More typically, the functions we’ll need to take the gradient of will themselves be defined in terms of matrix and vector operations. In all of these examples, remember that we’re working with vector-to-scalar functions.
Here’s an extremely important example that shows up everywhere in machine learning. Find the gradients of:
f(x)=∥x∥2
f(x)=∥x∥
Solution
As we did in the previous example, we can expand f(x)=∥x∥2 to get
f(x)=x⋅x=x12+x22+⋯+xn2
For each i, ∂xi∂f=2xi. So,
∇f(x)=⎣⎡2x12x2⋮2xn⎦⎤=2x
Think of this as the equivalent of the “power rule” for vectors.
There are two ways to find the gradient of f(x)=∥x∥: directly, or by using the chain rule. It’s not immediately obvious how the chain rule should work here, so we’ll start with the direct method and reason about how the chain rule may arise.
Direct method: Let’s start by expanding f(x)=∥x∥ like we did above.
f(x)=x⋅x=(x⋅x)1/2=(x12+x22+⋯+xn2)1/2
For each i, the (regular, scalar-to-scalar) chain rule tells us that
Chain rule method: Let me start by writing f(x) in terms of a composition of two functions.
f(x)=∥x∥=∥x∥2=h(g(x))
where g(x)=∥x∥2 and h(x)=x. Note that g:Rn→R is the vector-to-scalar function we found the gradient of above, and h:R→R is a scalar-to-scalar function.
Then, generalizing the calculation we did with the first method, we have a “chain rule” for a function h(g(x)) (where h is scalar-to-scalar and g is vector-to-scalar):
∇f(x)=h′(g(x))(dxdh(g(x)))∇g(x)
Remember that h(x)=x, so dxdh(x)=2x1 and dxdh(g(x))=2g(x)1=2∥x∥1. This means
If x∈Rn, we can define the log sum exp function as
f(x)=log(i=1∑nexi)
What is ∇f(x)? (The answer is called the softmax function, and comes up all the time in machine learning, when we want our models to output predicted probabilities in a classification problem.)
Solution
Let’s look at the partial derivatives with respect to each xi.
There isn’t really a way to simplify the expression using matrix-vector operations, so I’ll leave it as-is. As mentioned above, the gradient we’re looking at is called the softmax function. The softmax function maps Rn→Rn, meaning its a vector-to-vector function.
Let’s suppose we have the matrix ⎣⎡35−1⎦⎤. What does passing it through the softmax function yield?
The output vector has the same number of elements as the input vector, but each element is between 0 and 1, and the sum of elements is 1, meaning that we can interpret the outputted vector as a probability distribution. Larger values in the output correspond to larger values in the input, and almost all of the “mass” is concentrated at the maximum element of the input vector (position 2), hence the name “soft” max. (The “hard” max might be ⎣⎡010⎦⎤ in this case.)
is called a quadratic form, and its gradient is given by
∇f(x)=(A+AT)x
We won’t directly cover the proof of this formula here; one place to find it is here. Instead, we’ll focus our energy on understanding how it works, since it’s extremely important.
Let A=[acbd]. Expand out f(x)=xTAx and compute ∇f(x) directly by computing partial derivatives, and verify that the result you get matches the formula above.
In quadratic forms, we typically assume that A is symmetric, meaning that A=AT. Why do you think this assumption is made (what does it help with)?
Hint: Let A=[3621] and B=[3441]. Compute ∇(xTAx) and ∇(xTBx).
If A is any symmetric n×n matrix, what is ∇f(x)?
Suppose A is symmetric and n×n, b∈Rn, and c∈R. Find the gradient of
For a particular quadratic form, there are infinitely many choices of matrices A that represent it. To illustrate, let’s look at A=[3621] and B=[3441] as provided in the hint.
Note that both xTAx and xTBx are equal to the expression 3x12+8x1x2+x22. In fact, any matrix of the form [3cb1] where b+c=8 would produce the same quadratic form.
So, to avoid this issue of having infinitely many choices of the matrix A, we pick the symmetric matrix A, where A=AT. As we’re about to see, this choice of A simplifies the calculation of the gradient.
If A is any symmetric n×n matrix, then A=AT, and A+AT=2A. So,
∇(xTAx)=(A+AT)x=2Ax
This is also an important rule; don’t forget it.
Think of f(x)=xTAx+b⋅x+c as the matrix-vector equivalent of a quadratic function, ax2+bx+c. The derivative of ax2+bx+c is 2ax+b. Check out what the gradient of f(x) ends up being!
f(x)∇f(x)=xTAx+b⋅x+c=(A+AT)x+b=2Ax+b(since A is symmetric)
These are the core rules you need to know moving forward, not just because we’re about to use them in an important proof, but because they’ll come up repeatedly in your future machine learning work.
Function
Name
Gradient
f(x)=a⋅x
dot product
∇f(x)=a
f(x)=∥x∥2
squared norm
∇f(x)=2x
f(x)=xTAx
quadratic form
∇f(x)=(A+AT)x if A is symmetric, ∇f(x)=2Ax
We now apply the gradient concepts to examples and matrix-vector operations.
In the calculus of scalar-to-scalar functions, we have a well-understood procedure for finding the extrema of a function. The general strategy is to take the derivative, set it to zero, and solve for the inputs (called critical points) that satisfy that condition. To be thorough, we’d perform a second derivative test to check whether each critical point is a maximum, minimum, or neither.
In the land of vector-to-scalar functions, the equivalent is to solve for where the gradient is zero, which corresponds to finding where all partial derivatives are zero. Assessing whether we’ve arrived at a maximum or minimum is more difficult to do in the vector-to-scalar case, and we will save a discussion of this for Chapter 8.5.
As an example, consider
f(x)=xT[3441]x+[12]⋅x+3
As we computed earlier, the gradient of f(x)=xTAx+b⋅x+c is ∇f(x)=2Ax+b for symmetric A. So,
∇f(x)=2[3441]x+[12]=[6x1+8x2+18x1+2x2+2]
To find the critical points, we set the gradient to zero and solve the resulting system. We can also accomplish this by using the inverse of A, if we happen to have it:
∇f(x)=0⟹2Ax+b=0⟹x∗=−21A−1b
Either way, we find that x∗=[−7/261/13] satisfies ∇f(x∗)=0, which corresponds to a local minimum.
import numpy as np
import plotly.graph_objs as go
# Define the function
def f(x, y):
return 3 * x ** 2 + 8 * x * y + y ** 2 + x + 2 * y + 3
fig = plot_gradient_on_surface(
f = f,
lim = 5,
xaxis_title = 'x₁',
yaxis_title = 'x₂',
zaxis_title = 'f(x₁, x₂)',
title='',
dfx1 = lambda x, y: 6 * x + 8 * y + 1,
dfx2 = lambda x, y: 8 * x + 2 * y + 2,
point = np.array([-7/26, 1/13]),
)
fig.update_layout(title='', scene_camera=dict(eye=dict(x=1, y=2, z=2)))
# Annotate the local minimum point
x_star = -7/26
y_star = 1/13
z_star = f(x_star, y_star)
fig.add_trace(
go.Scatter3d(
x=[x_star],
y=[y_star],
z=[z_star],
mode='markers+text',
marker=dict(size=12, color='gold', symbol='circle'),
text=["local minimum"],
textposition="top center",
textfont=dict(color='gold', size=14),
name="local minimum"
)
)
Remember, the goal of this section is to minimize mean squared error,
Rsq(w)=n1∥y−Xw∥2
In the general case, X is an n×(d+1) matrix, y∈Rn, and w∈Rd+1.
We’re now equipped with the tools to minimize Rsq(w) by taking its gradient and setting it to zero. Hopefully, we end up with the same conditions on w∗ that we derived in Chapter 6.3.
In the most recent example we saw, the optimal vector x∗ corresponded to a local minimum. We know that we won’t run into such an issue here since Rsq(w) cannot output a negative number (it is the average of squared losses), so its minimum possible output is 0, meaning that there will be some global minimizer w∗.
Let’s start by rewriting the squared norm as a dot product and eventually matrix multiplication.
Let’s focus on the two terms in orange. They are both equal: they are both the dot product of y and Xw. Ideally, I want to express each term as a dot product of w with something, since I’m taking the gradient with respect to w. Remember, the dot product is a scalar, and the transpose of a scalar is just that same scalar. So,
yTXw=(yTXw)T=wTXTy=wT(XTy)
so, performing this substitution in for both orange terms gives us
Finally, to find the minimizer w∗, we set the gradient to zero and solve.
n2(XTXw∗−XTy)=0⟹XTXw∗=XTy
Stop me if this feels familiar... these are the normal equations once again! It shouldn’t be a surprise that we ended up with the same conditions on w∗ that we derived in Chapter 6.3, since we were solving the same problem.
We’ve now shown that the minimizer of
Rsq(w)=n1∥y−Xw∥2
is given by solving XTXw∗=XTy. These equations have a unique solution if XTX is invertible, and infinitely many solutions otherwise. If w∗ satisfies the normal equations, then Xw∗ is the vector in colsp(X) that is closest to y. All of that interpretation from Chapter 6.3 and Chapter 7 carry over; we’ve just introduced a new way of finding the solution.
Heads up: In Homework 9, you’ll follow similar steps to minimize a new objective function, that resembles Rsq(w) but involves another term. There, you’ll minimize
Rridge(w)=∥y−Xw∥2+λ∥w∥2
where λ>0 is a constant, called the regularization hyperparameter. (Notice the missing n1.) A good way to practice what you’ve learned (and to get a head start on the homework) is to compute the gradient of Rridge(w) and set it to zero. We’ll walk through what the significance of Rridge(w) is in the homework.