Autograd in C++ Frontend#
Created On: Apr 01, 2020 | Last Updated: Jan 21, 2025 | Last Verified: Not Verified
The autograd package is crucial for building highly flexible and dynamic neural
networks in PyTorch. Most of the autograd APIs in PyTorch Python frontend are also available
in C++ frontend, allowing easy translation of autograd code from Python to C++.
In this tutorial explore several examples of doing autograd in PyTorch C++ frontend. Note that this tutorial assumes that you already have a basic understanding of autograd in Python frontend. If that’s not the case, please first read Autograd: Automatic Differentiation.
Basic autograd operations#
(Adapted from this tutorial)
Create a tensor and set torch::requires_grad() to track computation with it
autox=torch::ones({2,2},torch::requires_grad()); std::cout<<x<<std::endl;
Out:
11 11 [CPUFloatType{2,2}]
Do a tensor operation:
autoy=x+2; std::cout<<y<<std::endl;
Out:
33 33 [CPUFloatType{2,2}]
y was created as a result of an operation, so it has a grad_fn.
std::cout<<y.grad_fn()->name()<<std::endl;
Out:
AddBackward1
Do more operations on y
autoz=y*y*3; autoout=z.mean(); std::cout<<z<<std::endl; std::cout<<z.grad_fn()->name()<<std::endl; std::cout<<out<<std::endl; std::cout<<out.grad_fn()->name()<<std::endl;
Out:
2727 2727 [CPUFloatType{2,2}] MulBackward1 27 [CPUFloatType{}] MeanBackward0
.requires_grad_( ... ) changes an existing tensor’s requires_grad flag in-place.
autoa=torch::randn({2,2}); a=((a*3)/(a-1)); std::cout<<a.requires_grad()<<std::endl; a.requires_grad_(true); std::cout<<a.requires_grad()<<std::endl; autob=(a*a).sum(); std::cout<<b.grad_fn()->name()<<std::endl;
Out:
false true SumBackward0
Let’s backprop now. Because out contains a single scalar, out.backward()
is equivalent to out.backward(torch::tensor(1.)).
out.backward();
Print gradients d(out)/dx
std::cout<<x.grad()<<std::endl;
Out:
4.50004.5000 4.50004.5000 [CPUFloatType{2,2}]
You should have got a matrix of 4.5. For explanations on how we arrive at this value,
please see the corresponding section in this tutorial.
Now let’s take a look at an example of vector-Jacobian product:
x=torch::randn(3,torch::requires_grad()); y=x*2; while(y.norm().item<double>()<1000){ y=y*2; } std::cout<<y<<std::endl; std::cout<<y.grad_fn()->name()<<std::endl;
Out:
-1021.4020 314.6695 -613.4944 [CPUFloatType{3}] MulBackward1
If we want the vector-Jacobian product, pass the vector to backward as argument:
autov=torch::tensor({0.1,1.0,0.0001},torch::kFloat); y.backward(v); std::cout<<x.grad()<<std::endl;
Out:
102.4000 1024.0000 0.1024 [CPUFloatType{3}]
You can also stop autograd from tracking history on tensors that require gradients
either by putting torch::NoGradGuard in a code block
std::cout<<x.requires_grad()<<std::endl; std::cout<<x.pow(2).requires_grad()<<std::endl; { torch::NoGradGuardno_grad; std::cout<<x.pow(2).requires_grad()<<std::endl; }
Out:
true true false
Or by using .detach() to get a new tensor with the same content but that does
not require gradients:
std::cout<<x.requires_grad()<<std::endl; y=x.detach(); std::cout<<y.requires_grad()<<std::endl; std::cout<<x.eq(y).all().item<bool>()<<std::endl;
Out:
true false true
For more information on C++ tensor autograd APIs such as grad / requires_grad /
is_leaf / backward / detach / detach_ / register_hook / retain_grad,
please see the corresponding C++ API docs.
Computing higher-order gradients in C++#
One of the applications of higher-order gradients is calculating gradient penalty.
Let’s see an example of it using torch::autograd::grad:
#include<torch/torch.h> automodel=torch::nn::Linear(4,3); autoinput=torch::randn({3,4}).requires_grad_(true); autooutput=model(input); // Calculate loss autotarget=torch::randn({3,3}); autoloss=torch::nn::MSELoss()(output,target); // Use norm of gradients as penalty autograd_output=torch::ones_like(output); autogradient=torch::autograd::grad({output},{input},/*grad_outputs=*/{grad_output},/*create_graph=*/true)[0]; autogradient_penalty=torch::pow((gradient.norm(2,/*dim=*/1)-1),2).mean(); // Add gradient penalty to loss autocombined_loss=loss+gradient_penalty; combined_loss.backward(); std::cout<<input.grad()<<std::endl;
Out:
-0.1042-0.06380.01030.0723 -0.2543-0.12220.00710.0814 -0.1683-0.10520.03550.1024 [CPUFloatType{3,4}]
Please see the documentation for torch::autograd::backward
(link)
and torch::autograd::grad
(link)
for more information on how to use them.
Using custom autograd function in C++#
(Adapted from this tutorial)
Adding a new elementary operation to torch::autograd requires implementing a new torch::autograd::Function
subclass for each operation. torch::autograd::Function s are what torch::autograd
uses to compute the results and gradients, and encode the operation history. Every
new function requires you to implement 2 methods: forward and backward, and
please see this link
for the detailed requirements.
Below you can find code for a Linear function from torch::nn:
#include<torch/torch.h> usingnamespacetorch::autograd; // Inherit from Function classLinearFunction:publicFunction<LinearFunction>{ public: // Note that both forward and backward are static functions // bias is an optional argument statictorch::Tensorforward( AutogradContext*ctx,torch::Tensorinput,torch::Tensorweight,torch::Tensorbias=torch::Tensor()){ ctx->save_for_backward({input,weight,bias}); autooutput=input.mm(weight.t()); if(bias.defined()){ output+=bias.unsqueeze(0).expand_as(output); } returnoutput; } statictensor_listbackward(AutogradContext*ctx,tensor_listgrad_outputs){ autosaved=ctx->get_saved_variables(); autoinput=saved[0]; autoweight=saved[1]; autobias=saved[2]; autograd_output=grad_outputs[0]; autograd_input=grad_output.mm(weight); autograd_weight=grad_output.t().mm(input); autograd_bias=torch::Tensor(); if(bias.defined()){ grad_bias=grad_output.sum(0); } return{grad_input,grad_weight,grad_bias}; } };
Then, we can use the LinearFunction in the following way:
autox=torch::randn({2,3}).requires_grad_(); autoweight=torch::randn({4,3}).requires_grad_(); autoy=LinearFunction::apply(x,weight); y.sum().backward(); std::cout<<x.grad()<<std::endl; std::cout<<weight.grad()<<std::endl;
Out:
0.53141.28071.4864 0.53141.28071.4864 [CPUFloatType{2,3}] 3.76080.91010.0073 3.76080.91010.0073 3.76080.91010.0073 3.76080.91010.0073 [CPUFloatType{4,3}]
Here, we give an additional example of a function that is parametrized by non-tensor arguments:
#include<torch/torch.h> usingnamespacetorch::autograd; classMulConstant:publicFunction<MulConstant>{ public: statictorch::Tensorforward(AutogradContext*ctx,torch::Tensortensor,doubleconstant){ // ctx is a context object that can be used to stash information // for backward computation ctx->saved_data["constant"]=constant; returntensor*constant; } statictensor_listbackward(AutogradContext*ctx,tensor_listgrad_outputs){ // We return as many input gradients as there were arguments. // Gradients of non-tensor arguments to forward must be `torch::Tensor()`. return{grad_outputs[0]*ctx->saved_data["constant"].toDouble(),torch::Tensor()}; } };
Then, we can use the MulConstant in the following way:
autox=torch::randn({2}).requires_grad_(); autoy=MulConstant::apply(x,5.5); y.sum().backward(); std::cout<<x.grad()<<std::endl;
Out:
5.5000 5.5000 [CPUFloatType{2}]
For more information on torch::autograd::Function, please see
its documentation.
Translating autograd code from Python to C++#
On a high level, the easiest way to use autograd in C++ is to have working autograd code in Python first, and then translate your autograd code from Python to C++ using the following table:
Python |
C++ |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
After translation, most of your Python autograd code should just work in C++. If that’s not the case, please file a bug report at GitHub issues and we will fix it as soon as possible.
Conclusion#
You should now have a good overview of PyTorch’s C++ autograd API. You can find the code examples displayed in this note here. As always, if you run into any problems or have questions, you can use our forum or GitHub issues to get in touch.