Most modern supervised learning algorithms follow these same three steps, just with different models, loss functions, and techniques for optimization.
Another name given to this process is empirical risk minimization.
When using squared loss, all three of these mean the same thing:
Average squared loss.
Mean squared error.
Empirical risk.
Risk is an idea from theoretical statistics that we’ll visit in Chapter 6. It refers to the expected error of a model, when considering the probability distribution of the data. “Empirical” risk refers to risk calculated using an actual, concrete dataset, rather than a theoretical distribution. The reason we call the average loss R is precisely because it is empirical risk.
The first half of the course – and in some ways, the entire course – is focused on empirical risk minimization, and so we will make many passes through the three-step modeling recipe ourselves, with differing models and loss functions.
When we first introduced the idea of a loss function, we first started by computing the error, ei, of each prediction:
ei=yi−h(xi)
where yi is the actual value and h(xi) is the predicted value.
The issue was that some errors were positive and some were negative, and so it was hard to compare them directly. We wanted the value of the loss function to be large for bad predictions and small for good predictions.
To get around this, we squared the errors, which gave us squared loss:
Lsq(yi,h(xi))=(yi−h(xi))2
But, instead, we could have taken the absolute value of the errors. Doing so gives us absolute loss:
Labs(yi,h(xi))=∣yi−h(xi)∣
Below, I’ve visualized the absolute loss and squared loss for just a single data point.
You should notice two key differences between the two loss functions:
The absolute loss function is not differentiable when yi=h(xi). The absolute value function, f(x)=∣x∣, does not have a derivative at x=0, because its slope to the left of x=0 (-1) is different from its slope to the right of x=0 (1). For more on this idea, see Chapter 0.2.
The squared loss function grows much faster than the absolute loss function, as the prediction h(xi) gets further away from the actual value yi.
We know the optimal constant prediction, w∗, when using squared loss, is the mean. What is the optimal constant prediction when using absolute loss? The answer is not still the mean; rather, the answer reflects some of these differences between squared loss and absolute loss.
Let’s find that new optimal constant prediction, w∗, by revisiting the three-step modeling recipe.
Choose a model.
We’ll stick with the constant model, h(xi)=w.
Choose a loss function.
We’ll use absolute loss:
Labs(yi,h(xi))=∣yi−h(xi)∣
For the constant model, since h(xi)=w, we have:
Labs(yi,w)=∣yi−w∣
Minimize average loss to find optimal model parameters.
The average loss – also known as mean absolute error here – is:
Rabs(w)=n1i=1∑n∣yi−w∣
In Chapter 1.2, we minimized Rsq(w)=n1i=1∑n(yi−w)2 by taking the derivative of Rsq(w) with respect to w and setting it equal to 0. That will be more challenging in the case of Rabs(w), because the absolute value function is not differentiable when its input is 0, as we just discussed.
We need to minimize the mean absolute error, Rabs(w), for the constant model, h(xi)=w, but we have to address the fact that Rabs(w) is not differentiable across its entire domain.
I think it’ll help to visualize what Rabs(w) looks like. To do so, let’s reintroduce the small dataset of 5 values we used in Chapter 1.2.
y1=72,y2=90,y3=61,y4=85,y5=92
Then, Rabs(w) is:
Rabs(w)=51(∣72−w∣+∣90−w∣+∣61−w∣+∣85−w∣+∣92−w∣)
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
df = pd.read_csv('data/commute-times.csv')
f = lambda h: (abs(72-h) + abs(90-h) + abs(61-h) + abs(85-h) + abs(92-h)) / 5
x = np.linspace(50, 110, 100)
y = np.array([f(h) for h in x])
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode='lines',
name='Data',
line=dict(color='#D81B60', width=4)
)
)
# Add vertical dotted lines at the 5 data points
data_points = [72, 90, 61, 85, 92]
for pt in data_points:
y_val = f(pt)
fig.add_trace(
go.Scatter(
x=[pt, pt],
y=[0, y_val-0.5],
mode='lines',
line=dict(color='gray', width=2, dash='dot'),
showlegend=False
)
)
# Add annotation halfway up the dotted line
halfway_y = (y_val-0.5) / 2 if pt != max(data_points) else y_val-5
fig.add_annotation(
x=pt,
y=halfway_y,
text=f"w={pt}",
showarrow=False,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="gray",
size=12
),
bgcolor="white",
xanchor="center",
yanchor="middle"
)
# Annotate (72, f(72)) with an arrow and text
fig.add_annotation(
x=72,
y=f(72),
text=r"slope undefined",
showarrow=True,
arrowhead=2,
ax=40,
ay=-35,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black",
size=12
),
bgcolor="white",
bordercolor="black"
)
fig.update_xaxes(
showticklabels=True,
showgrid=True,
gridwidth=1,
gridcolor='#f0f0f0',
showline=True,
linecolor="black",
linewidth=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor='#f0f0f0',
showline=True,
linecolor="black",
linewidth=1,
range=[0, 30]
)
fig.update_layout(
xaxis_title=r'$w$',
yaxis_title=r'$R_\text{abs}(w)$',
plot_bgcolor='white',
paper_bgcolor='white',
margin=dict(l=60, r=60, t=60, b=60),
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black"
),
showlegend=False
)
fig.show(renderer='png', scale=4)
This is a piecewise linear function. Where are the “bends” in the graph? Precisely where the data points, y1,y2,…,y5, are! Its at exactly these points where Rabs(w) is not differentiable. At each of those points, the slope of the line segment approaching from the left is different from the slope of the line segment approaching from the right, and for a function to be differentiable at a point, the slope of the tangent line must be the same when approaching from the left and the right.
The graph of Rabs(w) above, while not differentiable at any of the data points, still shows us something about the optimal constant prediction. If there is a bend at each data point, and at each bend the slope increases – that is, becomes more positive – then the optimal constant prediction seems to be in the middle, when the slope goes from negative to positive. I’ll make this more precise in a moment.
For now, you might notice the value of w that minimizes the graph of Rabs(w) above is a familiar summary statistic, but not the mean. I won’t spell it out just yet, since I’d like for you to reason about it yourself.
Let me show you one more graph of Rabs(w), but this time, in a case where there are an even number of data points. Suppose we have a sixth point, y6=78.
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
df = pd.read_csv('data/commute-times.csv')
f = lambda h: (abs(72-h) + abs(90-h) + abs(61-h) + abs(85-h) + abs(92-h) + abs(78 - h)) / 6
x = np.linspace(50, 110, 100)
y = np.array([f(h) for h in x])
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode='lines',
name='Data',
line=dict(color='#D81B60', width=4)
)
)
data_points = [72, 90, 61, 85, 92, 78]
for pt in data_points:
# Find the y value on the curve at this x (pt)
y_val = f(pt)
fig.add_trace(
go.Scatter(
x=[pt, pt],
y=[0, y_val-0.5],
mode='lines',
line=dict(color='gray', width=2, dash='dot'),
showlegend=False
)
)
halfway_y = (y_val-0.5) / 2 if pt != max(data_points) else y_val-5
fig.add_annotation(
x=pt,
y=halfway_y,
text=f"w={pt}",
showarrow=False,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="gray",
size=12
),
bgcolor="white",
xanchor="center",
yanchor="middle"
)
# Annotate the flat region between the middle two points with a box that says slope = 0
sorted_points = sorted(data_points)
mid1, mid2 = sorted_points[2], sorted_points[3]
# Find the y value at the midpoint for annotation placement
mid_x = (mid1 + mid2) / 2
mid_y = f(mid_x)
# Add annotation box
fig.add_annotation(
x=mid_x,
y=mid_y+1.3,
text="slope = 0",
showarrow=False,
font=dict(size=12, color="black"),
bordercolor="black",
bgcolor="white",
borderwidth=1,
# borderpad=4
)
fig.update_xaxes(
showticklabels=True,
showgrid=True,
gridwidth=1,
gridcolor='#f0f0f0',
showline=True,
linecolor="black",
linewidth=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor='#f0f0f0',
showline=True,
linecolor="black",
linewidth=1,
range=[0, 30]
)
fig.update_layout(
xaxis_title=r'$w$',
yaxis_title=r'$R_\text{abs}(w)$',
plot_bgcolor='white',
paper_bgcolor='white',
margin=dict(l=60, r=60, t=60, b=60),
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black"
),
showlegend=False
)
fig.show(renderer='png', scale=4)
This graph is broken into 7 segments, with 6 bends (one per data point). Between the 3rd and 4th bends – that is, the 3rd and 4th data points – the slope is 0, and all values in that interval minimize Rabs(w). So, it seems that the value of w∗ doesn’t have to be unique!
From the two graphs above, you may have a clear picture of what the optimal constant prediction, w∗, is. But, to avoid relying too heavily on visual intuition and just a single set of example data points, let’s try and minimize Rabs(w) mathematically, for an arbitrary set of data points.
To be clear, the goal is to minimize:
Rabs(w)=n1i=1∑n∣yi−w∣
To do so, we’ll take the derivative of Rabs(w) with respect to w and set it equal to 0.
dwdRabs(w)=dwd(n1i=1∑n∣yi−w∣)
Using the familiar facts that the derivative of a sum is the sum of the derivatives, and that constants can be pulled out of the derivative, we have:
dwdRabs(w)=n1i=1∑ndwd∣yi−w∣
Here’s where the challenge comes in. What is dwd∣yi−w∣?
Let’s start by remembering the derivative of the absolute value function. The absolute value function itself can be thought of as a piecewise function:
∣x∣={x−xx≥0x<0
Note that the x=0 case can either lumped in either the x or −x case, since 0 and -0 are both 0.
Using this logic, I’ll write ∣yi−w∣ as a piecewise of w:
∣yi−w∣={yi−ww−yiw≤yiw>yi
I have written the two conditions with w on the left, since it’s easier to think in terms of w in my mind, but this means that the inequalities are flipped relative to how I presented them in the definition of ∣x∣. Remember, ∣yi−w∣ is a function of w; we’re treating yi as some constant. If it helps, replace every instance of yi with a concrete number, like 5, then reason through the resulting graph.
import numpy as np
import plotly.graph_objects as go
h = np.linspace(-10, 10, 400)
y_i = 0
abs_loss = np.abs(y_i - h)
fig = go.Figure(
data=go.Scatter(
x=h,
y=abs_loss,
mode='lines',
line=dict(color='#004d40', width=3),
showlegend=False
)
)
# Annotate left side (slope = 1)
fig.add_annotation(
x=-7,
y=4,
text="slope = -1",
showarrow=False,
font=dict(size=12, color="black"),
align="center",
bordercolor="black"
)
# Annotate right side (slope = -1)
fig.add_annotation(
x=7,
y=4,
text="slope = 1",
showarrow=False,
font=dict(size=12, color="black"),
align="center",
bordercolor="black",
)
# Annotate y_i at x=0
fig.add_annotation(
x=0,
y=-0.5,
text="$y_i$",
showarrow=False,
font=dict(size=13, color="black"),
yanchor="top"
)
# Annotate "w ->" towards the right of the x-axis
fig.add_annotation(
x=9,
y=-0.5,
text="w →",
showarrow=False,
font=dict(size=13, color="black"),
yanchor="top"
)
fig.update_xaxes(showticklabels=False, title_text=None)
fig.update_yaxes(showticklabels=False, title_text=None, range=[-0.75, 10])
fig.update_layout(
width=350,
height=300,
plot_bgcolor='white',
paper_bgcolor='white',
margin=dict(l=40, r=40, t=40, b=40),
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black"
),
)
fig.show(renderer='png', scale=4)
Now we can take the derivative of each piece:
∣yi−w∣=⎩⎨⎧−1undefined1w<yiw=yiw>yi
Great. Remember, this is the derivative of the absolute loss for a single data point. But our main objective is to find the derivative of the average absolute loss, Rabs(w). Using this piecewise definition of dwd∣yi−w∣, we have:
At any point where w=yi, for any value of i, dwdRabs(w) is undefined. (This makes any point where w=yi a critical point.) Let’s exclude those values of w from our consideration. In all other cases, the sum in the expression above involves only two possible values: -1 and 1.
The sum adds -1 for all data points greater than w, i.e. where w<yi.
The sum adds 1 for all data points less than w, i.e. where w>yi.
Using some creative notation, I’ll re-write dwdRabs(w) as:
dwdRabs(w)=n1(w<yi∑−1+w>yi∑1)
The sum w<yi∑−1 is the sum of -1 for all data points greater than w, so perhaps a more intuitive way to write it is:
w<yi∑−1=add once per data point to the right of w(−1)+(−1)+…+(−1)=−(# right of w)
Equivalently, w>yi∑1=(# left of w), meaning that:
dwdRabs(w)=n1(−(# right of w)+(# left of w))=n# left of w−# right of w
This boxed form gives us the slope of Rabs(w), for any point w that is not an original data point. To put it in perspective, let’s revisit the first graph we saw in this section, where we plotted Rabs(w) for the dataset:
y1=72,y2=90,y3=61,y4=85,y5=92
Rabs(w)=51(∣72−w∣+∣90−w∣+∣61−w∣+∣85−w∣+∣92−w∣)
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
df = pd.read_csv('data/commute-times.csv')
f = lambda h: (abs(72-h) + abs(90-h) + abs(61-h) + abs(85-h) + abs(92-h)) / 5
x = np.linspace(50, 110, 100)
y = np.array([f(h) for h in x])
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode='lines',
name='Data',
line=dict(color='#D81B60', width=4)
)
)
# Add vertical dotted lines at the 5 data points
data_points = [72, 90, 61, 85, 92]
for pt in data_points:
y_val = f(pt)
fig.add_trace(
go.Scatter(
x=[pt, pt],
y=[0, y_val-0.5],
mode='lines',
line=dict(color='gray', width=2, dash='dot'),
showlegend=False
)
)
# Add annotation halfway up the dotted line
halfway_y = (y_val-0.5) / 2 if pt != max(data_points) else y_val-5
fig.add_annotation(
x=pt,
y=halfway_y,
text=f"w={pt}",
showarrow=False,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="gray",
size=12
),
bgcolor="white",
xanchor="center",
yanchor="middle"
)
# Annotate (65, f(65)) with an arrow and the specified text
fig.add_annotation(
x=65,
y=f(65),
text=r"$\text{slope} = \frac{\text{L} - \text{R}}{5} = \frac{1-4}{5} = -\frac{3}{5}$",
showarrow=True,
arrowhead=2,
ax=90, # Move the annotation box further to the right
ay=-45,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black",
size=16
),
bgcolor="white",
bordercolor="black",
borderwidth=1,
borderpad=12 # Keep the box tall
)
fig.update_xaxes(
showticklabels=True,
showgrid=True,
gridwidth=1,
gridcolor='#f0f0f0',
showline=True,
linecolor="black",
linewidth=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor='#f0f0f0',
showline=True,
linecolor="black",
linewidth=1,
range=[0, 30]
)
fig.update_layout(
xaxis_title=r'$w$',
yaxis_title=r'$R_\text{abs}(w)$',
plot_bgcolor='white',
paper_bgcolor='white',
margin=dict(l=60, r=60, t=60, b=60),
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black"
),
showlegend=False
)
fig.show(renderer='png', scale=4)
Now that we have a formula for dwdRabs(w), the easy thing to claim is that we could set it to 0 and solve for w. Doing so would give us:
n# left of w−# right of w=0
Which yields the condition:
# left of w=# right of w
The optimal value of w is the one that satisfies this condition, and that’s precisely the median of the data, as you may have noticed earlier.
This logic isn’t fully rigorous, however, because the formula for dwdRabs(w) is only valid for w’s that aren’t original data points, and the median – if we have an odd number of data points – is indeed one of the original data points. In the graph above, there is never a point where the slope is 0.
To fully justify why the median (in the case of an odd number of data points) minimizes mean absolute error, I’ll say that:
If w is just to the left of the median, there are more points to the right of w than to the left of w, so (# left of w)<(# right of w) and n(# left of w)−(# right of w) is negative.
If w is just to the right of the median, there are more points to the left of w than to the right of w, so (# left of w)>(# right of w) and n(# left of w)−(# right of w) is positive.
So even though the slope is undefined at the median, we know it is a point at which the sign of the derivative switches from negative to positive, and as we discussed in Chapter 0.2, this sign change implies at least a local minimum.
To summarize:
If n is odd, the median minimizes mean absolute error.
If n is even, any value between the middle two values (when sorted) minimizes mean absolute error.
We’ve just made a second pass through the three-step modeling recipe:
Choose a model.
h(xi)=w
Choose a loss function.
Rabs(w)=n1∑i=1n∣yi−w∣
Minimize average loss to find optimal model parameters.
Let’s compare the behavior of the mean and median, and reason about how their differences in behavior are related to the differences in their loss functions.
We can see that compared to the median, the mean is being “pulled” towards the outlier. We often say the median is robust to outliers. This makes sense when you consider that the mean and median balance different aspects of the distribution. The median is the point where the number of points to the left and right are equal, and the mean is where the sums to the left and right are equal. This gives us a trade-off between the loss functions: do we want to choose one with a minimizer that’s easier to calculate, or one that’s less sensitive to outliers?
Now that we’ve seen absolute and squared loss, let’s explore minimizing empirical risk for other loss functions. For any p≥1, define the Lp loss as follows:
Lp(yi,w)=∣yi−w∣p
Note that we need the absolute value to avoid negative loss. Given an Lp loss, the corresponding empirical risk is:
Rp(w)=n1i=1∑n∣yi−w∣p
When p=1, or absolute loss, ∗w=Median(y1,y2,...,yn).
When p=2, or squared loss, w∗=Mean(y1,y2,...,yn).
What about when p→∞, how can we find the minimizer w∗? Let’s use visualize the dataset from before, 61,72,85,90,292, on a different graph. The x axis is p, and the y axis is the optimal constant prediction w∗ for that Lp loss function.
import numpy as np
import plotly.graph_objects as go
# Corrected dataset
y = np.array([61, 72, 85, 90, 292])
# Range of p values
p_values = np.linspace(1, 20, 100)
w_stars = []
# For each p, numerically find the minimizer of average L_p loss
for p in p_values:
ws = np.linspace(np.min(y), np.max(y), 1000)
risks = np.array([np.mean(np.abs(y - w)**p) for w in ws])
min_idx = np.argmin(risks)
w_stars.append(ws[min_idx])
# For p = infinity, the minimizer is the midpoint of min and max
w_star_inf = (np.min(y) + np.max(y)) / 2
# Calculate median and mean for labels
median_val = np.median(y)
mean_val = np.mean(y)
# Plot
fig = go.Figure()
fig.add_trace(go.Scatter(
x=p_values,
y=w_stars,
mode='lines',
line=dict(color='#3D81F6', width=3),
name='Minimizer $w^*$',
showlegend=False
))
# Add horizontal line for p = infinity minimizer
fig.add_trace(go.Scatter(
x=[p_values[0], p_values[-1]],
y=[w_star_inf, w_star_inf],
mode='lines',
line=dict(color='red', width=2, dash='dash'),
name='$w^*$ as $p \\to \\infty$',
showlegend=False
))
# Add labels at p=1 (median) and p=2 (mean)
fig.add_annotation(
x=1,
y=median_val,
text=f"(1, {median_val}), median",
showarrow=True,
arrowhead=2,
ax=100,
ay=0,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black",
size=12
),
bgcolor="white",
bordercolor="black",
borderwidth=1
)
fig.add_annotation(
x=2,
y=mean_val,
text=f"(2, {mean_val}), mean",
showarrow=True,
arrowhead=2,
ax=100,
ay=0,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black",
size=12
),
bgcolor="white",
bordercolor="black",
borderwidth=1
)
fig.update_layout(
xaxis=dict(
title='$p$',
showticklabels=True,
visible=True,
ticks='outside',
tickcolor='black',
ticklen=8,
gridcolor='#f0f0f0',
linecolor="black",
linewidth=1,
range=[1, 20], # <-- Set x-axis to start from 1
),
yaxis=dict(
title='$w^*$',
showticklabels=True,
visible=True,
range=[np.min(y)-10, np.max(y)+10],
gridcolor='#f0f0f0',
linecolor="black",
linewidth=1,
),
plot_bgcolor='white',
paper_bgcolor='white',
margin=dict(l=60, r=60, t=60, b=60),
width=700,
font=dict(
family="Palatino Linotype, Palatino, serif",
color="black"
),
)
fig.show(renderer='png', scale=3)
As p→∞, w∗ approaches 176.5, which is actually the midpoint of the minimum and maximium values of the dataset. We call this the midrange, and it’s the best prediction when the measure of a “good” prediction is not being too far from any point in the dataset.
Let’s consider another loss funtion, this time 0-1 loss:
L0,1(yi,w)={01yi=wyi=w
The corresponding empirical risk is:
R0,1(w)=n1i=1∑nL0,1(yi,w)
The empirical risk is the proportion of points not equal to the prediction w. To minimize risk, we want there to be as many yi’s as possible which are equal to our prediction, so w∗ is the mode of the dataset.
To recap, choosing a loss function is important because it determines what the best prediction will be! Each optimal prediciton w∗ is a different summary statistic that measures the center of the dataset.
We know that the w∗ which minimizes empirical risk is some measure of the center of the dataset, but what does this tell us about the minimum risk itself? Consider the empirical risk for squared loss:
Rsq(w)=n1i=1∑n(yi−w)2
Rsq(w) is minimized when w∗ is the mean, so let’s try plugging that value back into Rsq to find the minimum value:
The minimum value comes from taking each data point’s deviation from the mean of the dataset, squaring them, and averaging them. There’s a special name for this average, and it’s called variance. You may be familiar with its square root, known as standard deviation. Both of these are used to measure the distance between data and the mean of the dataset.
Let’s try the same process on absolute loss next:
Rabs(w)=n1i=1∑n∣yi−w∣
We’ll plug w∗=Median(y1,y2,...,yn) into Rabs(w):
Our minimum value is the mean absolute deviation from the median. Similarily to variance, it also measures the average distance from each data point to a center, this time the median. Both of these are a way of measuring the spread in our dataset, using their respective centers.
What if we try using 0-1 loss?
R0,1(w)=n1i=1∑n{01yi=wyi=w
The minimizer is the mode of the dataset, so the minimum value is the proportion of values not equal to a mode of the dataset. However, this doesn’t tell us much about how spread out the data is. If R0,1 is higher, that only means that there’s less data clustered at exactly a mode. The mode is a very basic way to measure the center, so it follows that the corresponding measure of spread is also basic and rather informative.
Ultimately, choosing a loss function comes down to understanding which measures of center and spread are most important for your prediction problem and dataset.