Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

8.4. Gradient Descent for Empirical Risk Minimization

While gradient descent can be used to (attempt to) minimize any differentiable function f(x)f(\vec x), we typically use it to minimize empirical risk functions, R(w)R(\vec w).

Let’s try using gradient descent to fit a linear regression model – that is, let’s use it to minimize

Rsq(w)=1nyXw2R_\text{sq}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2

This function has a closed-form solution, but it’s worthwhile to see how gradient descent works on it.

In Chapter 8.1, we found that the gradient of Rsq(w)R_\text{sq}(\vec w) is

Rsq(w)=2n(XTXwXTy)\nabla R_\text{sq}(\vec w) = \frac{2}{n} (X^TX \vec w - X^T \vec y)

so, the update rule is

w(t+1)=w(t)α2n(XTXw(t)XTy)\vec w^{(t+1)} = \vec w^{(t)} - \alpha \frac{2}{n} (X^TX \vec w^{(t)} - X^T \vec y)

Let’s start by using gradient descent to fit a simple linear regression model to predict commute times in minutes from departure_hour – a problem we’ve solved many times.

We now apply gradient descent to empirical risk minimization.

Image produced in Jupyter

More to come! We’ll cover this example on Tuesday, and talk more about convexity (and, time permitting, variants of gradient descent for large datasets).