Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

8.2. Gradients of Matrix-Vector Operations

More typically, the functions we’ll need to take the gradient of will themselves be defined in terms of matrix and vector operations. In all of these examples, remember that we’re working with vector-to-scalar functions.

Example: Dot Product

Let aRn\vec a \in \mathbb{R}^n be some fixed vector (the equivalent of a constant in this context). Let’s find the gradient of

f(x)=axf(\vec x) = \vec a \cdot \vec x

I find it helpful to think about f(x)f(\vec x) in its expanded form,

f(x)=ax=a1x1+a2x2++anxnf(\vec x) = \vec a \cdot \vec x = a_1 x_1 + a_2 x_2 + \cdots + a_n x_n

Remember, f(x)\nabla f(\vec x) contains all of the partial derivatives of ff, which we now need to compute.

Putting these together, we get

f(x)=[fx1fx2fxn]=[a1a2an]=a\nabla f(\vec x) = \begin{bmatrix} \frac{\partial f}{\partial x_1} \\ \frac{\partial f}{\partial x_2} \\ \vdots \\ \frac{\partial f}{\partial x_n} \end{bmatrix} = \begin{bmatrix} a_1 \\ a_2 \\ \vdots \\ a_n \end{bmatrix} = \vec a

Example: Norm and Chain Rule

Here’s an extremely important example that shows up everywhere in machine learning. Find the gradients of:

  1. f(x)=x2f(\vec x) = \lVert \vec x \rVert^2

  2. f(x)=xf(\vec x) = \lVert \vec x \rVert

Example: Norm to an Exponent

Find the gradient of f(x)=xpf(\vec x) = \lVert \vec x \rVert^p, where pp is some real number.

Example: Log Sum Exp

If xRn\vec x \in \mathbb{R}^n, we can define the log sum exp function as

f(x)=log(i=1nexi)f(\vec x) = \log \left( \sum_{i=1}^n e^{x_i} \right)

What is f(x)\nabla f(\vec x)? (The answer is called the softmax function, and comes up all the time in machine learning, when we want our models to output predicted probabilities in a classification problem.)

Example: Quadratic Forms

Suppose xRnx \in \mathbb{R}^n and AA is an n×nn \times n matrix. The function

f(x)=xTAxf(\vec x) = \vec x^T A \vec x

is called a quadratic form, and its gradient is given by

f(x)=(A+AT)x\nabla f(\vec x) = (A + A^T) \vec x

We won’t directly cover the proof of this formula here; one place to find it is here. Instead, we’ll focus our energy on understanding how it works, since it’s extremely important.

  1. Let A=[abcd]A = \begin{bmatrix} a & b \\ c & d \end{bmatrix}. Expand out f(x)=xTAxf(\vec x) = \vec x^T A \vec x and compute f(x)\nabla f(\vec x) directly by computing partial derivatives, and verify that the result you get matches the formula above.

  2. In quadratic forms, we typically assume that AA is symmetric, meaning that A=ATA = A^T. Why do you think this assumption is made (what does it help with)?

    • Hint: Let A=[3261]A = \begin{bmatrix} 3 & 2 \\ 6 & 1 \end{bmatrix} and B=[3441]B = \begin{bmatrix} 3 & 4 \\ 4 & 1 \end{bmatrix}. Compute (xTAx)\nabla (\vec x^T A \vec x) and (xTBx)\nabla (\vec x^T B \vec x).

  3. If AA is any symmetric n×nn \times n matrix, what is f(x)\nabla f(\vec x)?

  4. Suppose AA is symmetric and n×nn \times n, bRn\vec b \in \mathbb{R}^n, and cRc \in \mathbb{R}. Find the gradient of

    f(x)=xTAx+bx+cf(\vec x) = \vec x^T A \vec x + \vec b \cdot \vec x + c

Summary of Important Gradient Rules

These are the core rules you need to know moving forward, not just because we’re about to use them in an important proof, but because they’ll come up repeatedly in your future machine learning work.

FunctionNameGradient
f(x)=axf(\vec x) = \vec a \cdot \vec xdot productf(x)=a\nabla f(\vec x) = \vec a
f(x)=x2f(\vec x) = \lVert \vec x \rVert^2squared normf(x)=2x\nabla f(\vec x) = 2\vec x
f(x)=xTAxf(\vec x) = \vec x^T A \vec xquadratic formf(x)=(A+AT)x\nabla f(\vec x) = (A + A^T) \vec x
if AA is symmetric, f(x)=2Ax\nabla f(\vec x) = 2A \vec x

We now apply the gradient concepts to examples and matrix-vector operations.


Optimization

In the calculus of scalar-to-scalar functions, we have a well-understood procedure for finding the extrema of a function. The general strategy is to take the derivative, set it to zero, and solve for the inputs (called critical points) that satisfy that condition. To be thorough, we’d perform a second derivative test to check whether each critical point is a maximum, minimum, or neither.

In the land of vector-to-scalar functions, the equivalent is to solve for where the gradient is zero, which corresponds to finding where all partial derivatives are zero. Assessing whether we’ve arrived at a maximum or minimum is more difficult to do in the vector-to-scalar case, and we will save a discussion of this for Chapter 8.5.

As an example, consider

f(x)=xT[3441]x+[12]x+3f(\vec x) = \vec x^T \begin{bmatrix} 3 & 4 \\ 4 & 1 \end{bmatrix} \vec x + \begin{bmatrix} 1 \\ 2 \end{bmatrix} \cdot \vec x + 3

As we computed earlier, the gradient of f(x)=xTAx+bx+cf(\vec x) = \vec x^T A \vec x + \vec b \cdot \vec x + c is f(x)=2Ax+b\nabla f(\vec x) = 2A \vec x + \vec b for symmetric AA. So,

f(x)=2[3441]x+[12]=[6x1+8x2+18x1+2x2+2]\nabla f(\vec x) = 2 \begin{bmatrix} 3 & 4 \\ 4 & 1 \end{bmatrix} \vec x + \begin{bmatrix} 1 \\ 2 \end{bmatrix} = \begin{bmatrix} 6x_1 + 8x_2 + 1 \\ 8x_1 + 2x_2 + 2 \end{bmatrix}

To find the critical points, we set the gradient to zero and solve the resulting system. We can also accomplish this by using the inverse of AA, if we happen to have it:

f(x)=0    2Ax+b=0    x=12A1b\nabla f(\vec x) = 0 \implies 2 A \vec x + \vec b = 0 \implies \vec x^* = -\frac{1}{2}A^{-1} \vec b

Either way, we find that x=[7/261/13]\vec x^* = \begin{bmatrix} -7/26 \\ 1/13 \end{bmatrix} satisfies f(x)=0\nabla f(\vec x^*) = 0, which corresponds to a local minimum.

Loading...

Minimizing Mean Squared Error

Remember, the goal of this section is to minimize mean squared error,

Rsq(w)=1nyXw2R_\text{sq}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2

In the general case, XX is an n×(d+1)n \times (d + 1) matrix, yRny \in \mathbb{R}^n, and wRd+1\vec w \in \mathbb{R}^{d+1}.

We’re now equipped with the tools to minimize Rsq(w)R_\text{sq}(\vec w) by taking its gradient and setting it to zero. Hopefully, we end up with the same conditions on w\vec w^* that we derived in Chapter 6.3.

In the most recent example we saw, the optimal vector x\vec x^* corresponded to a local minimum. We know that we won’t run into such an issue here since Rsq(w)R_\text{sq}(\vec w) cannot output a negative number (it is the average of squared losses), so its minimum possible output is 0, meaning that there will be some global minimizer w\vec w^*.

Let’s start by rewriting the squared norm as a dot product and eventually matrix multiplication.

Rsq(w)=1nyXw2=1n(yXw)(yXw)=1n(yXx)T(yXw)since uv=uTv=1n(yT(Xw)T)(yXw)=1n(yTyyTXw(Xw)Ty+(Xw)TXw)\begin{align*}R_\text{sq}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2 &= \frac{1}{n} (\vec y - X \vec w) \cdot (\vec y - X \vec w) \\ &= \underbrace{\frac{1}{n} (\vec y - X \vec x)^T (\vec y - X \vec w)}_{\text{since } \vec u \cdot \vec v = \vec u^T \vec v} \\ &= \frac{1}{n} \left( \vec y^T - (X \vec w)^T \right) (\vec y - X \vec w) \\ &= \frac{1}{n} \left( \vec y^T \vec y - {\color{orange}\vec y^T X \vec w} - {\color{orange}(X \vec w)^T \vec y} + (X \vec w)^T X \vec w \right)\end{align*}

Let’s focus on the two terms in orange. They are both equal: they are both the dot product of y\vec y and XwX \vec w. Ideally, I want to express each term as a dot product of w\vec w with something, since I’m taking the gradient with respect to w\vec w. Remember, the dot product is a scalar, and the transpose of a scalar is just that same scalar. So,

yTXw=(yTXw)T=wTXTy=wT(XTy)\vec y^T X \vec w = (\vec y^T X \vec w)^T = \vec w^T X^T \vec y = \vec w^T (X^T \vec y)

so, performing this substitution in for both orange terms gives us

Rsq(w)=1n(yTywT(XTy)wTXTy+wT(XTX)w)=1n(yTy2wT(XTy)+wT(XTX)w)\begin{align*}R_\text{sq}(\vec w) &= \frac{1}{n} \left( \vec y^T \vec y - {\color{orange}\vec w^T (X^T \vec y)} - {\color{orange}\vec w^T X^T \vec y} + \vec w^T (X^T X) \vec w \right) \\ &= \frac{1}{n} \left( \vec y^T \vec y - 2 \vec w^T (X^T \vec y) + \vec w^T (X^T X) \vec w \right)\end{align*}

Now, we’re ready to take the gradient, which we’ll do term by term.

  • (yTy)=0\nabla \left( \vec y^T \vec y \right) = \vec 0, since yTy\vec y^T \vec y is a constant with respect to w\vec w

  • (2wT(XTy))=2XTy\nabla \left( 2 \vec w^T (X^T \vec y) \right) = 2 X^T \vec y using the dot product rule, since this is the dot product between 2XTy2X^T \vec y (a vector) and w\vec w (a vector)

  • (wT(XTX)w)=2XTXw\nabla \left( \vec w^T (X^T X) \vec w \right) = 2X^T X \vec w, using the quadratic form rule, since XTXX^T X is a symmetric matrix

Plugging these terms in gives us

Rsq(w)=1n(yTy2wT(XTy)+wT(XTX)w)Rsq(w)=1n((yTy)(2wT(XTy))+(wT(XTX)w))=1n(02XTy+2XTXw)=2n(XTXwXTy)\begin{align*}R_\text{sq}(\vec w) &= \frac{1}{n} \left( \vec y^T \vec y - 2 \vec w^T (X^T \vec y) + \vec w^T (X^T X) \vec w \right) \\ \nabla R_\text{sq}(\vec w) &= \frac{1}{n} \left( \nabla \left(\vec y^T \vec y \right) - \nabla \left( 2 \vec w^T (X^T \vec y) \right) + \nabla \left( \vec w^T (X^T X) \vec w \right) \right) \\ &= \frac{1}{n} \left( 0 - 2 X^T \vec y + 2X^T X \vec w \right) \\ &= \boxed{\frac{2}{n} (X^T X \vec w - X^T \vec y)} \end{align*}

Finally, to find the minimizer w\vec w^*, we set the gradient to zero and solve.

2n(XTXwXTy)=0    XTXw=XTy\frac{2}{n} (X^T X \vec w^* - X^T \vec y) = 0 \implies X^TX \vec w^* = X^T \vec y

Stop me if this feels familiar... these are the normal equations once again! It shouldn’t be a surprise that we ended up with the same conditions on w\vec w^* that we derived in Chapter 6.3, since we were solving the same problem.

We’ve now shown that the minimizer of

Rsq(w)=1nyXw2R_\text{sq}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2

is given by solving XTXw=XTyX^TX \vec w^* = X^T \vec y. These equations have a unique solution if XTXX^TX is invertible, and infinitely many solutions otherwise. If w\vec w^* satisfies the normal equations, then XwX \vec w^* is the vector in colsp(X)\text{colsp}(X) that is closest to y\vec y. All of that interpretation from Chapter 6.3 and Chapter 7 carry over; we’ve just introduced a new way of finding the solution.

Heads up: In Homework 9, you’ll follow similar steps to minimize a new objective function, that resembles Rsq(w)R_\text{sq}(\vec w) but involves another term. There, you’ll minimize

Rridge(w)=yXw2+λw2R_\text{ridge}(\vec w) = \lVert \vec y - X \vec w \rVert^2 + \lambda \lVert \vec w \rVert^2

where λ>0\lambda > 0 is a constant, called the regularization hyperparameter. (Notice the missing 1n\frac{1}{n}.) A good way to practice what you’ve learned (and to get a head start on the homework) is to compute the gradient of Rridge(w)R_\text{ridge}(\vec w) and set it to zero. We’ll walk through what the significance of Rridge(w)R_\text{ridge}(\vec w) is in the homework.