Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

8.5. Convexity

As we saw in Chapter 8.3, gradient descent is prone to getting stuck in local minima.

When I use gradient descent to minimize an empirical risk function, getting stuck at a local minimum that is not global means landing on a parameter vector w\vec w that is not actually optimal. That is a real problem.

In modern machine learning, the loss surfaces people care about are often huge and messy, and dealing with bad local minima is part of the game. I am not going to get into those fixes here. Instead, I want to focus on a special family of functions where life is much nicer: convex functions.

The following video, recorded in an earlier semester, summarizes the key ideas of this section.


Formal Definition of Convexity

Let me start with the picture, because the picture is the whole point.

If you need a refresher on what a secant line is, take a quick look at the appendix. For a scalar-to-scalar function, I want every secant line between two points on the graph to lie on or above the graph itself. If that happens no matter which two points I pick, the function has the familiar bowl-shaped behavior I want.

The figure below is the geometry I have in mind.

Loading...
Loading...

That picture turns almost directly into algebra. Suppose I pick two inputs xx and yy, and some t[0,1]t \, \in \, [0,1].

  • The input that is a fraction tt of the way from xx to yy is (1t)x+ty(1-t)x + ty.

  • The height of the secant line at that same horizontal location is (1t)f(x)+tf(y)(1-t)f(x) + t f(y).

So saying “the secant line lies above the graph” is the same as saying

f((1t)x+ty)(1t)f(x)+tf(y).f((1-t)x + ty) \le (1-t)f(x) + t f(y).

That is the formal definition in one dimension. In higher dimensions, nothing really changes: I just replace the scalar inputs xx and yy with vectors x\vec x and y\vec y, and think about the line segment connecting them.

For d=1d=1, this is exactly the secant-line picture above. For d>1d>1, it says the same thing along every line segment in the domain.

Loading...

This is one of those definitions that is much more useful than it first looks. I do not just want a fancy way to say “bowl-shaped.” I want an inequality that I can actually plug into proofs.

Once I know that every point on every line segment satisfies a weighted-average inequality, I can stop arguing from a picture and start choosing x\vec x, y\vec y, and tt strategically. That is what makes the next result work.


Local Minimums are Global Minimums

This is the payoff.

Suppose ff is convex, and suppose x\vec x^* is a local minimum of ff. I want to show that x\vec x^* is automatically a global minimum.

Take any other point z\vec z. Since x\vec x^* is a local minimum, points on the line segment from x\vec x^* toward z\vec z that stay close enough to x\vec x^* cannot have smaller function value. So for all sufficiently small t>0t > 0,

f((1t)x+tz)f(x).f((1-t)\vec x^* + t \vec z) \ge f(\vec x^*).

But convexity also gives

f((1t)x+tz)(1t)f(x)+tf(z).f((1-t)\vec x^* + t \vec z) \le (1-t)f(\vec x^*) + t f(\vec z).

Putting those together,

f(x)(1t)f(x)+tf(z).f(\vec x^*) \le (1-t)f(\vec x^*) + t f(\vec z).

Subtract (1t)f(x)(1-t)f(\vec x^*) from both sides:

tf(x)tf(z).t f(\vec x^*) \le t f(\vec z).

Since t>0t > 0, I can divide by tt and get

f(x)f(z).f(\vec x^*) \le f(\vec z).

And since z\vec z was arbitrary, x\vec x^* beats every other point in the domain. So x\vec x^* is a global minimum.

This is why convexity matters so much in optimization. It does not magically make minimization easy, but it does remove the possibility of bad local minima.

There is one important caveat, though: convex functions do not need to have global minima in the first place. A standard example is f(x)=exf(x) = e^x. It is convex, and it keeps getting smaller as I move left, but it never actually attains its infimum of 0.

Loading...

Strict Convexity

Convexity rules out bad local minima, but it does not rule out ties.

A convex function can have a completely flat bottom, in which case every point in that flat region is a local minimum and a global minimum. The example below does exactly that.

Loading...

So if I want a guarantee of a single best point, I need to strengthen the definition a little.

The only difference is that the inequality is now strict once I move away from the endpoints. Geometrically, the secant line is allowed to touch the graph at the two endpoints, but not in between.

Now I can prove the uniqueness statement I really want. Suppose a strictly convex function has a global minimum, but that there are two different minimizers, x\vec x^* and y\vec y^*. Let their common minimum value be mm.

For any t(0,1)t \in (0,1), strict convexity says

f((1t)x+ty)<(1t)f(x)+tf(y)=(1t)m+tm=m.f((1-t)\vec x^* + t \vec y^*) < (1-t)f(\vec x^*) + t f(\vec y^*) = (1-t)m + tm = m.

But that says there is a point with function value smaller than the global minimum value mm, which is impossible. So a strictly convex function can have at most one global minimum. In other words: if a global minimum exists, it is unique.

Strict Convexity and Mean Squared Error

This is exactly the distinction I want you to keep in mind for mean squared error.

Recall that

Rsq(w)=1nyXw2.R_{\mathrm{sq}}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2.

When the columns of XX are linearly independent, this risk surface curves upward in every direction, so there is a single best parameter vector. When the columns of XX are linearly dependent, there are flat directions: I can move in some directions without changing the predictions, and the minimum need not be unique.

The contour plots below show both behaviors.

Loading...

The left-hand plot is the nice case: one bowl, one bottom, one minimizer. The right-hand plot still describes a convex function, but not a strictly convex one, because there is a whole line of minimizers.

I am not proving the full criterion for RsqR_{\mathrm{sq}} here just yet. The clean explanation comes from the Hessian, and one chapter from now we will phrase that explanation using eigenvalues of XTXX^T X. But the geometry is already visible: full rank gives curvature in every direction; linear dependence creates flat directions.


Second Derivative Test

The formal definition is the ground truth, but it is not always the fastest way to check that a function is convex.

For scalar-to-scalar functions, you may already know the second derivative test from calculus: if a twice-differentiable function satisfies

d2fdx2(x)>0\frac{d^2 f}{dx^2}(x) > 0

for all xx in its domain, then ff is convex. In fact, the slightly weaker condition d2fdx2(x)0\frac{d^2 f}{dx^2}(x) \ge 0 is enough for convexity, while >0>0 points toward strict convexity.

What does this look like for vector-to-scalar functions? Now there is no single second derivative. There are many of them.

For example, if

f(x1,x2)=x12+x1x2+2x22,f(x_1, x_2) = x_1^2 + x_1 x_2 + 2x_2^2,

then the first partial derivatives are

fx1=2x1+x2,fx2=x1+4x2,\frac{\partial f}{\partial x_1} = 2x_1 + x_2, \qquad \frac{\partial f}{\partial x_2} = x_1 + 4x_2,

and the second partial derivatives are

2fx12=2,2fx1x2=1,2fx22=4.\frac{\partial^2 f}{\partial x_1^2} = 2, \qquad \frac{\partial^2 f}{\partial x_1 \partial x_2} = 1, \qquad \frac{\partial^2 f}{\partial x_2^2} = 4.

The natural thing to do is collect all of those second partial derivatives into a matrix.

For the example above,

Hf(x)=[2114],H_f(\vec x) = \begin{bmatrix} 2 & 1 \\ 1 & 4 \end{bmatrix},

which does not even depend on x\vec x.

The vector-valued second derivative test says: a twice-differentiable function ff is convex exactly when its Hessian is positive semidefinite everywhere, meaning

vTHf(x)v0for all xRd and all vRd.\vec v^T H_f(\vec x) \vec v \ge 0 \qquad \text{for all } \vec x \in \mathbb{R}^d \text{ and all } \vec v \in \mathbb{R}^d.

If the Hessian is positive definite everywhere, then I get strict convexity.

This brings us back to mean squared error. For

Rsq(w)=1nyXw2,R_{\mathrm{sq}}(\vec w) = \frac{1}{n} \lVert \vec y - X \vec w \rVert^2,

we already computed

Rsq(w)=2n(XTXwXTy).\nabla R_{\mathrm{sq}}(\vec w) = \frac{2}{n}(X^T X \vec w - X^T \vec y).

Differentiate once more, and the Hessian is

HRsq(w)=2nXTX.H_{R_{\mathrm{sq}}}(\vec w) = \frac{2}{n} X^T X.

That is a really nice conclusion: the Hessian is constant, and it is always positive semidefinite. So mean squared error is always convex. If the columns of XX are linearly independent, then XTXX^T X is positive definite, which is why RsqR_{\mathrm{sq}} becomes strictly convex and has a unique minimizer. If the columns are dependent, there is a zero-curvature direction, which is exactly the flat valley we saw above.


Aside: Tangent Hyperplanes

For scalar-to-scalar functions, there is another way to recognize convexity: a differentiable convex function always lies above each of its tangent lines.

For vector-to-scalar functions, the same story holds, except the tangent line becomes a tangent hyperplane. When the input has two coordinates, that hyperplane is literally just a plane in 3D.

The next figure shows a convex surface together with its tangent plane at one point.

Loading...

Suppose I fix a point a\vec a. The tangent hyperplane to ff at a\vec a is the linear approximation

La(x)=f(a)+f(a)T(xa).L_{\vec a}(\vec x) = f(\vec a) + \nabla f(\vec a)^T (\vec x - \vec a).

This is the vector-valued version of the tangent-line formula from calculus.

In words: the graph of a differentiable convex function lies above every tangent hyperplane.

I like this characterization because it tells me that the local linear information in the gradient is globally trustworthy. On a non-convex function, a tangent line or tangent plane can point you in a misleading direction. On a convex function, the tangent hyperplane never overshoots the graph, which is a big part of why gradient-based optimization behaves so nicely in this setting.