While gradient descent can be used to (attempt to) minimize any differentiable function , we typically use it to minimize empirical risk functions, . As I said in Chapter 8.3, gradient descent is the tool in practice for finding optimal model parameters. This is because most empirical risk functions in practice don’t have closed-form solutions, i.e. a formula for that we can derive algebraically by hand.
Simple Linear Regression¶
Let’s try using gradient descent to fit a linear regression model - that is, let’s use it to minimize
This function has a closed-form minimizer, , so we don’t need gradient descent. Still, it’s worthwhile to see how gradient descent works on it.
In Chapter 8.2, we found that the gradient of is
so, the update rule is
Let’s start by using gradient descent to fit a simple linear regression model to predict commute times in minutes from departure_hour – a problem we’ve solved many times.
First, for reference, we’ll compute using the closed-form solution,
x = df["departure_hour"].to_numpy(dtype=float)
y = df["minutes"].to_numpy(dtype=float)
n = len(df)
X = np.column_stack([np.ones(n), x])
w_star_closed_form = np.linalg.solve(X.T @ X, X.T @ y)
w_star_closed_formarray([142.44824159, -8.18694172])The code below implements gradient descent. At each iteration, it computes the MSE, its gradient, and logs the current vector. That current vector is sometimes called an “iterate”. For every 500 iterations, it displays the current values of MSE, the norm of the gradient vector, and . I’ve collapsed the code since it’s relatively lengthy.
Source
def mse(w):
return np.mean((y - X @ w) ** 2)
def grad(w):
return (2 / n) * (X.T @ (X @ w - y))
def run_gradient_descent(
w0,
alpha,
loss_fn,
grad_fn,
tol=1e-2,
max_iter=50000,
record_every=500,
log=True,
loss_name="loss",
param_names=None,
):
w = np.array(w0, dtype=float)
if param_names is None:
param_names = [f"w{i}" for i in range(len(w))]
rows = []
for t in range(max_iter + 1):
gradient = grad_fn(w)
grad_norm = np.linalg.norm(gradient)
current_loss = loss_fn(w)
should_record = (t % record_every == 0) or (grad_norm < tol) or (t == max_iter)
if should_record:
row = {"t": t, loss_name: current_loss, "grad_norm": grad_norm}
row.update({name: value for name, value in zip(param_names, w)})
rows.append(row)
if log:
params_str = ", ".join(f"{value:8.3f}" for value in w)
print(
f"t = {t:>5d} | w^({t}) = [{params_str}] | "
f"{loss_name} = {current_loss:9.4f} | ||grad|| = {grad_norm:9.6f}"
)
# The most relevant pieces are below.
# --------------------------------
if grad_norm < tol:
break
w = w - alpha * gradient
# --------------------------------
return pd.DataFrame(rows)
alpha = 0.01
history = run_gradient_descent(
w0=np.zeros(2),
alpha=alpha,
loss_fn=mse,
grad_fn=grad,
tol=1e-2,
max_iter=50000,
record_every=500,
loss_name="mse",
param_names=["w0", "w1"],
log=True
)
history.round(4);
t = 0 | w^(0) = [ 0.000, 0.000] | mse = 5523.5231 | ||grad|| = 1229.842640
t = 500 | w^(500) = [ 19.804, 6.102] | mse = 314.8511 | ||grad|| = 3.527944
t = 1000 | w^(1000) = [ 36.133, 4.200] | mse = 260.7135 | ||grad|| = 3.058219
t = 1500 | w^(1500) = [ 50.289, 2.551] | mse = 220.0323 | ||grad|| = 2.651035
t = 2000 | w^(2000) = [ 62.559, 1.121] | mse = 189.4630 | ||grad|| = 2.298065
t = 2500 | w^(2500) = [ 73.196, -0.118] | mse = 166.4919 | ||grad|| = 1.992091
t = 3000 | w^(3000) = [ 82.416, -1.193] | mse = 149.2306 | ||grad|| = 1.726856
t = 3500 | w^(3500) = [ 90.409, -2.124] | mse = 136.2598 | ||grad|| = 1.496935
t = 4000 | w^(4000) = [ 97.338, -2.931] | mse = 126.5130 | ||grad|| = 1.297627
t = 4500 | w^(4500) = [ 103.344, -3.631] | mse = 119.1889 | ||grad|| = 1.124856
t = 5000 | w^(5000) = [ 108.551, -4.237] | mse = 113.6852 | ||grad|| = 0.975088
t = 5500 | w^(5500) = [ 113.064, -4.763] | mse = 109.5496 | ||grad|| = 0.845261
t = 6000 | w^(6000) = [ 116.976, -5.219] | mse = 106.4419 | ||grad|| = 0.732719
t = 6500 | w^(6500) = [ 120.368, -5.614] | mse = 104.1067 | ||grad|| = 0.635162
t = 7000 | w^(7000) = [ 123.308, -5.957] | mse = 102.3519 | ||grad|| = 0.550594
t = 7500 | w^(7500) = [ 125.856, -6.254] | mse = 101.0333 | ||grad|| = 0.477285
t = 8000 | w^(8000) = [ 128.065, -6.511] | mse = 100.0424 | ||grad|| = 0.413738
t = 8500 | w^(8500) = [ 129.980, -6.734] | mse = 99.2978 | ||grad|| = 0.358651
t = 9000 | w^(9000) = [ 131.640, -6.928] | mse = 98.7383 | ||grad|| = 0.310899
t = 9500 | w^(9500) = [ 133.079, -7.095] | mse = 98.3179 | ||grad|| = 0.269504
t = 10000 | w^(10000) = [ 134.327, -7.241] | mse = 98.0020 | ||grad|| = 0.233621
t = 10500 | w^(10500) = [ 135.408, -7.367] | mse = 97.7646 | ||grad|| = 0.202516
t = 11000 | w^(11000) = [ 136.345, -7.476] | mse = 97.5862 | ||grad|| = 0.175552
t = 11500 | w^(11500) = [ 137.158, -7.571] | mse = 97.4521 | ||grad|| = 0.152179
t = 12000 | w^(12000) = [ 137.862, -7.653] | mse = 97.3514 | ||grad|| = 0.131917
t = 12500 | w^(12500) = [ 138.473, -7.724] | mse = 97.2757 | ||grad|| = 0.114353
t = 13000 | w^(13000) = [ 139.002, -7.785] | mse = 97.2188 | ||grad|| = 0.099127
t = 13500 | w^(13500) = [ 139.461, -7.839] | mse = 97.1761 | ||grad|| = 0.085929
t = 14000 | w^(14000) = [ 139.859, -7.885] | mse = 97.1440 | ||grad|| = 0.074488
t = 14500 | w^(14500) = [ 140.204, -7.925] | mse = 97.1198 | ||grad|| = 0.064571
t = 15000 | w^(15000) = [ 140.502, -7.960] | mse = 97.1017 | ||grad|| = 0.055973
t = 15500 | w^(15500) = [ 140.761, -7.990] | mse = 97.0881 | ||grad|| = 0.048521
t = 16000 | w^(16000) = [ 140.986, -8.017] | mse = 97.0778 | ||grad|| = 0.042061
t = 16500 | w^(16500) = [ 141.181, -8.039] | mse = 97.0701 | ||grad|| = 0.036460
t = 17000 | w^(17000) = [ 141.350, -8.059] | mse = 97.0644 | ||grad|| = 0.031606
t = 17500 | w^(17500) = [ 141.496, -8.076] | mse = 97.0600 | ||grad|| = 0.027398
t = 18000 | w^(18000) = [ 141.623, -8.091] | mse = 97.0567 | ||grad|| = 0.023750
t = 18500 | w^(18500) = [ 141.733, -8.104] | mse = 97.0543 | ||grad|| = 0.020588
t = 19000 | w^(19000) = [ 141.828, -8.115] | mse = 97.0524 | ||grad|| = 0.017847
t = 19500 | w^(19500) = [ 141.910, -8.124] | mse = 97.0511 | ||grad|| = 0.015470
t = 20000 | w^(20000) = [ 141.982, -8.133] | mse = 97.0500 | ||grad|| = 0.013411
t = 20500 | w^(20500) = [ 142.044, -8.140] | mse = 97.0492 | ||grad|| = 0.011625
t = 21000 | w^(21000) = [ 142.098, -8.146] | mse = 97.0486 | ||grad|| = 0.010077
t = 21027 | w^(21027) = [ 142.101, -8.146] | mse = 97.0486 | ||grad|| = 0.010000
As you can see in the first line, we started with an initial guess of . We also chose a step size of arbitrarily.
By the final iteration (number 21027), gradient descent has essentially matched the closed-form least-squares solution. After ~10000 iterations, the norm of the gradient vector is already close to 0, and the MSE stops changing very much.
Instead of just looking at printed logs, here’s an interactive figure. Drag the slider from left to right: the left panel tracks the MSE over time, while the right panel shows the regression line corresponding to the selected iteration. The title reports the MSE of the model at that iteration number. This is the sort of figure that machine learning practitioners draw frequently when training large-scale models.
To be clear, the function we are actually minimizing doesn’t appear in either of the plots above. That function, , is a vector-to-scalar function that we’d need to draw in here.
Gradient descent is beautiful: once we can write down a model, a loss for a single example, and the gradient of the resulting average loss, we can search for optimal parameters for far more complicated model-loss combinations. Least squares is just one example. As long as the average loss is differentiable, gradient descent gives us a general recipe for finding the model’s optimal parameters.
Issues of Scale¶
When I get a chance, I’ll flesh this section out more. For now, refer to this blog post by Sebastian Ruder. It discusses stochastic gradient descent, a variant of gradient descent used in practical machine learning applications.