Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

1.4. Comparing Loss Functions

We now know that:

Let’s compare the behavior of the mean and median, and reason about how their differences in behavior are related to the differences in the loss functions used to derive them.

Outliers and Balance

Let’s consider our example dataset of 5 commute times, with a mean of 85 and median of 80:

617285909261 \qquad 72 \qquad 85 \qquad 90 \qquad 92

Suppose 200 is added to the largest commute time:

6172859029261 \qquad 72 \qquad 85 \qquad 90 \qquad 292

The median is still 85, but the mean is now 80+2005=12080 + \frac{200}{5} = 120. This example illustrates the fact that the mean is sensitive to outliers, while the median is robust to outliers.

But why? I like to think of the mean and median as different “balance points” of a dataset, each satisfying a different “balance condition”.

Summary StatisticMinimizesBalance Condition
(comes from setting ddwR(w)=0\frac{\text{d}}{\text{d}w} R(w) = 0)
MedianRabs(w)=1ni=1nyiw\displaystyle R_\text{abs}(w) = \frac{1}{n} \sum_{i=1}^n |y_i - w|# left of w=# right of w\text{\# left of } w = \text{\# right of } w
MeanRsq(w)=1ni=1n(yiw)2\displaystyle R_\text{sq}(w) = \frac{1}{n} \sum_{i=1}^n (y_i - w)^2i=1n(yiw)=0\displaystyle\sum_{i=1}^n (y_i - w) = 0

In both cases, the “balance condition” comes from setting the derivative of empirical risk, ddwR(w)\frac{\text{d}}{\text{d}w} R(w), to 0. The logic for the median and mean absolute error is more fresh from this section, so let’s think in terms of the mean and mean squared error, from Chapter 1.2. There, we found that:

ddwRsq(w)=2ni=1n(yiw)\displaystyle \frac{\text{d}}{\text{d}w} R_\text{sq}(w) = -\frac{2}{n} \sum_{i=1}^n (y_i - w)

Setting this to 0 gave us the balance equation above.

i=1n(yiw)=0\sum_{i=1}^n (y_i - w^*) = 0

In English, this is saying that the sum of deviations from each data point to the mean is 0. (“Deviation” just means "difference from the mean.) Or, in other words, the positive differences and negative differences cancel each other out at the mean, and the mean is the unique point where this happens.

Let me illustrate using our familiar small example dataset.

617285909261 \qquad 72 \qquad 85 \qquad 90 \qquad 92

The mean is 80. Then:

(6180)19+(7280)8+(8580)5+(9080)10+(9280)12=0\underbrace{(61 - 80)}_{\color{#d81b60}{-19}} + \underbrace{(72 - 80)}_{\color{#d81b60}{-8}} + \underbrace{(85 - 80)}_{\color{#3d81f6}{5}} + \underbrace{(90 - 80)}_{\color{#3d81f6}{10}} + \underbrace{(92 - 80)}_{\color{#3d81f6}{12}} = 0

Note that the negative deviations and positive deviations both total 27 in magnitude.

While the mean balances the positive and negative deviations, the median balances the number of points on either side, without regard to how far the values are from the median.

Here’s another perspective: the squared loss more heavily penalizes outliers, and so the resulting predictions “cater to” or are “pulled” towards these outliers.

We just derived the absolute-loss solution for the constant model; now we’ll compare it to squared loss and examine how the choice of loss changes what “best” means.

Image produced in Jupyter

In the example above, the top plot visualizes the squared loss for the constant prediction of w=4w = 4 against the dataset 1, 2, 3, and 14. While it has relatively small squared losses to the three points on the left, it has a very large squared loss to the point at 14, of (144)2=100(14 - 4)^2 = 100, which causes the mean squared error to be large.

In efforts to reduce the overall mean squared error, the optimal ww^* is pulled towards 14. w=5w^* = 5 has larger squared losses to the points at 1, 2, and 3 than w=4w = 4 did, but a much smaller squared loss to the point at 14, of (145)2=81(14 - 5)^2 = 81. The “savings” from going from a squared loss of 102=10010^2 = 100 to 92=819^2 = 81 more than makes up for the additional squared losses to the points at 1, 2, and 3.

In short: models that are fit using squared loss are strongly pulled towards outliers, in an effort to keep mean squared error low. Models fit using absolute loss don’t have this tendency.

To conclude, let me visualize the behavior of the mean and median with a larger dataset – the full dataset of commute times we first saw at the start of Chapter 1.2.

Image produced in Jupyter

The median is the point at which half the values are below it and half are above it. In the histogram above, half of the area is to the left of the median and half is to the right.

The mean is the point at which the sum of deviations from each value to the mean is 0. Another interpretation: if you placed this histogram on a playground see-saw, the mean would be the point at which the see-saw is balanced. Wikipedia has a good illustration of this general idea.

We say the distribution above is right-skewed or right-tailed because the tail is on the right side of the distribution. (This is counterintuitive to me, because most of the data is on the left of the distribution.)

In general, the mean is pulled in the direction of the tail of a distribution:

  • If a distribution is symmetric, i.e. has roughly the same-shaped tail on the left and right, the mean and median are similar.

  • If a distribution is right-skewed, the mean is pulled to the right of the median, i.e. Mean>Median\text{Mean} > \text{Median}.

  • If a distribution is left-skewed, the mean is pulled to the left of the median, i.e. Mean<Median\text{Mean} < \text{Median}.

This explains why Mean>Median{\color{orange} \text{Mean}} > {\color{purple} \text{Median}} in the histogram above, and equivalently, in the scatter plot below.

Image produced in Jupyter

Many common distributions in the real world are right-skewed, including incomes and net worths, and in such cases, the mean doesn’t tell the full story.

Mean (average) and median incomes for several countries (source).

Mean (average) and median incomes for several countries (source).

When we move to more sophisticated models with (many) more parameters, the optimal parameter values won’t be as easily interpretable as the mean and median of our data, but the effects of our choice of loss function will still be felt in the predictions we make.


Beyond Absolute and Squared Loss

You may have noticed that the absolute loss and squared loss functions both look relatively similar:

Labs(yi,h(xi))=yih(xi)L_\text{abs}({\color{3D81F6}y_i}, {\color{orange}h(x_i)}) = |{\color{3D81F6}y_i}-{\color{orange}h(x_i)}|
Lsq(yi,h(xi))=(yih(xi))2L_\text{sq}({\color{3D81F6}y_i}, {\color{orange}h(x_i)}) = ({\color{3D81F6}y_i}-{\color{orange}h(x_i)})^2

Both of these loss functions are special cases of a more general class of loss functions, known as LpL_p loss functions. For any p1p \geq 1, define the LpL_p loss as follows:

Lp(yi,h(xi))=yih(xi)pL_p(y_i, h(x_i)) = |{\color{3D81F6}y_i}-{\color{orange}h(x_i)}|^p

Suppose we continue to use the constant model, h(xi)=wh(x_i) = w. Then, the corresponding empirical risk for LpL_p loss is:

Rp(w)=1ni=1nyiwpR_p(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|^p

We’ve studied, in depth, the minimizers of Rp(w)R_p(w) for p=1p = 1 (the median) and p=2p = 2 (the mean). What about when p=3p = 3, or p=4p = 4, or p=100p = 100? What happens as pp \rightarrow \infty?

Let me be a bit less abstract. Suppose we have p=6p = 6. Then, we’re looking for the constant prediction ww that minimizes the following:

R6(w)=1ni=1nyiw6=1ni=1n(yiw)6R_6(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|^6 = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^6

Note that I dropped the absolute value, because (yiw)6(y_i - w)^6 is always non-negative, since 6 is an even number.

To find ww^* here, we need to take the derivative of R6(w)R_6(w) with respect to ww and set it equal to 0.

ddwR6(w)=6ni=1n(yiw)5\frac{\text{d}}{\text{d}w} R_6(w) = -\frac{6}{n} \sum_{i = 1}^n (y_i - w)^5

Setting the above to 0 gives us a new balance condition: i=1n(yiw)5=0\displaystyle\sum_{i = 1}^n (y_i - w)^5 = 0. The minimizer of R2(w)R_2(w) was the point at which the balance condition i=1n(yiw)=0\displaystyle\sum_{i = 1}^n (y_i - w) = 0 was satisfied; equivalently, the minimizer of R6(w)R_6(w) is the point at which the balance condition i=1n(yiw)5=0\displaystyle\sum_{i = 1}^n (y_i - w)^5 = 0 is satisfied. You’ll notice that the degree of the differences in the balance condition is one lower than the degree of the differences in the loss function --- this comes from the power rule of differentiation.

At what point ww^* does i=1n(yiw)5=0\displaystyle\sum_{i = 1}^n (y_i - w^*)^5 = 0? It’s challenging to determine the value by hand, but the computer can approximate solutions for us, as you’ll see in Lab 2.

Below, you’ll find a computer-generated graph where:

  • The xx-axis is pp.

  • The yy-axis represents the value of ww^* that minimizes Rp(w)R_p(w), for the dataset

    6172859029261 \qquad 72 \qquad 85 \qquad 90 \qquad 292

    that we saw earlier. Note the maximum value in our dataset is 292.

Image produced in Jupyter

As pp \rightarrow \infty, ww^* approaches some value.

On the other extreme end, let me introduce yet another loss function, 0-1 loss:

L0,1(yi,h(xi))={0yi=h(xi)1yih(xi)L_{0,1}(y_i, h(x_i)) = \begin{cases} 0 & y_i = h(x_i) \\ 1 & y_i \neq h(x_i) \end{cases}

The corresponding empirical risk, for the constant model h(xi)=wh(x_i) = w, is:

R0,1(w)=1ni=1nL0,1(yi,w)R_{0,1}(w) = \frac{1}{n} \sum_{i = 1}^n L_{0, 1}(y_i, w)

This is the sum of 0s and 1s, divided by nn. A 1 is added to the sum each time yiwy_i \neq w. So, in other words, R0,1(w)R_{0,1}(w) is:

R0,1(w)=number of points not equal to wnR_{0,1}(w) = \frac{\text{number of points not equal to } w}{n}

To minimize empirical risk, we want the number of points not equal to ww to be as small as possible. So, ww^* is the mode (i.e. most frequent value) of the dataset. If all values in the dataset are unique, they all minimize average 0-1 loss. This is not a useful loss function for regression, since our predictions are drawn from the continuous set of real numbers, but is useful for classification.


Center and Spread

Prior to taking EECS 245, you knew about the mean, median, and mode of a dataset. What you now know is that each one of these summary statistics comes from minimizing empirical risk (i.e. average loss) for a different loss function. All three measure the center of the dataset in some way.

LossMinimizer of Empirical RiskAlways Unique?Robust to Outliers?Empirical Risk Differentiable?
LsqL_\text{sq}meanyes ✅no ❌yes ✅
LabsL_\text{abs}medianno ❌yes ✅no ❌
LL_\inftymidrangeyes ✅no ❌no ❌
L0,1L_\text{0,1}modeno ❌no ❌no ❌

So far, we’ve focused on finding model parameters that minimize empirical risk. But, we never stopped to think about what the minimum empirical risk itself is! Consider the empirical risk for squared loss and the constant model:

Rsq(w)=1ni=1n(yiw)2R_\text{sq}(w) = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2

Rsq(w)R_\text{sq}(w) is minimized when ww^* is the mean, which I’ll denote with yˉ\bar{y}. What happens if I plug w=yˉw^* = \bar{y} back into RsqR_\text{sq}?

Rsq(w)=Rsq(yˉ)=1ni=1n(yiyˉ)2R_\text{sq}(w^*) = R_\text{sq}(\bar{y}) = {\color{orange}\frac{1}{n} \sum_{i = 1}^n} {\color{#d81b60}(y_i - \bar{y})}^{\color{#3d81f6}2}

This is the variance of the dataset y1,y2,...,yny_1, y_2, ..., y_n! The variance is nothing but the average squared deviation of each value from the mean of the dataset.

This gives context to the yy-axis value of the vertex of the parabola we saw in Chapter 1.2.

Rsq(w)=1ni=1n(yiw)2R_\text{sq}(w) = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2
Image produced in Jupyter

Practically speaking, this gives us a nice “worst-case” mean squared error of any regression model on a dataset. If we learn how to build a sophisticated regression model, and its mean squared error is somehow greater than the variance of the dataset, we know that we’re doing something wrong, since we could do better just by predicting the mean!

The units of the variance are the square of the units of the yy-values. So, if the yiy_i’s represent commute times in minutes, the variance is in minutes2\text{minutes}^2. This makes it a bit difficult to interpret. So, we typically take the square root of the variance, which gives us the standard deviation, σ\sigma:

σ=variance=1ni=1n(yiyˉ)2\sigma = \sqrt{\text{variance}} = \sqrt{\frac{1}{n} \sum_{i = 1}^n (y_i - \bar{y})^2}

The standard deviation has the same units as the yy-values themselves, so it’s a more interpretable measure of spread.

How does this work in the context of absolute loss?

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|

Plugging in w=Median(y1,y2,...,yn)w^* = \text{Median}(y_1, y_2, ..., y_n) into Rabs(w)R_\text{abs}(w) gives us:

Rabs(w)=1ni=1nyiw=1ni=1nyiMedian(y1,y2,...,yn)\begin{align*} R_\text{abs}(w^*) &= \frac{1}{n} \sum_{i = 1}^n |y_i - w^*| \\ &= \frac{1}{n} \sum_{i = 1}^n |y_i - \text{Median}(y_1, y_2, ..., y_n)| \end{align*}

I’ll admit, this result doesn’t have a special name. It is the mean absolute deviation from the median. And, like the variance and standard deviation, it measures roughly how far spread out the data is from its center. Its units are the same as the yy-values themselves (since there’s no squaring involved).

Rabs(w)=1ni=1nyiwR_\text{abs}(w) = \frac{1}{n} \sum_{i = 1}^n |y_i - w|
Image produced in Jupyter

In the real-world, be careful when you hear the term “mean absolute deviation”, as sometimes it’s used to refer to the median absolute deviation from the mean, not the median as above.

To reiterate, in practice, our models will have many, many more parameters than just one, as is the case for the constant model. But, by deeply studying the effects of choosing squared loss vs. absolute loss vs. other loss functions in the context of the constant model, we can develop a better intuition for how to choose loss functions in more complex situations.