Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

8.4. Gradient Descent for Empirical Risk Minimization

While gradient descent can be used to (attempt to) minimize any differentiable function f(x)f(\vec x), we typically use it to minimize empirical risk functions, R(w)R(\vec w). As I said in Chapter 8.3, gradient descent is the tool in practice for finding optimal model parameters. This is because most empirical risk functions in practice don’t have closed-form solutions, i.e. a formula for w\vec w^* that we can derive algebraically by hand.


Simple Linear Regression

Let’s try using gradient descent to fit a linear regression model - that is, let’s use it to minimize

Rsq(w)=1nyXw2R_\text{sq}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2

This function has a closed-form minimizer, w=(XTX)1XTy\vec w^* = (X^TX)^{-1}X^T\vec y, so we don’t need gradient descent. Still, it’s worthwhile to see how gradient descent works on it.

In Chapter 8.2, we found that the gradient of Rsq(w)R_\text{sq}(\vec w) is

Rsq(w)=2n(XTXwXTy)\nabla R_\text{sq}(\vec w) = \frac{2}{n} (X^TX \vec w - X^T \vec y)

so, the update rule is

w(t+1)=w(t)α2n(XTXw(t)XTy)\vec w^{(t+1)} = \vec w^{(t)} - \alpha \frac{2}{n} (X^TX \vec w^{(t)} - X^T \vec y)

Let’s start by using gradient descent to fit a simple linear regression model to predict commute times in minutes from departure_hour – a problem we’ve solved many times.

Loading...
Loading...

First, for reference, we’ll compute w\vec w^* using the closed-form solution,

w=(XTX)1XTy\vec w^* = (X^TX)^{-1}X^T\vec y
x = df["departure_hour"].to_numpy(dtype=float)
y = df["minutes"].to_numpy(dtype=float)
n = len(df)

X = np.column_stack([np.ones(n), x])

w_star_closed_form = np.linalg.solve(X.T @ X, X.T @ y)
w_star_closed_form
array([142.44824159, -8.18694172])

The code below implements gradient descent. At each iteration, it computes the MSE, its gradient, and logs the current w(t)\vec w^{(t)} vector. That current w(t)\vec w^{(t)} vector is sometimes called an “iterate”. For every 500 iterations, it displays the current values of MSE, the norm of the gradient vector, and w(t)\vec w^{(t)}. I’ve collapsed the code since it’s relatively lengthy.

Source
def mse(w):
    return np.mean((y - X @ w) ** 2)


def grad(w):
    return (2 / n) * (X.T @ (X @ w - y))


def run_gradient_descent(
    w0,
    alpha,
    loss_fn,
    grad_fn,
    tol=1e-2,
    max_iter=50000,
    record_every=500,
    log=True,
    loss_name="loss",
    param_names=None,
):
    w = np.array(w0, dtype=float)

    if param_names is None:
        param_names = [f"w{i}" for i in range(len(w))]

    rows = []

    for t in range(max_iter + 1):
        gradient = grad_fn(w)
        grad_norm = np.linalg.norm(gradient)
        current_loss = loss_fn(w)

        should_record = (t % record_every == 0) or (grad_norm < tol) or (t == max_iter)

        if should_record:
            row = {"t": t, loss_name: current_loss, "grad_norm": grad_norm}
            row.update({name: value for name, value in zip(param_names, w)})
            rows.append(row)

            if log:
                params_str = ", ".join(f"{value:8.3f}" for value in w)
                print(
                    f"t = {t:>5d} | w^({t}) = [{params_str}] | "
                    f"{loss_name} = {current_loss:9.4f} | ||grad|| = {grad_norm:9.6f}"
                )

        # The most relevant pieces are below.
        # --------------------------------
        if grad_norm < tol:
            break

        w = w - alpha * gradient
        # --------------------------------

    return pd.DataFrame(rows)

alpha = 0.01

history = run_gradient_descent(
    w0=np.zeros(2),
    alpha=alpha,
    loss_fn=mse,
    grad_fn=grad,
    tol=1e-2,
    max_iter=50000,
    record_every=500,
    loss_name="mse",
    param_names=["w0", "w1"],
    log=True
)

history.round(4);
t =     0 | w^(0) = [   0.000,    0.000] | mse = 5523.5231 | ||grad|| = 1229.842640
t =   500 | w^(500) = [  19.804,    6.102] | mse =  314.8511 | ||grad|| =  3.527944
t =  1000 | w^(1000) = [  36.133,    4.200] | mse =  260.7135 | ||grad|| =  3.058219
t =  1500 | w^(1500) = [  50.289,    2.551] | mse =  220.0323 | ||grad|| =  2.651035
t =  2000 | w^(2000) = [  62.559,    1.121] | mse =  189.4630 | ||grad|| =  2.298065
t =  2500 | w^(2500) = [  73.196,   -0.118] | mse =  166.4919 | ||grad|| =  1.992091
t =  3000 | w^(3000) = [  82.416,   -1.193] | mse =  149.2306 | ||grad|| =  1.726856
t =  3500 | w^(3500) = [  90.409,   -2.124] | mse =  136.2598 | ||grad|| =  1.496935
t =  4000 | w^(4000) = [  97.338,   -2.931] | mse =  126.5130 | ||grad|| =  1.297627
t =  4500 | w^(4500) = [ 103.344,   -3.631] | mse =  119.1889 | ||grad|| =  1.124856
t =  5000 | w^(5000) = [ 108.551,   -4.237] | mse =  113.6852 | ||grad|| =  0.975088
t =  5500 | w^(5500) = [ 113.064,   -4.763] | mse =  109.5496 | ||grad|| =  0.845261
t =  6000 | w^(6000) = [ 116.976,   -5.219] | mse =  106.4419 | ||grad|| =  0.732719
t =  6500 | w^(6500) = [ 120.368,   -5.614] | mse =  104.1067 | ||grad|| =  0.635162
t =  7000 | w^(7000) = [ 123.308,   -5.957] | mse =  102.3519 | ||grad|| =  0.550594
t =  7500 | w^(7500) = [ 125.856,   -6.254] | mse =  101.0333 | ||grad|| =  0.477285
t =  8000 | w^(8000) = [ 128.065,   -6.511] | mse =  100.0424 | ||grad|| =  0.413738
t =  8500 | w^(8500) = [ 129.980,   -6.734] | mse =   99.2978 | ||grad|| =  0.358651
t =  9000 | w^(9000) = [ 131.640,   -6.928] | mse =   98.7383 | ||grad|| =  0.310899
t =  9500 | w^(9500) = [ 133.079,   -7.095] | mse =   98.3179 | ||grad|| =  0.269504
t = 10000 | w^(10000) = [ 134.327,   -7.241] | mse =   98.0020 | ||grad|| =  0.233621
t = 10500 | w^(10500) = [ 135.408,   -7.367] | mse =   97.7646 | ||grad|| =  0.202516
t = 11000 | w^(11000) = [ 136.345,   -7.476] | mse =   97.5862 | ||grad|| =  0.175552
t = 11500 | w^(11500) = [ 137.158,   -7.571] | mse =   97.4521 | ||grad|| =  0.152179
t = 12000 | w^(12000) = [ 137.862,   -7.653] | mse =   97.3514 | ||grad|| =  0.131917
t = 12500 | w^(12500) = [ 138.473,   -7.724] | mse =   97.2757 | ||grad|| =  0.114353
t = 13000 | w^(13000) = [ 139.002,   -7.785] | mse =   97.2188 | ||grad|| =  0.099127
t = 13500 | w^(13500) = [ 139.461,   -7.839] | mse =   97.1761 | ||grad|| =  0.085929
t = 14000 | w^(14000) = [ 139.859,   -7.885] | mse =   97.1440 | ||grad|| =  0.074488
t = 14500 | w^(14500) = [ 140.204,   -7.925] | mse =   97.1198 | ||grad|| =  0.064571
t = 15000 | w^(15000) = [ 140.502,   -7.960] | mse =   97.1017 | ||grad|| =  0.055973
t = 15500 | w^(15500) = [ 140.761,   -7.990] | mse =   97.0881 | ||grad|| =  0.048521
t = 16000 | w^(16000) = [ 140.986,   -8.017] | mse =   97.0778 | ||grad|| =  0.042061
t = 16500 | w^(16500) = [ 141.181,   -8.039] | mse =   97.0701 | ||grad|| =  0.036460
t = 17000 | w^(17000) = [ 141.350,   -8.059] | mse =   97.0644 | ||grad|| =  0.031606
t = 17500 | w^(17500) = [ 141.496,   -8.076] | mse =   97.0600 | ||grad|| =  0.027398
t = 18000 | w^(18000) = [ 141.623,   -8.091] | mse =   97.0567 | ||grad|| =  0.023750
t = 18500 | w^(18500) = [ 141.733,   -8.104] | mse =   97.0543 | ||grad|| =  0.020588
t = 19000 | w^(19000) = [ 141.828,   -8.115] | mse =   97.0524 | ||grad|| =  0.017847
t = 19500 | w^(19500) = [ 141.910,   -8.124] | mse =   97.0511 | ||grad|| =  0.015470
t = 20000 | w^(20000) = [ 141.982,   -8.133] | mse =   97.0500 | ||grad|| =  0.013411
t = 20500 | w^(20500) = [ 142.044,   -8.140] | mse =   97.0492 | ||grad|| =  0.011625
t = 21000 | w^(21000) = [ 142.098,   -8.146] | mse =   97.0486 | ||grad|| =  0.010077
t = 21027 | w^(21027) = [ 142.101,   -8.146] | mse =   97.0486 | ||grad|| =  0.010000

As you can see in the first line, we started with an initial guess of w(0)=[00]\vec w^{(0)} = \begin{bmatrix}0 & 0\end{bmatrix}. We also chose a step size of α=0.01\alpha = 0.01 arbitrarily.

By the final iteration (number 21027), gradient descent has essentially matched the closed-form least-squares solution. After ~10000 iterations, the norm of the gradient vector is already close to 0, and the MSE stops changing very much.

Instead of just looking at printed logs, here’s an interactive figure. Drag the slider from left to right: the left panel tracks the MSE over time, while the right panel shows the regression line corresponding to the selected iteration. The title reports the MSE of the model at that iteration number. This is the sort of figure that machine learning practitioners draw frequently when training large-scale models.

Loading...

To be clear, the function we are actually minimizing doesn’t appear in either of the plots above. That function, Rsq(w)=1nyXw2R_\text{sq}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2, is a vector-to-scalar function that we’d need to draw in R3\mathbb{R}^3 here.

Gradient descent is beautiful: once we can write down a model, a loss for a single example, and the gradient of the resulting average loss, we can search for optimal parameters for far more complicated model-loss combinations. Least squares is just one example. As long as the average loss is differentiable, gradient descent gives us a general recipe for finding the model’s optimal parameters.


Another Example: Logistic Regression

Coming soon!


Issues of Scale

When I get a chance, I’ll flesh this section out more. For now, refer to this blog post by Sebastian Ruder. It discusses stochastic gradient descent, a variant of gradient descent used in practical machine learning applications.