Skip to article frontmatterSkip to article content

1.2. Loss Functions and the Constant Model

Motivation

Suppose that you’re looking for an off campus apartment for next year. Unfortunately, none of them are in your price range, so you decide to live with your parents in Detroit and commute. To see if you can save some time on the road each day, you keep track of how long it takes for you to get to school.

Loading...

This is a real dataset, collected by Joseph Hearn, except he lived in Seattle, not Metro Detroit. The full dataset contains more columns than are shown above, but we’ll focus on these few for now.

Our goal is to predict commute times, stored in the minutes column. This is our yy variable. The natural first input variable, or feature, to consider, is departure_hour. This is our xx variable.

We’ll use the subscript ii to index the iith data point, for i=1,2,,ni = 1, 2, \ldots, n. Using the dataset above, x1=10.816667x_1 = 10.816667 and y1=68y_1 = 68, for instance.

Departure hours are stored as decimals, but correspond to times of the day. For example, 7.75 corresponds to 7:45 AM, and 10.816667 corresponds to 10:49 AM.

10.816667=10+0.816667 hours=10 hours+0.81666760 minutes=10 hours+49 minutes\begin{align*} 10.816667 &= 10 + 0.816667 \text{ hours} \\ &= 10 \text{ hours} + 0.816667 \cdot 60 \text{ minutes} \\ &= 10 \text{ hours} + 49 \text{ minutes} \end{align*}

Before we get any further, we should look at our data. Since we’re working with two quantitative variables, we should draw a scatter plot.

Image produced in Jupyter

There’s a general downward trend: the later the departure time, the lower the commute time tends to be.

Again, our goal is to predict commute time given departure hour. That is, we’d like to build a useful function hh such that:

predicted commute timei=h(departure houri)\text{predicted commute time}_i = h(\text{departure hour}_i)

This is a regression problem, because the variable we are predicting – commute time – is quantitative.

To build this function, the approach we’ll take is the machine learning approach – that is, to learn a pattern from the dataset that we’ve collected. (This is not the only approach one could take – we could build the function hh however we want.)

However, in order to learn a pattern from the dataset that we’ve collected, we need to make an important assumption.

We don’t really need our function hh to make good predictions on the dataset that we’ve already collected. We know the actual commute times on day 1, day 2, ..., day nn. In other words, we’re working with a labeled dataset, in which we’re given the values of y1,y2,,yny_1, y_2, \ldots, y_n.

What we do need is for our function hh to make good predictions on unseen data from the future, i.e. for future commutes, the ones we don’t know about yet. This is the only way our function hh will be practically useful.

But, if the future doesn’t resemble the past, the patterns we learn from the past will not be generalizable to the future. For example, if a new highway between Detroit and Ann Arbor gets built, the patterns previously learned won’t necessarily still exist. This idea of generalizability is key, so keep it in mind even if we’re not explicitly talking about it.


Models

I’ve used the word “model” loosely, but let’s give it a formal definition.

“All models are wrong, but some are useful.” - George Box

My interpretation of George Box’s famous quote is that no matter how complex a model is, it will never be 100% correct, so sometimes – especially when we’re starting our machine learning journey – it’s better to use a simpler model that is also wrong but interpretable.

We gain value from simple models all the time. In a physics class, you may have learned that acceleration due to gravity is 9.81 m/s29.81 \text{ m/s}^2 towards the center of the Earth. This is not fully accurate – think about how parachutes work, for example – but it’s still a useful approximation, and one that is relatively easy to understand. A related idea is Occam’s razor, which states that the simplest explanation of a phenomenon is often the best.

Image produced in Jupyter

Above, you’ll see a degree-40 polynomial fit to our dataset. We’ll learn how to build such polynomials throughout the semester.

At first glance, it looks to be quite accurate, albeit complex. In fact, it’s a little too complex, and the phenomenon we see above is called overfitting. For xix_i’s in the dataset that we collected, sure, the polynomial will make accurate predictions, but for any xix_i’s that don’t match the exact pattern in the dataset, the predictions will be off. (For example, it’s unlikely that commutes will take 110 minutes around 10:15AM, but that’s what the model predicts.) This polynomial model wouldn’t generalize well to unseen data.

If we look at the scatter plot closely, it seems reasonable to start with a line of best fit, much like you may have seen in a statistics class. In fact, we’ll start with something even more simple than that. But first, some notation.


Hypothesis Functions

The hypothesis functions we’ll study have parameters, usually denoted by ww, which describe the relationship between the input and output. The two hypothesis functions we’ll study are:

  1. Constant model: h(xi)=wh(x_i) = w
  2. Simple linear regression model: h(xi)=w0+w1xih(x_i) = w_0 + w_1 x_i

We’ll study the constant model first, but it’s easier to understand the role of parameters in the simple linear regression model.

Image produced in Jupyter

h(xi)=w0+w1xih(x_i) = w_0 + w_1 x_i represents the equation of a line, where w0w_0 is the intercept and w1w_1 is the slope. Above, we see that different choices of parameters w0w_0 and w1w_1 result in different lines. The million dollar question is: among all of the infinitely many choices of w0w_0 and w1w_1, which one is the best?

To fully answer that, we’ll have to wait until Chapter 1.4. Surprisingly, that answer involves multivariable calculus.

For now, let’s return to the constant model, h(xi)=wh(x_i) = w. The constant model predicts the same value for all xix_i’s, and looks like a flat line.

Image produced in Jupyter

We’ll use the constant model for the rest of this section to illustrate core ideas in machine learning, and will move to more sophisticated models in Chapter 1.4.

If we’re forced to use a constant model, it’s clear that some choices of ww (the height of the line) are better than others. w=100w=100 yields a flat line that is far from most of the data. w=60w=60 and w=70w=70 seem like much more reasonable predictions, but how we can quantify which one is better, and which ww would be the best?

Since the constant model doesn’t depend on departure hours xix_i, we can instead draw a histogram of just the true commute times.

Image produced in Jupyter

An equivalent way of phrasing the problem is, which constant ww best summarizes the histogram above? Most commute times seem to be in the 60 to 80 range, so somewhere there makes sense. How can we be more precise?


Loss Functions

To illustrate, let’s consider a small dataset of only 5 commute times.

y1=72,y2=90,y3=61,y4=85,y5=92y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92

If asked to find the constant that best summarizes these 5 numbers, you might think of the mean or median, which are common summary statistics. There are other valid choices too, like the mode, or halfway between the minimum and maximum, or the most recent. What we need is a way to compare these choices.

A loss function quantifies how bad a prediction is for a single data point.

  • If our prediction is close ✅ to the actual value, we should have low loss.
  • If our prediction is far ❌ from the actual value, we should have high loss.

We’ll start by computing the error for a single data point, defined as the difference between an actual yy-value and its corresponding predicted yy-value.

ei=yih(xi)e_i={\color{3D81F6}y_i}-{\color{orange}h(x_i)}

where yi{\color{3D81F6}y_i} is the actual value and h(xi){\color{orange}h(x_i)} is the predicted value.

Could this be a loss function? Let’s think this through. Suppose we have the true commute time yi=80y_i=80.

  • If I predict 75, ei=8075=5e_i=80-\textcolor{orange}{75}=5.
  • If I predict 72, ei=8072=8e_i=80-\textcolor{orange}{72}=8.
  • If I predict 100, ei=80100=20e_i=80-\textcolor{orange}{100}=-20.

A lower error is better, so 75 (error of 5) is a better prediction than 72 (error of 8). 100 seems to be the worst of the three predictions, but technically has the smallest error (-20). The issue is that some errors are positive and some are negative, and so it’s hard to compare them directly.

So ideally, a loss function shouldn’t have negative outputs. How can we take these errors, in which some are positive and some are negative, and enforce that they’re all positive?

Squared Loss

The most common solution to the problem of negative errors is to square each error. This gives rise to the first loss function we’ll explore, and arguably the most common loss function in machine learning: squared loss.

The squared loss function, LsqL_\text{sq}, computes (actualpredicted)2({\color{3D81F6}\text{actual}}-{\color{orange}\text{predicted}})^2. That is:

Lsq(yi,h(xi))=(yih(xi))2L_\text{sq}({\color{3D81F6}y_i}, {\color{orange}h(x_i)})=({\color{3D81F6}y_i}-{\color{orange}h(x_i)})^2

Why did we square instead of take the absolute value? Absolute loss is a perfectly valid loss function – in fact, we’ll study it in Chapter 1.3 – and different loss functions have different pros and cons. That said, squared loss is a good first choice because:

  • The resulting optimization problem is differentiable, as we’ll see in just a few moments.
  • It has a nice relationship to the normal distribution in statistics, as we’ll see in Chapter 6, at the end of the course.

Let’s return to our small example dataset of 5 commute times.

y1=72,y2=90,y3=61,y4=85,y5=92y_1=72, \quad y_2=90, \quad y_3=61, \quad y_4=85, \quad y_5=92

How can we use squared loss to compare two choices of ww, say w=85w=85 (the median) and w=80w=80 (the mean)? Let’s draw a picture (in which the xx-axis positions of each point are irrelevant, since we’re not using departure hours).

Image produced in Jupyter

Each output of LsqL_\text{sq}, shown in pink, describes the quality of a prediction for a single data point. For example, in the left plot above, the annotated (13)2(-13)^2 came from an actual value of 72 and a predicted value of 85:

Lsq(72,85)=(7285)2=(13)2=169L_\text{sq}({\color{3D81F6}72}, {\color{orange}85}) = ({\color{3D81F6}72}-{\color{orange}85})^2 = {\color{D81B60}(-13)^2} = 169

What we’d like is a single number which describes the quality of our predictions across the whole dataset, almost like a “score” for each choice of ww. Then, we can compare scores to choose the best possible ww. One way to construct such a score is to take the average of the squared losses.

  • For the median, w=85w = 85:

     15((7285)2+(9085)2+(6185)2+(8585)2+(9285)2)=163.8\begin{aligned} \ &\frac{1}{5} \left( (72 - {\color{orange} 85})^2 + (90 - {\color{orange} 85})^2 + (61 - {\color{orange} 85})^2 + (85 - {\color{orange} 85})^2 + (92 - {\color{orange} 85})^2 \right) \\ &= 163.8 \end{aligned}
  • For the mean, w=80w = 80:

     15((7280)2+(9080)2+(6180)2+(8580)2+(9280)2)=138.8\begin{aligned} \ &\frac{1}{5} \left( (72 - {\color{orange} 80})^2 + (90 - {\color{orange} 80})^2 + (61 - {\color{orange} 80})^2 + (85 - {\color{orange} 80})^2 + (92 - {\color{orange} 80})^2 \right) \\ &= 138.8 \end{aligned}

Losses are bad, so the better choice of ww has a lower average squared loss. Since 138.8<163.8138.8 < 163.8, the mean is a better prediction than the median.

Another term for average squared loss is mean squared error (MSE); this is the more common name for the technique we just defined.


Minimizing Mean Squared Error

Let’s start by generalizing mean squared error to any prediction ww for our small commute times dataset.

Rsq(w)=15((72w)2+(90w)2+(61w)2+(85w)2+(92w)2)R_\text{sq}(w) = \frac{1}{5} \left( (72 - w)^2 + (90 - w)^2 + (61 - w)^2 + (85 - w)^2 + (92 - w)^2 \right)

The function RsqR_\text{sq} takes in any prediction ww and outputs the mean squared error of that ww. We’re searching for the value of ww that makes Rsq(w)R_\text{sq}(w) as small as possible, as that would correspond to the ww that makes the best possible predictions, for our humble constant model.

Where did the letter RR come from? It stands for risk, as in “empirical risk”. I’ll speak more on this soon. For now, remember that:

  • LL always refers to loss for a single data point.
  • RR always refers to average loss across an entire dataset.

What does Rsq(w)R_\text{sq}(w) actually look like, if we were to plot it? It is the sum of 5 quadratic functions – namely, 15(72w)2\frac{1}{5}(72 - w)^2, 15(90w)2\frac{1}{5}(90 - w)^2, and so on – and so it’s a quadratic function too, and looks like a parabola.

Rsq(w)=15((72w)2+(90w)2+(61w)2+(85w)2+(92w)2)R_\text{sq}(w) = \frac{1}{5} \left( (72 - w)^2 + (90 - w)^2 + (61 - w)^2 + (85 - w)^2 + (92 - w)^2 \right)
Image produced in Jupyter

The question is, though, what is the ww-value of the vertex of this parabola? That is, which ww minimizes Rsq(w)R_\text{sq}(w)?

Before we find the answer, let’s cast our problem in more general terms, so that the answer is applicable to any dataset. Suppose we have a dataset of nn actual commute times, y1,y2,,yny_1, y_2, \ldots, y_n. Our goal is to find the ww that minimizes:

Rsq(w)=1n((y1w)2+(y2w)2...+(ynw)2)=1ni=1n(yiw)2\begin{aligned} R_\text{sq}(w) &= \frac{1}{n} \left((y_1- w)^2 + (y_2 - w)^2 ... + (y_n - w)^2 \right) \\ &= \frac{1}{n} \sum_{i=1}^n (y_i- w)^2 \end{aligned}

While it looks like there are many variables in this equation, we know the actual values in the dataset, so we can treat y1,y2,,yny_1, y_2, \ldots, y_n as constants. The only true variable is ww.

How do we minimize Rsq(w)R_\text{sq}(w)? There are a few approaches. We’ll use a calculus-based approach here, though in Homework 1 you’ll look at an alternative approach. For a refresher on the relevant calculus ideas, see Chapter 0.2.

Rsq(w)R_\text{sq}(w) is a function of a single variable, ww. To minimize a function of a single variable, we should:

  1. Take the derivative of Rsq(w)R_\text{sq}(w) with respect to ww.
  2. Set the derivative equal to 0 and solve for ww.
  3. Verify that the second derivative at the critical point is positive.

Let’s go through these steps one by one.

Step 1: Take the derivative of Rsq(w)R_\text{sq}(w) with respect to ww

Rsq(w)=1ni=1n(yiw)2ddwRsq(w)=ddw(1ni=1n(yiw)2)\begin{align*}R_\text{sq}(w) &= \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2 \\ \frac{\text{d}}{\text{d}w}R_\text{sq}(w) &= \frac{\text{d}}{\text{d}w}\left(\frac{1}{n} \sum_{i = 1}^n (y_i - w)^2\right) \end{align*}

Remember that constants can be pulled out of derivatives, e.g. the derivative of 2f(x)2 f(x) is 2 times the derivative of f(x)f(x).

ddwRsq(w)=1n(ddwi=1n(yiw)2)\begin{align*} \frac{\text{d}}{\text{d}w}R_\text{sq}(w) &= \frac{1}{n} \left( \frac{\text{d}}{\text{d}w} \sum_{i = 1}^n (y_i - w)^2 \right)\end{align*}

From here, we’ll use the fact that the derivative of a sum is the sum of derivatives, to “push” the derivative operator inside the sum.

ddwRsq(w)=1ni=1nddw(yiw)2\begin{align*} \frac{\text{d}}{\text{d}w}R_\text{sq}(w) &= \frac{1}{n} \sum_{i = 1}^n \frac{\text{d}}{\text{d}w} (y_i - w)^2 \end{align*}

What is ddw(yiw)2\frac{\text{d}}{\text{d}w} (y_i - w)^2? Try and work it out on your own, then check the solution below.

Using that result, we have:

ddwRsq(w)=1ni=1n(2(yiw))\begin{align*} \frac{\text{d}}{\text{d}w}R_\text{sq}(w) &= \frac{1}{n} \sum_{i = 1}^n (-2(y_i - w)) \end{align*}

Finally, we’ll pull the constant of -2 out of the sum.

ddwRsq(w)=2ni=1n(yiw)\boxed{\begin{align*} \frac{\text{d}}{\text{d}w}R_\text{sq}(w) &= - \frac{2}{n} \sum_{i = 1}^n (y_i - w) \end{align*}}

We could simplify this further, but this form will do just fine. Don’t forget, though, that the expression on the right side is a function of ww.

Step 2: Set the derivative equal to 0 and solve for ww

2ni=1n(yiw)=0- \frac{2}{n} \sum_{i = 1}^n (y_i - w) = 0

First, we’ll multiply both sides by n2-\frac{n}{2} to get rid of the fraction.

i=1n(yiw)=0\sum_{i = 1}^n (y_i - w) = 0

Separating the sum into two parts gives us:

i=1nyii=1nw=0\sum_{i = 1}^n y_i - \sum_{i = 1}^n w = 0

i=1nyi\displaystyle \sum_{i = 1}^n y_i can’t be broken down much further. But, i=1nw\displaystyle \sum_{i = 1}^n w is the sum of nn copies of ww, i.e. w+w++ww + w + \ldots + w. This is just nwnw!

i=1nyinw=0\sum_{i = 1}^n y_i - nw = 0

Adding nwnw to both sides, then dividing both sides by nn, gives us:

w=1ni=1nyi\boxed{w^* = \frac{1}{n} \sum_{i = 1}^n y_i}

The value of ww that minimizes Rsq(w)R_\text{sq}(w) is w=1ni=1nyiw^* = \frac{1}{n} \sum_{i = 1}^n y_i. Notice that I’ve called it ww^*; think of “star” as meaning “best” or “optimal”.

The formula for ww^* should look very familar. It’s the mean of y1,y2,,yny_1, y_2, \ldots, y_n!

Step 3: Verify that the second derivative at the critical point is positive

We already know that Rsq(w)R_\text{sq}(w) is a parabola, which means that its only critical point is a global minimum. But, we’ll be thorough just to set a good example.

Here, we’ll need to find the second derivative of Rsq(w)R_\text{sq}(w) with respect to ww.

d2dw2Rsq(w)=ddw(2ni=1n(yiw))=2ni=1nddw(yiw)=2ni=1n(1)=2n(n)=2\begin{align*} \frac{\text{d}^2}{\text{d}w^2}R_\text{sq}(w) &= \frac{\text{d}}{\text{d}w} \left( - \frac{2}{n} \sum_{i = 1}^n (y_i - w) \right) \\ &= - \frac{2}{n} \sum_{i = 1}^n \frac{\text{d}}{\text{d}w} (y_i - w) \\ &= - \frac{2}{n} \sum_{i = 1}^n (-1) \\ &= - \frac{2}{n} (-n) \\ &= 2 \end{align*}

The second derivative is 2 for all values of ww, including at the ww^* we found. This tells us that Rsq(w)R_\text{sq}(w) is concave opening upwards across its entire domain, so the critical point we’ve found corresponds to a global minimum.


Conclusion

What was the point of all of that algebra? To recap:

  • We decided to use the constant model, h(xi)=wh(x_i) = w, to make predictions.
  • To find the best value of ww – a model parameter • we decided to minimize mean squared error:
    Rsq(w)=1ni=1n(yiw)2R_\text{sq}(w) = \frac{1}{n} \sum_{i = 1}^n (y_i - w)^2
  • Using calculus, we found that the value of ww that minimizes Rsq(w)R_\text{sq}(w) is
    w=1ni=1nyi=Mean(y1,y2,,yn)w^* = \frac{1}{n} \sum_{i = 1}^n y_i = \text{Mean}(y_1, y_2, \ldots, y_n)

In other words, the mean minimizes mean squared error. This is a remarkable result. We use the mean all of the time in daily life, and now we’ve proven that it is optimal in some sense. It is the constant with the smallest mean squared error, no matter the dataset we’re working with.

Another name for ww^* is an optimal model parameter. In the context of our full commute times dataset, the optimal model parameter is the mean commute time. Visually, the value of w73w^* \approx 73 tells us the optimal “height” at which we should draw the constant model, h(xi)=wh(x_i) = w.

Image produced in Jupyter

Is this the best possible model? No, of course not – we’re not capturing the fact that later departure times are associated with shorter commute times. But as a first attempt at building a model, the constant model is valuable. If someone asked you how long your commutes are, saying something like “about 73 minutes” is reasonable.

What’s next?

  • In Chapter 1.3, we’ll investigate other loss functions, like absolute loss.
  • In Chapter 1.4, we’ll reintroduce the simple linear regression model, h(xi)=w0+w1xih(x_i) = w_0 + w_1 x_i, and see how to find the best values of w0w_0 and w1w_1.