Skip to article frontmatterSkip to article content

1.3. Empirical Risk Minimization

The Modeling Recipe

In Chapter 1.2, we implicitly introduced a three-step process for building a machine learning model.

Image produced in Jupyter

Most modern supervised learning algorithms follow these same three steps, just with different models, loss functions, and techniques for optimization.

Another name given to this process is empirical risk minimization.

When using squared loss, all three of these mean the same thing:

  • Average squared loss.

  • Mean squared error.

  • Empirical risk.

Risk is an idea from theoretical statistics that we’ll visit in Chapter 6. It refers to the expected error of a model, when considering the probability distribution of the data. “Empirical” risk refers to risk calculated using an actual, concrete dataset, rather than a theoretical distribution. The reason we call the average loss RR is precisely because it is empirical risk.

The first half of the course – and in some ways, the entire course – is focused on empirical risk minimization, and so we will make many passes through the three-step modeling recipe ourselves, with differing models and loss functions.

A common question you’ll see in labs, homeworks, and exams will involve finding the optimal model parameters for a given model and loss function – in particular, for a combination of model and loss function that you’ve never seen before. For practice with this sort of exercise, work through the following activity. If you feel stuck, try reading through the rest of this section for context, then come back.


Absolute Loss

When we first introduced the idea of a loss function, we first started by computing the error, eie_i, of each prediction:

ei=yih(xi)e_i={\color{3D81F6}y_i}-{\color{orange}h(x_i)}

where yi{\color{3D81F6}y_i} is the actual value and h(xi){\color{orange}h(x_i)} is the predicted value.

The issue was that some errors were positive and some were negative, and so it was hard to compare them directly. We wanted the value of the loss function to be large for bad predictions and small for good predictions.

To get around this, we squared the errors, which gave us squared loss:

Lsq(yi,h(xi))=(yih(xi))2L_\text{sq}({\color{3D81F6}y_i}, {\color{orange}h(x_i)})=({\color{3D81F6}y_i}-{\color{orange}h(x_i)})^2

But, instead, we could have taken the absolute value of the errors. Doing so gives us absolute loss:

Labs(yi,h(xi))=yih(xi)L_\text{abs}({\color{3D81F6}y_i}, {\color{orange}h(x_i)})=|{\color{3D81F6}y_i}-{\color{orange}h(x_i)}|

Below, I’ve visualized the absolute loss and squared loss for just a single data point.

Image produced in Jupyter

You should notice two key differences between the two loss functions:

  1. The absolute loss function is not differentiable when yi=h(xi)y_i = h(x_i). The absolute value function, f(x)=xf(x) = |x|, does not have a derivative at x=0x=0, because its slope to the left of x=0x=0 (-1) is different from its slope to the right of x=0x=0 (1). For more on this idea, see Chapter 0.2.

  2. The squared loss function grows much faster than the absolute loss function, as the prediction h(xi)h(x_i) gets further away from the actual value yiy_i.

We know the optimal constant prediction, ww^*, when using squared loss, is the mean. What is the optimal constant prediction when using absolute loss? The answer is not still the mean; rather, the answer reflects some of these differences between squared loss and absolute loss.

Let’s find that new optimal constant prediction, ww^*, by revisiting the three-step modeling recipe.

  1. Choose a model.

    We’ll stick with the constant model, h(xi)=wh(x_i) = w.

  2. Choose a loss function.

    We’ll use absolute loss:

    Labs(yi,h(xi))=yih(xi)L_\text{abs}(y_i, h(x_i)) = |y_i - h(x_i)|

    For the constant model, since h(xi)=wh(x_i) = w, we have:

    Labs(yi,w)=yiwL_\text{abs}(y_i, w) = |y_i - w|
  3. Minimize average loss to find optimal model parameters.

    The average loss – also known as mean absolute error here – is:

    Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

In Chapter 1.2, we minimized Rsq(w)=1ni=1n(yiw)2\displaystyle R_\text{sq}(w) = \frac{1}{n} \sum_{i=1}^n (y_i - w)^2 by taking the derivative of Rsq(w)R_\text{sq}(w) with respect to ww and setting it equal to 0. That will be more challenging in the case of Rabs(w)R_\text{abs}(w), because the absolute value function is not differentiable when its input is 0, as we just discussed.


Mean Absolute Error for the Constant Model

We need to minimize the mean absolute error, Rabs(w)R_\text{abs}(w), for the constant model, h(xi)=wh(x_i) = w, but we have to address the fact that Rabs(w)R_\text{abs}(w) is not differentiable across its entire domain.

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

Graphing Mean Absolute Error

I think it’ll help to visualize what Rabs(w)R_\text{abs}(w) looks like. To do so, let’s reintroduce the small dataset of 5 values we used in Chapter 1.2.

y1=72,y2=90,y3=61,y4=85,y5=92y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92

Then, Rabs(w)R_\text{abs}(w) is:

Rabs(w)=15(72w+90w+61w+85w+92w)R_\text{abs}(w) = \frac{1}{5} (|72 - w| + |90 - w| + |61 - w| + |85 - w| + |92 - w|)
Image produced in Jupyter

This is a piecewise linear function. Where are the “bends” in the graph? Precisely where the data points, y1,y2,,y5y_1, y_2, \ldots, y_5, are! Its at exactly these points where Rabs(w)R_\text{abs}(w) is not differentiable. At each of those points, the slope of the line segment approaching from the left is different from the slope of the line segment approaching from the right, and for a function to be differentiable at a point, the slope of the tangent line must be the same when approaching from the left and the right.

The graph of Rabs(w)R_\text{abs}(w) above, while not differentiable at any of the data points, still shows us something about the optimal constant prediction. If there is a bend at each data point, and at each bend the slope increases – that is, becomes more positive – then the optimal constant prediction seems to be in the middle, when the slope goes from negative to positive. I’ll make this more precise in a moment.

For now, you might notice the value of ww that minimizes the graph of Rabs(w)R_\text{abs}(w) above is a familiar summary statistic, but not the mean. I won’t spell it out just yet, since I’d like for you to reason about it yourself.

Let me show you one more graph of Rabs(w)R_\text{abs}(w), but this time, in a case where there are an even number of data points. Suppose we have a sixth point, y6=78y_6=78.

y1=72,y2=90,y3=61,y4=85,y5=92,y6=78y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92, \quad y_6=78

Then, Rabs(w)R_\text{abs}(w) is:

Rabs(w)=16(72w+90w+61w+85w+92w+78w)R_\text{abs}(w) = \frac{1}{6} (|72 - w| + |90 - w| + |61 - w| + |85 - w| + |92 - w| + |78 - w|)

And its graph is:

Image produced in Jupyter

This graph is broken into 7 segments, with 6 bends (one per data point). Between the 3rd and 4th bends – that is, the 3rd and 4th data points – the slope is 0, and all values in that interval minimize Rabs(w)R_\text{abs}(w). So, it seems that the value of ww^* doesn’t have to be unique!

Minimizing Mean Absolute Error

From the two graphs above, you may have a clear picture of what the optimal constant prediction, ww^*, is. But, to avoid relying too heavily on visual intuition and just a single set of example data points, let’s try and minimize Rabs(w)R_\text{abs}(w) mathematically, for an arbitrary set of data points.

To be clear, the goal is to minimize:

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

To do so, we’ll take the derivative of Rabs(w)R_\text{abs}(w) with respect to ww and set it equal to 0.

ddwRabs(w)=ddw(1ni=1nyiw)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) = \frac{\text{d}}{\text{d}w} \left( \frac{1}{n} \sum_{i=1}^n |y_i - w| \right)

Using the familiar facts that the derivative of a sum is the sum of the derivatives, and that constants can be pulled out of the derivative, we have:

ddwRabs(w)=1ni=1nddwyiw\frac{\text{d}}{\text{d}w} R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n \frac{\text{d}}{\text{d}w} |y_i - w|

Here’s where the challenge comes in. What is ddwyiw\frac{\text{d}}{\text{d}w} |y_i - w|?

Let’s start by remembering the derivative of the absolute value function. The absolute value function itself can be thought of as a piecewise function:

x={xx0xx<0|x| = \begin{cases} x & x \geq 0 \\ -x & x < 0 \end{cases}

Note that the x=0x=0 case can either lumped in either the xx or x-x case, since 0 and -0 are both 0.

Using this logic, I’ll write yiw|y_i - w| as a piecewise of ww:

yiw={yiwwyiwyiw>yi|y_i - w| = \begin{cases} y_i - w & w \leq y_i \\ w - y_i & w > y_i \end{cases}

I have written the two conditions with ww on the left, since it’s easier to think in terms of ww in my mind, but this means that the inequalities are flipped relative to how I presented them in the definition of x|x|. Remember, yiw|y_i - w| is a function of ww; we’re treating yiy_i as some constant. If it helps, replace every instance of yiy_i with a concrete number, like 5, then reason through the resulting graph.

Image produced in Jupyter

Now we can take the derivative of each piece:

yiw={1w<yiundefinedw=yi1w>yi|y_i - w| = \begin{cases} -1 & w < y_i \\ \text{undefined} & w = y_i \\ 1 & w > y_i \end{cases}

Great. Remember, this is the derivative of the absolute loss for a single data point. But our main objective is to find the derivative of the average absolute loss, Rabs(w)R_\text{abs}(w). Using this piecewise definition of ddwyiw\frac{\text{d}}{\text{d}w} |y_i - w|, we have:

ddwRabs(w)=1ni=1nddwyiw=1ni=1n{1w<yiundefinedw=yi1w>yi\begin{align*} \frac{\text{d}}{\text{d}w} R_\text{abs}(w) &= \frac{1}{n} \sum_{i=1}^n \frac{\text{d}}{\text{d}w} |y_i - w| \\ &= \frac{1}{n} \sum_{i=1}^n \begin{cases} -1 & w < y_i \\ \text{undefined} & w = y_i \\ 1 & w > y_i \end{cases} \end{align*}

At any point where w=yiw = y_i, for any value of ii, ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) is undefined. (This makes any point where w=yiw = y_i a critical point.) Let’s exclude those values of ww from our consideration. In all other cases, the sum in the expression above involves only two possible values: -1 and 1.

  • The sum adds -1 for all data points greater than ww, i.e. where w<yiw < y_i.

  • The sum adds 1 for all data points less than ww, i.e. where w>yiw > y_i.

Using some creative notation, I’ll re-write ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) as:

ddwRabs(w)=1n(w<yi1+w>yi1)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) = \frac{1}{n} \left( \sum_{w < y_i} -1 + \sum_{w > y_i} 1 \right)

The sum w<yi1\displaystyle \sum_{w < y_i} -1 is the sum of -1 for all data points greater than ww, so perhaps a more intuitive way to write it is:

w<yi1=(1)+(1)++(1)add once per data point to the right of w=(# right of w)\sum_{w < y_i} -1 = \underbrace{(-1) + (-1) + \ldots + (-1)}_{\text{add once per data point \\ to the right of } w} = -(\text{\# right of } w)

Equivalently, w>yi1=(# left of w)\displaystyle \sum_{w > y_i} 1 = (\text{\# left of } w), meaning that:

ddwRabs(w)=1n((# right of w)+(# left of w))=# left of w# right of wn\begin{align*} \frac{\text{d}}{\text{d}w} R_\text{abs}(w) &= \frac{1}{n} \left( -(\text{\# right of } w) + (\text{\# left of } w) \right) \\ &= \boxed{\frac{\text{\# left of } w - \text{\# right of } w}{n}} \end{align*}

By “left of ww”, I mean less than ww.

This boxed form gives us the slope of Rabs(w)R_\text{abs}(w), for any point ww that is not an original data point. To put it in perspective, let’s revisit the first graph we saw in this section, where we plotted Rabs(w)R_\text{abs}(w) for the dataset:

y1=72,y2=90,y3=61,y4=85,y5=92y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92
Rabs(w)=15(72w+90w+61w+85w+92w)R_\text{abs}(w) = \frac{1}{5} (|72 - w| + |90 - w| + |61 - w| + |85 - w| + |92 - w|)
Image produced in Jupyter

Now that we have a formula for ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w), the easy thing to claim is that we could set it to 0 and solve for ww. Doing so would give us:

# left of w# right of wn=0\frac{\text{\# left of } w - \text{\# right of } w}{n} = 0

Which yields the condition:

# left of w=# right of w\text{\# left of } w = \text{\# right of } w

The optimal value of ww is the one that satisfies this condition, and that’s precisely the median of the data, as you may have noticed earlier.

This logic isn’t fully rigorous, however, because the formula for ddwRabs(w)\frac{\text{d}}{\text{d}w} R_\text{abs}(w) is only valid for ww’s that aren’t original data points, and if we have an odd number of data points, the median is indeed one of the original data points. In the graph above, there is never a point where the slope is 0.

To fully justify why the median minimizes mean absolute error even when there are an odd number of data points, I’ll say that:

  • If ww is just to the left of the median, there are more points to the right of ww than to the left of ww, so (# left of w)<(# right of w)(\text{\# left of } w) < (\text{\# right of } w) and (# left of w)(# right of w)n\frac{(\text{\# left of } w) - (\text{\# right of } w)}{n} is negative.

  • If ww is just to the right of the median, there are more points to the left of ww than to the right of ww, so (# left of w)>(# right of w)(\text{\# left of } w) > (\text{\# right of } w) and (# left of w)(# right of w)n\frac{(\text{\# left of } w) - (\text{\# right of } w)}{n} is positive.

So even though the slope is undefined at the median, we know it is a point at which the sign of the derivative switches from negative to positive, and as we discussed in Chapter 0.2, this sign change implies at least a local minimum.

To summarize:

  • If nn is odd, the median minimizes mean absolute error.

  • If nn is even, any value between the middle two values (when sorted) minimizes mean absolute error. (It’s common to call the mean of the middle two values the median.)

We’ve just made a second pass through the three-step modeling recipe:

  1. Choose a model.

    h(xi)=wh(x_i) = w

  2. Choose a loss function.

    Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|

  3. Minimize average loss to find optimal model parameters.

    Rabs(w)=1ni=1nyiw    w=Median(y1,y2,,yn)R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w| \implies w^* = \text{Median}(y_1, y_2, \ldots, y_n)

Conclusion

What we’ve now discovered is that the optimal model parameter (in this case, the optimal constant prediction) depends on the loss function we choose!

In the context of the commute times dataset from Chapter 1.2, our two optimal constant predictions can be visualized as flat lines, as shown below.

Image produced in Jupyter

Depending on your criteria for what makes a good or bad prediction (i.e., the loss function you choose), optimal model parameters may change.


Comparing Loss Functions

Let’s compare the behavior of the mean and median, and reason about how their differences in behavior are related to the differences in the loss functions used to derive them.

Let’s consider our example dataset of 5 commute times, with a mean of 85 and median of 80:

617285909261 \qquad 72 \qquad 85 \qquad 90 \qquad 92

Suppose 200 is added to the largest commute time:

6172859029261 \qquad 72 \qquad 85 \qquad 90 \qquad 292

The median is still 85, but the mean is now 80+2005=12080 + \frac{200}{5} = 120. This example illustrates the fact that the mean is sensitive to outliers, while the median is robust to outliers.

But why? I like to think of the mean and median as different “balance points” of a dataset, each satisfying a different “balance condition”.

Summary StatisticMinimizesBalance Condition
(comes from setting ddwR(w)=0\frac{\text{d}}{\text{d}w} R(w) = 0)
MedianRabs(w)=1ni=1nyiw\displaystyle R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|# left of w=# right of w\text{\# left of } w = \text{\# right of } w
MeanRsq(w)=1ni=1n(yiw)2\displaystyle R_\text{sq}(w) = \frac{1}{n} \sum_{i=1}^n (y_i - w)^2i=1n(yiw)=0\displaystyle\sum_{i=1}^n (y_i - w) = 0

In both cases, the “balance condition” comes from setting the derivative of empirical risk, ddwR(w)\frac{\text{d}}{\text{d}w} R(w), to 0. The logic for the median and mean absolute error is more fresh from this section, so let’s think in terms of the mean and mean squared error, from Chapter 1.2. There, we found that:

ddwRsq(w)=2ni=1n(yiw)\displaystyle \frac{\text{d}}{\text{d}w} R_\text{sq}(w) = -\frac{2}{n} \sum_{i=1}^n (y_i - w)

Setting this to 0 gave us the balance equation above.

i=1n(yiw)=0\sum_{i=1}^n (y_i - w^*) = 0

In English, this is saying that the sum of deviations from each data point to the mean is 0. (“Deviation” just means "difference from the mean.) Or, in other words, the positive differences and negative differences cancel each other out at the mean, and the mean is the unique point where this happens.

Let me illustrate using our familiar small example dataset.

617285909261 \qquad 72 \qquad 85 \qquad 90 \qquad 92

The mean is 80. Then:

(6180)19+(7280)8+(8580)5+(9080)10+(9280)12=0\underbrace{(61 - 80)}_{\color{#d81b60}{-19}} + \underbrace{(72 - 80)}_{\color{#d81b60}{-8}} + \underbrace{(85 - 80)}_{\color{#3d81f6}{5}} + \underbrace{(90 - 80)}_{\color{#3d81f6}{10}} + \underbrace{(92 - 80)}_{\color{#3d81f6}{12}} = 0

Note that the negative deviations and positive deviations both total 27 in magnitude.

While the mean balances the positive and negative deviations, the median balances the number of points on either side, without regard to how far the values are from the median.

Here’s another perspective: the squared loss more heavily penalizes outliers, and so the resulting predictions “cater to” or are “pulled” towards these outliers.

Image produced in Jupyter

In the example above, the top plot visualizes the squared loss for the constant prediction of w=4w = 4 against the dataset 1, 2, 3, and 14. While it has relatively small squared losses to the three points on the left, it has a very large squared loss to the point at 14, of (144)2=100(14 - 4)^2 = 100, which causes the mean squared error to be large.

In efforts to reduce the overall mean squared error, the optimal ww^* is pulled towards 14. w=5w^* = 5 has larger squared losses to the points at 1, 2, and 3 than w=4w = 4 did, but a much smaller squared loss to the point at 14, of (145)2=81(14 - 5)^2 = 81. The “savings” from going from a squared loss of 102=10010^2 = 100 to 92=819^2 = 81 more than makes up for the additional squared losses to the points at 1, 2, and 3.

In short: models that are fit using squared loss are strongly pulled towards outliers, in an effort to keep mean squared error low. Models fit using absolute loss don’t have this tendency.

To conclude, let me visualize the behavior of the mean and median with a larger dataset – the full dataset of commute times we first saw at the start of Chapter 1.2.

Image produced in Jupyter

The median is the point at which half the values are below it and half are above it. In the histogram above, half of the area is to the left of the median and half is to the right.

The mean is the point at which the sum of deviations from each value to the mean is 0. Another interpretation: if you placed this histogram on a playground see-saw, the mean would be the point at which the see-saw is balanced. Wikipedia has a good illustration of this general idea.

We say the distribution above is right-skewed or right-tailed because the tail is on the right side of the distribution. (This is counterintuitive to me, because most of the data is on the left of the distribution.)

In general, the mean is pulled in the direction of the tail of a distribution:

  • If a distribution is symmetric, i.e. has roughly the same-shaped tail on the left and right, the mean and median are similar.

  • If a distribution is right-skewed, the mean is pulled to the right of the median, i.e. Mean>Median\text{Mean} > \text{Median}.

  • If a distribution is left-skewed, the mean is pulled to the left of the median, i.e. Mean<Median\text{Mean} < \text{Median}.

This explains why Mean>Median{\color{orange} \text{Mean}} > {\color{purple} \text{Median}} in the histogram above, and equivalently, in the scatter plot below.

Image produced in Jupyter

Many common distributions in the real world are right-skewed, including incomes and net worths, and in such cases, the mean doesn’t tell the full story.

Mean (average) and median incomes for several countries (source).

Mean (average) and median incomes for several countries (source).

When we move to more sophisticated models with (many) more parameters, the optimal parameter values won’t be as easily interpretable as the mean and median of our data, but the effects of our choice of loss function will still be felt in the predictions we make.


Beyond Absolute and Squared Loss

You may have noticed that the absolute loss and squared loss functions both look relatively similar:

Labs(yi,h(xi))=yih(xi)L_\text{abs}({\color{3D81F6}y_i}, {\color{orange}h(x_i)}) = |{\color{3D81F6}y_i}-{\color{orange}h(x_i)}|
Lsq(yi,h(xi))=(yih(xi))2L_\text{sq}({\color{3D81F6}y_i}, {\color{orange}h(x_i)}) = ({\color{3D81F6}y_i}-{\color{orange}h(x_i)})^2

Both of these loss functions are special cases of a more general class of loss functions, known as LpL_p loss functions. For any p1p \geq 1, define the LpL_p loss as follows:

Lp(yi,h(xi))=yih(xi)pL_p(y_i, h(x_i)) = |{\color{3D81F6}y_i}-{\color{orange}h(x_i)}|^p

Suppose we continue to use the constant model, h(xi)=wh(x_i) = w. Then, the corresponding empirical risk for LpL_p loss is:

Rp(w)=1ni=1nyiwpR_p(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|^p

We’ve studied, in depth, the minimizers of Rp(w)R_p(w) for p=1p = 1 (the median) and p=2p = 2 (the mean). What about when p=3p = 3, or p=4p = 4, or p=100p = 100? What happens as pp \rightarrow \infty?

Let me be a bit less abstract. Suppose we have p=6p = 6. Then, we’re looking for the constant prediction ww that minimizes the following:

R6(w)=1ni=1nyiw6=1ni=1n(yiw)6R_6(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|^6 = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^6

Note that I dropped the absolute value, because (yiw)6(y_i - w)^6 is always non-negative, since 6 is an even number.

To find ww^* here, we need to take the derivative of R6(w)R_6(w) with respect to ww and set it equal to 0.

ddwR6(w)=6ni=1n(yiw)5\frac{\text{d}}{\text{d}w} R_6(w) = -\frac{6}{n} \sum_{i = 1}^n (y_i - w)^5

Setting the above to 0 gives us a new balance condition: i=1n(yiw)5=0\displaystyle\sum_{i = 1}^n (y_i - w)^5 = 0. The minimizer of R2(w)R_2(w) was the point at which the balance condition i=1n(yiw)=0\displaystyle\sum_{i = 1}^n (y_i - w) = 0 was satisfied; equivalently, the minimizer of R6(w)R_6(w) is the point at which the balance condition i=1n(yiw)5=0\displaystyle\sum_{i = 1}^n (y_i - w)^5 = 0 is satisfied. You’ll notice that the degree of the differences in the balance condition is one lower than the degree of the differences in the loss function --- this comes from the power rule of differentiation.

At what point ww^* does i=1n(yiw)5=0\displaystyle\sum_{i = 1}^n (y_i - w^*)^5 = 0? It’s challenging to determine the value by hand, but the computer can approximate solutions for us, as you’ll see in Lab 2.

Below, you’ll find a computer-generated graph where:

  • The xx-axis is pp.

  • The yy-axis represents the value of ww^* that minimizes Rp(w)R_p(w), for the dataset

    6172859029261 \qquad 72 \qquad 85 \qquad 90 \qquad 292

    that we saw earlier. Note the maximum value in our dataset is 292.

Image produced in Jupyter

As pp \rightarrow \infty, ww^* approaches some value.

On the other extreme end, let me introduce yet another loss function, 0-1 loss:

L0,1(yi,h(xi))={0yi=h(xi)1yih(xi)L_{0,1}(y_i, h(x_i)) = \begin{cases} 0 & y_i = h(x_i) \\ 1 & y_i \neq h(x_i) \end{cases}

The corresponding empirical risk, for the constant model h(xi)=wh(x_i) = w, is:

R0,1(w)=1ni=1nL0,1(yi,w)R_{0,1}(w) = \frac{1}{n} \sum_{i = 1}^n L_{0, 1}(y_i, w)

This is the sum of 0s and 1s, divided by nn. A 1 is added to the sum each time yiwy_i \neq w. So, in other words, R0,1(w)R_{0,1}(w) is:

R0,1(w)=number of points not equal to wnR_{0,1}(w) = \frac{\text{number of points not equal to } w}{n}

To minimize empirical risk, we want the number of points not equal to ww to be as small as possible. So, ww^* is the mode (i.e. most frequent value) of the dataset. If all values in the dataset are unique, they all minimize average 0-1 loss. This is not a useful loss function for regression, since our predictions are drawn from the continuous set of real numbers, but is useful for classification.


Center and Spread

Prior to taking EECS 245, you knew about the mean, median, and mode of a dataset. What you now know is that each one of these summary statistics comes from minimizing empirical risk (i.e. average loss) for a different loss function. All three measure the center of the dataset in some way.

LossMinimizer of Empirical RiskAlways Unique?Robust to Outliers?Empirical Risk Differentiable?
LsqL_\text{sq}meanyes ✅no ❌yes ✅
LabsL_\text{abs}medianno ❌yes ✅no ❌
LL_\inftymidrangeyes ✅no ❌no ❌
L0,1L_\text{0,1}modeno ❌no ❌no ❌

So far, we’ve focused on finding model parameters that minimize empirical risk. But, we never stopped to think about what the minimum empirical risk itself is! Consider the empirical risk for squared loss and the constant model:

Rsq(w)=1ni=1n(yiw)2R_\text{sq}(w) = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2

Rsq(w)R_\text{sq}(w) is minimized when ww^* is the mean, which I’ll denote with yˉ\bar{y}. What happens if I plug w=yˉw^* = \bar{y} back into RsqR_\text{sq}?

Rsq(w)=Rsq(yˉ)=1ni=1n(yiyˉ)2R_\text{sq}(w^*) = R_\text{sq}(\bar{y}) = {\color{orange}\frac{1}{n} \sum_{i = 1}^n} {\color{#d81b60}(y_i - \bar{y})}^{\color{#3d81f6}2}

This is the variance of the dataset y1,y2,...,yny_1, y_2, ..., y_n! The variance is nothing but the average squared deviation of each value from the mean of the dataset.

This gives context to the yy-axis value of the vertex of the parabola we saw in Chapter 1.2.

Rsq(w)=1ni=1n(yiw)2R_\text{sq}(w) = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2
Image produced in Jupyter

Practically speaking, this gives us a nice “worst-case” mean squared error of any regression model on a dataset. If we learn how to build a sophisticated regression model, and its mean squared error is somehow greater than the variance of the dataset, we know that we’re doing something wrong, since we could do better just by predicting the mean!

The units of the variance are the square of the units of the yy-values. So, if the yiy_i’s represent commute times in minutes, the variance is in minutes2\text{minutes}^2. This makes it a bit difficult to interpret. So, we typically take the square root of the variance, which gives us the standard deviation, σ\sigma:

σ=variance=1ni=1n(yiyˉ)2\sigma = \sqrt{\text{variance}} = \sqrt{\frac{1}{n} \sum_{i = 1}^n (y_i - \bar{y})^2}

The standard deviation has the same units as the yy-values themselves, so it’s a more interpretable measure of spread.

How does this work in the context of absolute loss?

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|

Plugging in w=Median(y1,y2,...,yn)w^* = \text{Median}(y_1, y_2, ..., y_n) into Rabs(w)R_\text{abs}(w) gives us:

Rabs(w)=1ni=1nyiw=1ni=1nyiMedian(y1,y2,...,yn)\begin{align*} R_\text{abs}(w^*) &= \frac{1}{n} \sum_{i = 1}^n |y_i - w^*| \\ &= \frac{1}{n} \sum_{i = 1}^n |y_i - \text{Median}(y_1, y_2, ..., y_n)| \end{align*}

I’ll admit, this result doesn’t have a special name. It is the mean absolute deviation from the median. And, like the variance and standard deviation, it measures roughly how far spread out the data is from its center. Its units are the same as the yy-values themselves (since there’s no squaring involved).

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|
Image produced in Jupyter

In the real-world, be careful when you hear the term “mean absolute deviation”, as sometimes it’s used to refer to the median absolute deviation from the mean, not the median as above.

To reiterate, in practice, our models will have many, many more parameters than just one, as is the case for the constant model. But, by deeply studying the effects of choosing squared loss vs. absolute loss vs. other loss functions in the context of the constant model, we can develop a better intuition for how to choose loss functions in more complex situations.