Skip to article frontmatterSkip to article content

1.3. Empirical Risk Minimization

The Modeling Recipe

In Chapter 1.2, we implicitly introduced a three-step process for building a machine learning model.

Image produced in Jupyter

Most modern supervised learning algorithms follow these same three steps, just with different models, loss functions, and techniques for optimization.

Another name given to this process is empirical risk minimization.

When using squared loss, all three of these mean the same thing:

  • Average squared loss.
  • Mean squared error.
  • Empirical risk.

Risk is an idea from theoretical statistics that we’ll visit in Chapter 6. It refers to the expected error of a model, when considering the probability distribution of the data. “Empirical” risk refers to risk calculated using an actual, concrete dataset, rather than a theoretical distribution. The reason we call the average loss RR is precisely because it is empirical risk.

The first half of the course – and in some ways, the entire course – is focused on empirical risk minimization, and so we will make many passes through the three-step modeling recipe ourselves, with differing models and loss functions.


Absolute Loss

When we first introduced the idea of a loss function, we first started by computing the error, eie_i, of each prediction:

ei=yih(xi)e_i={\color{3D81F6}y_i}-{\color{orange}h(x_i)}

where yi{\color{3D81F6}y_i} is the actual value and h(xi){\color{orange}h(x_i)} is the predicted value.

The issue was that some errors were positive and some were negative, and so it was hard to compare them directly. We wanted the value of the loss function to be large for bad predictions and small for good predictions.

To get around this, we squared the errors, which gave us squared loss:

Lsq(yi,h(xi))=(yih(xi))2L_\text{sq}({\color{3D81F6}y_i}, {\color{orange}h(x_i)})=({\color{3D81F6}y_i}-{\color{orange}h(x_i)})^2

But, instead, we could have taken the absolute value of the errors. Doing so gives us absolute loss:

Labs(yi,h(xi))=yih(xi)L_\text{abs}({\color{3D81F6}y_i}, {\color{orange}h(x_i)})=|{\color{3D81F6}y_i}-{\color{orange}h(x_i)}|

Below, I’ve visualized the absolute loss and squared loss for just a single data point.

Image produced in Jupyter

You should notice two key differences between the two loss functions:

  1. The absolute loss function is not differentiable when yi=h(xi)y_i = h(x_i). The absolute value function, f(x)=xf(x) = |x|, does not have a derivative at x=0x=0, because its slope to the left of x=0x=0 (-1) is different from its slope to the right of x=0x=0 (1). For more on this idea, see Chapter 0.2.
  2. The squared loss function grows much faster than the absolute loss function, as the prediction h(xi)h(x_i) gets further away from the actual value yiy_i.

We know the optimal constant prediction, ww^*, when using squared loss, is the mean. What is the optimal constant prediction when using absolute loss? The answer is not still the mean; rather, the answer reflects some of these differences between squared loss and absolute loss.

Let’s find that new optimal constant prediction, ww^*, by revisiting the three-step modeling recipe.

  1. Choose a model.

    We’ll stick with the constant model, h(xi)=wh(x_i) = w.

  2. Choose a loss function.

    We’ll use absolute loss:

    Labs(yi,h(xi))=yih(xi)L_\text{abs}(y_i, h(x_i)) = |y_i - h(x_i)|

    For the constant model, since h(xi)=wh(x_i) = w, we have:

    Labs(yi,w)=yiwL_\text{abs}(y_i, w) = |y_i - w|
  3. Minimize average loss to find optimal model parameters.

    The average loss – also known as mean absolute error here – is:

    Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

In Chapter 1.2, we minimized Rsq(w)=1ni=1n(yiw)2\displaystyle R_\text{sq}(w) = \frac{1}{n} \sum_{i=1}^n (y_i - w)^2 by taking the derivative of Rsq(w)R_\text{sq}(w) with respect to ww and setting it equal to 0. That will be more challenging in the case of Rabs(w)R_\text{abs}(w), because the absolute value function is not differentiable when its input is 0, as we just discussed.


Mean Absolute Error for the Constant Model

We need to minimize the mean absolute error, Rabs(w)R_\text{abs}(w), for the constant model, h(xi)=wh(x_i) = w, but we have to address the fact that Rabs(w)R_\text{abs}(w) is not differentiable across its entire domain.

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

Graphing Mean Absolute Error

I think it’ll help to visualize what Rabs(w)R_\text{abs}(w) looks like. To do so, let’s reintroduce the small dataset of 5 values we used in Chapter 1.2.

y1=72,y2=90,y3=61,y4=85,y5=92y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92

Then, Rabs(w)R_\text{abs}(w) is:

Rabs(w)=15(72w+90w+61w+85w+92w)R_\text{abs}(w) = \frac{1}{5} (|72 - w| + |90 - w| + |61 - w| + |85 - w| + |92 - w|)
Image produced in Jupyter

This is a piecewise linear function. Where are the “bends” in the graph? Precisely where the data points, y1,y2,,y5y_1, y_2, \ldots, y_5, are! Its at exactly these points where Rabs(w)R_\text{abs}(w) is not differentiable. At each of those points, the slope of the line segment approaching from the left is different from the slope of the line segment approaching from the right, and for a function to be differentiable at a point, the slope of the tangent line must be the same when approaching from the left and the right.

The graph of Rabs(w)R_\text{abs}(w) above, while not differentiable at any of the data points, still shows us something about the optimal constant prediction. If there is a bend at each data point, and at each bend the slope increases – that is, becomes more positive – then the optimal constant prediction seems to be in the middle, when the slope goes from negative to positive. I’ll make this more precise in a moment.

For now, you might notice the value of ww that minimizes the graph of Rabs(w)R_\text{abs}(w) above is a familiar summary statistic, but not the mean. I won’t spell it out just yet, since I’d like for you to reason about it yourself.

Let me show you one more graph of Rabs(w)R_\text{abs}(w), but this time, in a case where there are an even number of data points. Suppose we have a sixth point, y6=78y_6=78.

y1=72,y2=90,y3=61,y4=85,y5=92,y6=78y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92, \quad y_6=78

Then, Rabs(w)R_\text{abs}(w) is:

Rabs(w)=16(72w+90w+61w+85w+92w+78w)R_\text{abs}(w) = \frac{1}{6} (|72 - w| + |90 - w| + |61 - w| + |85 - w| + |92 - w| + |78 - w|)

And its graph is:

Image produced in Jupyter

This graph is broken into 7 segments, with 6 bends (one per data point). Between the 3rd and 4th bends – that is, the 3rd and 4th data points – the slope is 0, and all values in that interval minimize Rabs(w)R_\text{abs}(w). So, it seems that the value of ww^* doesn’t have to be unique!

Minimizing Mean Absolute Error

From the two graphs above, you may have a clear picture of what the optimal constant prediction, ww^*, is. But, to avoid relying too heavily on visual intuition and just a single set of example data points, let’s try and minimize Rabs(w)R_\text{abs}(w) mathematically, for an arbitrary set of data points.

To be clear, the goal is to minimize:

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

To do so, we’ll take the derivative of Rabs(w)R_\text{abs}(w) with respect to ww and set it equal to 0.

ddwRabs(w)=ddw(1ni=1nyiw)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) = \frac{\text{d}}{\text{d}w} \left( \frac{1}{n} \sum_{i=1}^n |y_i - w| \right)

Using the familiar facts that the derivative of a sum is the sum of the derivatives, and that constants can be pulled out of the derivative, we have:

ddwRabs(w)=1ni=1nddwyiw\frac{\text{d}}{\text{d}w} R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n \frac{\text{d}}{\text{d}w} |y_i - w|

Here’s where the challenge comes in. What is ddwyiw\frac{\text{d}}{\text{d}w} |y_i - w|?

Let’s start by remembering the derivative of the absolute value function. The absolute value function itself can be thought of as a piecewise function:

x={xx0xx<0|x| = \begin{cases} x & x \geq 0 \\ -x & x < 0 \end{cases}

Note that the x=0x=0 case can either lumped in either the xx or x-x case, since 0 and -0 are both 0.

Using this logic, I’ll write yiw|y_i - w| as a piecewise of ww:

yiw={yiwwyiwyiw>yi|y_i - w| = \begin{cases} y_i - w & w \leq y_i \\ w - y_i & w > y_i \end{cases}

I have written the two conditions with ww on the left, since it’s easier to think in terms of ww in my mind, but this means that the inequalities are flipped relative to how I presented them in the definition of x|x|. Remember, yiw|y_i - w| is a function of ww; we’re treating yiy_i as some constant. If it helps, replace every instance of yiy_i with a concrete number, like 5, then reason through the resulting graph.

Image produced in Jupyter

Now we can take the derivative of each piece:

yiw={1w<yiundefinedw=yi1w>yi|y_i - w| = \begin{cases} -1 & w < y_i \\ \text{undefined} & w = y_i \\ 1 & w > y_i \end{cases}

Great. Remember, this is the derivative of the absolute loss for a single data point. But our main objective is to find the derivative of the average absolute loss, Rabs(w)R_\text{abs}(w). Using this piecewise definition of ddwyiw\frac{\text{d}}{\text{d}w} |y_i - w|, we have:

ddwRabs(w)=1ni=1nddwyiw=1ni=1n{1w<yiundefinedw=yi1w>yi\begin{align*} \frac{\text{d}}{\text{d}w} R_\text{abs}(w) &= \frac{1}{n} \sum_{i=1}^n \frac{\text{d}}{\text{d}w} |y_i - w| \\ &= \frac{1}{n} \sum_{i=1}^n \begin{cases} -1 & w < y_i \\ \text{undefined} & w = y_i \\ 1 & w > y_i \end{cases} \end{align*}

At any point where w=yiw = y_i, for any value of ii, ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) is undefined. (This makes any point where w=yiw = y_i a critical point.) Let’s exclude those values of ww from our consideration. In all other cases, the sum in the expression above involves only two possible values: -1 and 1.

  • The sum adds -1 for all data points greater than ww, i.e. where w<yiw < y_i.
  • The sum adds 1 for all data points less than ww, i.e. where w>yiw > y_i.

Using some creative notation, I’ll re-write ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) as:

ddwRabs(w)=1n(w<yi1+w>yi1)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) = \frac{1}{n} \left( \sum_{w < y_i} -1 + \sum_{w > y_i} 1 \right)

The sum w<yi1\displaystyle \sum_{w < y_i} -1 is the sum of -1 for all data points greater than ww, so perhaps a more intuitive way to write it is:

w<yi1=(1)+(1)++(1)add once per data point to the right of w=(# right of w)\sum_{w < y_i} -1 = \underbrace{(-1) + (-1) + \ldots + (-1)}_{\text{add once per data point \\ to the right of } w} = -(\text{\# right of } w)

Equivalently, w>yi1=(# left of w)\displaystyle \sum_{w > y_i} 1 = (\text{\# left of } w), meaning that:

ddwRabs(w)=1n((# right of w)+(# left of w))=# left of w# right of wn\begin{align*} \frac{\text{d}}{\text{d}w} R_\text{abs}(w) &= \frac{1}{n} \left( -(\text{\# right of } w) + (\text{\# left of } w) \right) \\ &= \boxed{\frac{\text{\# left of } w - \text{\# right of } w}{n}} \end{align*}

This boxed form gives us the slope of Rabs(w)R_\text{abs}(w), for any point ww that is not an original data point. To put it in perspective, let’s revisit the first graph we saw in this section, where we plotted Rabs(w)R_\text{abs}(w) for the dataset:

y1=72,y2=90,y3=61,y4=85,y5=92y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92
Rabs(w)=15(72w+90w+61w+85w+92w)R_\text{abs}(w) = \frac{1}{5} (|72 - w| + |90 - w| + |61 - w| + |85 - w| + |92 - w|)
Image produced in Jupyter

Now that we have a formula for ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w), the easy thing to claim is that we could set it to 0 and solve for ww. Doing so would give us:

# left of w# right of wn=0\frac{\text{\# left of } w - \text{\# right of } w}{n} = 0

Which yields the condition:

# left of w=# right of w\text{\# left of } w = \text{\# right of } w

The optimal value of ww is the one that satisfies this condition, and that’s precisely the median of the data, as you may have noticed earlier.

This logic isn’t fully rigorous, however, because the formula for ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) is only valid for ww’s that aren’t original data points, and the median – if we have an odd number of data points – is indeed one of the original data points. In the graph above, there is never a point where the slope is 0.

To fully justify why the median (in the case of an odd number of data points) minimizes mean absolute error, I’ll say that:

  • If ww is just to the left of the median, there are more points to the right of ww than to the left of ww, so (# left of w)<(# right of w)(\text{\# left of } w) < (\text{\# right of } w) and (# left of w)(# right of w)n\frac{(\text{\# left of } w) - (\text{\# right of } w)}{n} is negative.
  • If ww is just to the right of the median, there are more points to the left of ww than to the right of ww, so (# left of w)>(# right of w)(\text{\# left of } w) > (\text{\# right of } w) and (# left of w)(# right of w)n\frac{(\text{\# left of } w) - (\text{\# right of } w)}{n} is positive.

So even though the slope is undefined at the median, we know it is a point at which the sign of the derivative switches from negative to positive, and as we discussed in Chapter 0.2, this sign change implies at least a local minimum.

To summarize:

  • If nn is odd, the median minimizes mean absolute error.
  • If nn is even, any value between the middle two values (when sorted) minimizes mean absolute error.

We’ve just made a second pass through the three-step modeling recipe:

  1. Choose a model.

    h(xi)=wh(x_i) = w

  2. Choose a loss function.

    Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

  3. Minimize average loss to find optimal model parameters.

    Rabs(w)=1ni=1nyiw    w=Median(y1,y2,,yn)R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w| \implies w^* = \text{Median}(y_1, y_2, \ldots, y_n)

Conclusion

What we’ve now discovered is that the optimal model parameter (in this case, the optimal constant prediction) depends on the loss function we choose!

In the context of the commute times dataset from Chapter 1.2, our two optimal constant predictions can be visualized as flat lines, as shown below.

Image produced in Jupyter

Depending on your criteria for what makes a good or bad prediction (i.e., the loss function you choose), optimal model parameters may change.


Comparing Loss Functions

We now know that:

  • The mean is the constant prediction that minimizes mean squared error,

    Rsq(w)=1ni=1n(yiw)2    w=Mean(y1,y2,,yn)R_\text{sq}(w) = \frac{1}{n} \sum_{i=1}^n (y_i - w)^2 \implies w^* = \text{Mean}(y_1, y_2, \ldots, y_n)
  • The median is the constant prediction that minimizes mean absolute error,

    Rabs(w)=1ni=1nyiw    w=Median(y1,y2,,yn)R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w| \implies w^* = \text{Median}(y_1, y_2, \ldots, y_n)

Let’s compare the behavior of the mean and median, and reason about how their differences in behavior are related to the differences in their loss functions.

Outliers

Let’s consider our example dataset of 5 commute times, sorted from least to greatest:

y1=61y2=72y3=85y4=90y5=92y_1 = 61 \:\:\:\:\:\:\:\:\:\: y_2 = 72 \:\:\:\:\:\:\:\:\:\: y_3 = 85 \:\:\:\:\:\:\:\:\:\: y_4 = 90 \:\:\:\:\:\:\:\:\:\: y_5 = 92

Here, the median is 85 and the mean is 80. But what if we try adding 200 to the largest commute time?

y1=61y2=72y3=85y4=90y5=292y_1 = 61 \:\:\:\:\:\:\:\:\:\: y_2 = 72 \:\:\:\:\:\:\:\:\:\: y_3 = 85 \:\:\:\:\:\:\:\:\:\: y_4 = 90 \:\:\:\:\:\:\:\:\:\: y_5 = 292

The median is still 85, but the mean is now 120! In other words, the mean is sensitive to outliers!

Image produced in Jupyter

We can see that compared to the median, the mean is being “pulled” towards the outlier. We often say the median is robust to outliers. This makes sense when you consider that the mean and median balance different aspects of the distribution. The median is the point where the number of points to the left and right are equal, and the mean is where the sums to the left and right are equal. This gives us a trade-off between the loss functions: do we want to choose one with a minimizer that’s easier to calculate, or one that’s less sensitive to outliers?

Now that we’ve seen absolute and squared loss, let’s explore minimizing empirical risk for other loss functions. For any p1p \geq 1, define the LpL_p loss as follows:

Lp(yi,w)=yiwpL_p(y_i, w) = |y_i - w|^p

Note that we need the absolute value to avoid negative loss. Given an LpL_p loss, the corresponding empirical risk is:

Rp(w)=1ni=1nyiwpR_p(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|^p
  • When p=1p = 1, or absolute loss, w=Median(y1,y2,...,yn)*w^ = \text{Median}(y_1, y_2, ..., y_n).
  • When p=2p = 2, or squared loss, w=Mean(y1,y2,...,yn)w^* = \text{Mean}(y_1, y_2, ..., y_n).

What about when pp \rightarrow \infty, how can we find the minimizer ww^*? Let’s use visualize the dataset from before, 61,72,85,90,29261, 72, 85, 90, 292, on a different graph. The xx axis is pp, and the yy axis is the optimal constant prediction ww^* for that LpL_p loss function.

Image produced in Jupyter

As pp \rightarrow \infty, ww^* approaches 176.5, which is actually the midpoint of the minimum and maximium values of the dataset. We call this the midrange, and it’s the best prediction when the measure of a “good” prediction is not being too far from any point in the dataset.

Let’s consider another loss funtion, this time 0-1 loss:

L0,1(yi,w)={0yi=w1yiwL_{0,1}(y_i, w) = \begin{cases} 0 & y_i = w \\ 1 & y_i \neq w \end{cases}

The corresponding empirical risk is:

R0,1(w)=1ni=1nL0,1(yi,w)R_{0,1}(w) = \frac{1}{n} \sum_{i = 1}^n L_{0, 1}(y_i, w)

The empirical risk is the proportion of points not equal to the prediction ww. To minimize risk, we want there to be as many yiy_i’s as possible which are equal to our prediction, so ww^* is the mode of the dataset.

To recap, choosing a loss function is important because it determines what the best prediction will be! Each optimal prediciton ww^* is a different summary statistic that measures the center of the dataset.

LossMinimizerAlways
Unique?
Robust to
Outliers?
Differentiable?
L_\text{sq}meanyes ✅no ❌yes ✅
L_\text{abs}medianno ❌yes ✅no ❌
L_\inftymidrangeyes ✅no ❌no ❌
L_\text{0,1}modeno ❌yes ✅no ❌

Center and Spread

We know that the ww^* which minimizes empirical risk is some measure of the center of the dataset, but what does this tell us about the minimum risk itself? Consider the empirical risk for squared loss:

Rsq(w)=1ni=1n(yiw)2R_\text{sq}(w) = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2

Rsq(w)R_\text{sq}(w) is minimized when ww^* is the mean, so let’s try plugging that value back into RsqR_\text{sq} to find the minimum value:

Rsq(w)=Rsq(Mean(y1,y2,...,yn))=1ni=1n(yiMean(y1,y2,...,yn))2\begin{align*} R_\text{sq}(w^*) &= R_\text{sq}\left( \text{Mean}(y_1, y_2, ..., y_n) \right) \\ &= \frac{1}{n} \sum_{i = 1}^n \left( y_i - \text{Mean}(y_1, y_2, ..., y_n) \right)^2 \end{align*}

The minimum value comes from taking each data point’s deviation from the mean of the dataset, squaring them, and averaging them. There’s a special name for this average, and it’s called variance. You may be familiar with its square root, known as standard deviation. Both of these are used to measure the distance between data and the mean of the dataset.

Let’s try the same process on absolute loss next:

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|

We’ll plug w=Median(y1,y2,...,yn)w^* = \text{Median}(y_1, y_2, ..., y_n) into Rabs(w)R_\text{abs}(w):

Rabs(h)=1ni=1nyih=Rabs(h)=1ni=1nyiMedian(y1,y2,...,yn)\begin{align*} R_\text{abs}(h^*) &= \frac{1}{n} \sum_{i = 1}^n |y_i - h| \\ &= R_\text{abs}(h) = \frac{1}{n} \sum_{i = 1}^n |y_i - \text{Median}(y_1, y_2, ..., y_n)| \end{align*}

Our minimum value is the mean absolute deviation from the median. Similarily to variance, it also measures the average distance from each data point to a center, this time the median. Both of these are a way of measuring the spread in our dataset, using their respective centers.

What if we try using 0-1 loss?

R0,1(w)=1ni=1n{0yi=w1yiwR_{0,1}(w) = \frac{1}{n} \sum_{i = 1}^n \begin{cases} 0 & y_i = w \\ 1 & y_i \neq w \end{cases}

The minimizer is the mode of the dataset, so the minimum value is the proportion of values not equal to a mode of the dataset. However, this doesn’t tell us much about how spread out the data is. If R0,1R_{0,1} is higher, that only means that there’s less data clustered at exactly a mode. The mode is a very basic way to measure the center, so it follows that the corresponding measure of spread is also basic and rather informative.

Ultimately, choosing a loss function comes down to understanding which measures of center and spread are most important for your prediction problem and dataset.