Skip to content

5.9 Stochastic Gradient Descent

Nearly all of deep learning is powered by one very important algorithm: stochastic gradient descent SGD or . Stochastic gradient descent is an extension of the gradient descent algorithm introduced in Sec. 4.3

A recurring problem in machine learning is that large training sets are necessary for good generalization, but large training sets are also more computationally expensive.

The cost function used by a machine learning algorithm often decomposes as a sum over training examples of some per-example loss function. For example, the negative conditional log-likelihood of the training data can be written as

For these additive cost functions, gradient descent requires computing

The computational cost of this operation is O ( m )(m表示的是样本的个数). As the training set size grows to billions of examples, the time to take a single gradient step becomes prohibitively long.

The insight of stochastic gradient descent is that the gradient is an expectation(期望). The expectation may be approximately estimated using a small set of samples. Specifically, on each step of the algorithm, we can sample a minibatch of examples B = {x ,...,x_m} drawn uniformly(统一的) from the training set. The minibatch size m(就是我们所说的batch_size) is typically chosen to be a relatively small number of examples, ranging from 1 to a few hundred. Crucially, m is usually held fixed as the training set size m grows. We may fit a training set with billions of examples using updates computed on only a hundred examples.

The estimate(估计) of the gradient is formed as

Gradient descent in general has often been regarded as slow or unreliable. In the past, the application of gradient descent to non-convex optimization problems was regarded as foolhardy or unprincipled. Today, we know that the machine learning models described in Part work very well when trained with gradient II descent. The optimization algorithm may not be guaranteed to arrive at even a local minimum in a reasonable amount of time, but it often finds a very low value of the cost function quickly enough to be useful.