The whole point behind Nesterov optimization is to calculate the gradient not at the current parameter values $\theta_t$, but at $\theta_t + \beta m$, where $\beta$ is the momentum coefficient and $m$ the momentum. The update steps are the following:
$$ m \gets \beta m - \eta \nabla L (\theta + \beta m) \\ \theta \gets \theta + m $$
By looking at Tensorflow's docs the update steps are:
velocity = momentum * velocity - learning_rate * g
w = w + momentum * velocity - learning_rate * g
By looking at PyTorch's docs the update steps are:
Correct me if I am wrong, but in both cases the gradient is not calculated at a "look ahead" parameter position (as Nesterov optimization necessitates). Are these implementations approximations of the original method?
1 Answer 1
Indeed TensorFlow and PyTorch implement a variant of Nesterov momentum that approximates this behavior by adjusting the parameters based on the momentum and the current gradient, without explicitly computing the gradient at the look-ahead position. You can further read this article "Is PyTorch’s Nesterov Momentum Implementation Wrong?".
Most notably, PyTorch’s implementation evaluates the gradient at the current parameters, whereas the whole point of Nesterov momentum is to evaluate the gradient at shifted parameters... Ultimately, we will see how PyTorch’s implementation is not wrong, but rather an approximation, and speculate about the benefit of their implementation.
the big remaining question is: Why does PyTorch bother at all to reformulate Nesterov momentum from equations 3 and 4 to equations 8 and 9? One possible explanation is that the reformulation might provide some savings in the number of arithmetic operations required.
The update rules implemented in the PyTorch SGD algorithm (8, 9) are an approximation to the update rules stated in the documentation note (3, 4) after a simple change of variables. Although the "actual" parameters are easily recoverable from the current parameters at each time step, the PyTorch implementation does not make any such correction at the end of the algorithm, and so the final parameters technically remain an approximation of the "actual" final parameters.
Explore related questions
See similar questions with these tags.