Note
Go to the end to download the full example code.
(beta) Building a Simple CPU Performance Profiler with FX#
Created On: Mar 04, 2021 | Last Updated: Jul 14, 2025 | Last Verified: Not Verified
Author: James Reed
In this tutorial, we are going to use FX to do the following:
Capture PyTorch Python code in a way that we can inspect and gather statistics about the structure and execution of the code
Build out a small class that will serve as a simple performance "profiler", collecting runtime statistics about each part of the model from actual runs.
For this tutorial, we are going to use the torchvision ResNet18 model for demonstration purposes.
importtorch importtorch.fx importtorchvision.modelsasmodels rn18 = models.resnet18 () rn18.eval ()
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) )
Now that we have our model, we want to inspect deeper into its performance. That is, for the following invocation, which parts of the model are taking the longest?
input = torch.randn (5, 3, 224, 224) output = rn18(input)
A common way of answering that question is to go through the program source, add code that collects timestamps at various points in the program, and compare the difference between those timestamps to see how long the regions between the timestamps take.
That technique is certainly applicable to PyTorch code, however it would be nicer if we didn’t have to copy over model code and edit it, especially code we haven’t written (like this torchvision model). Instead, we are going to use FX to automate this "instrumentation" process without needing to modify any source.
First, let’s get some imports out of the way (we will be using all of these later in the code).
importstatistics,tabulate,time fromtypingimport Any, Dict, List fromtorch.fximport Interpreter
Note
tabulate
is an external library that is not a dependency of PyTorch.
We will be using it to more easily visualize performance data. Please
make sure you’ve installed it from your favorite Python package source.
Capturing the Model with Symbolic Tracing#
Next, we are going to use FX’s symbolic tracing mechanism to capture the definition of our model in a data structure we can manipulate and examine.
traced_rn18 = torch.fx.symbolic_trace (rn18) print(traced_rn18.graph )
graph(): %x : torch.Tensor [num_users=1] = placeholder[target=x] %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {}) %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {}) %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {}) %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {}) %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {}) %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {}) %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {}) %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {}) %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {}) %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {}) %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {}) %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {}) %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {}) %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {}) %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {}) %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {}) %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {}) %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {}) %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {}) %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {}) %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {}) %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {}) %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {}) %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {}) %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {}) %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {}) %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {}) %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {}) %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {}) %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {}) %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {}) %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {}) %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {}) %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {}) %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {}) %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {}) %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {}) %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {}) %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {}) %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {}) %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {}) %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {}) %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {}) %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {}) %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {}) %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {}) %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {}) %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {}) %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {}) %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {}) %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {}) %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {}) %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {}) %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {}) %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {}) %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {}) %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {}) %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {}) %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {}) %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {}) %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {}) %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {}) %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {}) %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {}) %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {}) %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) return fc
This gives us a Graph representation of the ResNet18 model. A Graph
consists of a series of Nodes connected to each other. Each Node
represents a call-site in the Python code (whether to a function,
a module, or a method) and the edges (represented as args
and kwargs
on each node) represent the values passed between these call-sites. More
information about the Graph representation and the rest of FX’s APIs ca
be found at the FX documentation https://pytorch.org/docs/master/fx.html.
Creating a Profiling Interpreter#
Next, we are going to create a class that inherits from torch.fx.Interpreter
.
Though the GraphModule
that symbolic_trace
produces compiles Python code
that is run when you call a GraphModule
, an alternative way to run a
GraphModule
is by executing each Node
in the Graph
one by one. That is
the functionality that Interpreter
provides: It interprets the graph node-
by-node.
By inheriting from Interpreter
, we can override various functionality and
install the profiling behavior we want. The goal is to have an object to which
we can pass a model, invoke the model 1 or more times, then get statistics about
how long the model and each part of the model took during those runs.
Let’s define our ProfilingInterpreter
class:
classProfilingInterpreter(Interpreter ): def__init__(self, mod : torch.nn.Module ): # Rather than have the user symbolically trace their model, # we're going to do it in the constructor. As a result, the # user can pass in any ``Module`` without having to worry about # symbolic tracing APIs gm = torch.fx.symbolic_trace (mod) super().__init__(gm) # We are going to store away two things here: # # 1. A list of total runtimes for ``mod``. In other words, we are # storing away the time ``mod(...)`` took each time this # interpreter is called. self.total_runtime_sec : List[float] = [] # 2. A map from ``Node`` to a list of times (in seconds) that # node took to run. This can be seen as similar to (1) but # for specific sub-parts of the model. self.runtimes_sec : Dict[torch.fx.Node , List[float]] = {} ###################################################################### # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run`` # method is the top-level entry point for execution of the model. We will # want to intercept this so that we can record the total runtime of the # model. defrun(self, *args) -> Any: # Record the time we started running the model t_start = time.time() # Run the model by delegating back into Interpreter.run() return_val = super().run(*args) # Record the time we finished running the model t_end = time.time() # Store the total elapsed time this model execution took in the # ``ProfilingInterpreter`` self.total_runtime_sec.append(t_end - t_start) return return_val ###################################################################### # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each # time it executes a single node. We will intercept this so that we # can measure and record the time taken for each individual call in # the model. defrun_node(self, n : torch.fx.Node ) -> Any: # Record the time we started running the op t_start = time.time() # Run the op by delegating back into Interpreter.run_node() return_val = super().run_node(n) # Record the time we finished running the op t_end = time.time() # If we don't have an entry for this node in our runtimes_sec # data structure, add one with an empty list value. self.runtimes_sec.setdefault(n, []) # Record the total elapsed time for this single invocation # in the runtimes_sec data structure self.runtimes_sec[n].append(t_end - t_start) return return_val ###################################################################### # Finally, we are going to define a method (one which doesn't override # any ``Interpreter`` method) that provides us a nice, organized view of # the data we have collected. defsummary(self, should_sort : bool = False) -> str: # Build up a list of summary information for each node node_summaries : List[List[Any]] = [] # Calculate the mean runtime for the whole network. Because the # network may have been called multiple times during profiling, # we need to summarize the runtimes. We choose to use the # arithmetic mean for this. mean_total_runtime = statistics.mean(self.total_runtime_sec) # For each node, record summary statistics for node, runtimes in self.runtimes_sec.items(): # Similarly, compute the mean runtime for ``node`` mean_runtime = statistics.mean(runtimes) # For easier understanding, we also compute the percentage # time each node took with respect to the whole network. pct_total = mean_runtime / mean_total_runtime * 100 # Record the node's type, name of the node, mean runtime, and # percent runtime. node_summaries.append( [node.op, str(node), mean_runtime, pct_total]) # One of the most important questions to answer when doing performance # profiling is "Which op(s) took the longest?". We can make this easy # to see by providing sorting functionality in our summary view if should_sort: node_summaries.sort(key=lambda s: s[2], reverse=True) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers : List[str] = [ 'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime' ] return tabulate.tabulate(node_summaries, headers=headers)
Note
We use Python’s time.time
function to pull wall clock
timestamps and compare them. This is not the most accurate
way to measure performance, and will only give us a first-
order approximation. We use this simple technique only for the
purpose of demonstration in this tutorial.
Investigating the Performance of ResNet18#
We can now use ProfilingInterpreter
to inspect the performance
characteristics of our ResNet18 model;
interp = ProfilingInterpreter (rn18) interp.run (input) print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime ------------- --------------------- --------------------- ------------------- call_module maxpool 0.00458145 8.15724 call_module conv1 0.00455642 8.11266 call_module layer4_0_conv2 0.00312781 5.56905 call_module layer1_0_conv1 0.00304317 5.41835 call_module layer4_1_conv1 0.00294089 5.23624 call_module layer4_1_conv2 0.00291324 5.18699 call_module layer1_0_conv2 0.00278592 4.96031 call_module layer1_1_conv2 0.00265574 4.72853 call_module layer1_1_conv1 0.0025475 4.53581 call_module layer2_0_conv2 0.00242448 4.31676 call_module layer2_1_conv2 0.00233269 4.15333 call_module layer3_1_conv2 0.00219154 3.90202 call_module layer2_1_conv1 0.00214982 3.82774 call_module layer3_0_conv2 0.00213575 3.80269 call_module layer3_1_conv1 0.00207639 3.69699 call_module layer4_0_conv1 0.00192642 3.42998 call_module bn1 0.00134301 2.39122 call_module layer3_0_conv1 0.00125623 2.2367 call_module layer2_0_conv1 0.00122809 2.18661 call_module layer2_0_downsample_0 0.000622749 1.1088 call_module layer3_0_downsample_0 0.000458956 0.817167 call_module layer4_0_downsample_0 0.000439644 0.782782 call_function add 0.000433683 0.77217 call_function add_1 0.000389099 0.692788 call_module layer1_0_bn1 0.000309706 0.551428 call_module layer1_1_bn2 0.00030756 0.547608 call_module layer1_0_bn2 0.000283003 0.503884 call_module relu 0.000269175 0.479263 call_function add_3 0.000226974 0.404126 call_module fc 0.000194788 0.346818 call_module layer2_1_bn2 0.00017643 0.314132 call_module layer2_0_bn1 0.000175953 0.313283 call_module layer1_1_bn1 0.000154257 0.274653 call_module layer1_0_relu 0.000147343 0.262342 call_module layer2_0_downsample_1 0.000122786 0.218619 call_module avgpool 0.000118732 0.211402 call_module layer3_1_bn1 0.000112057 0.199516 call_module layer3_1_bn2 0.000111341 0.198243 call_module layer3_0_bn2 9.05991e-05 0.161311 call_module layer1_0_relu_1 8.70228e-05 0.154943 call_module layer4_0_bn2 8.63075e-05 0.15367 call_module layer2_0_bn2 8.58307e-05 0.152821 call_module layer4_1_bn2 8.55923e-05 0.152396 call_module layer4_1_bn1 8.27312e-05 0.147302 call_module layer1_1_relu_1 8.13007e-05 0.144755 call_module layer4_0_downsample_1 8.05855e-05 0.143482 call_module layer2_1_bn1 7.98702e-05 0.142208 call_function add_2 7.72476e-05 0.137539 call_function add_5 7.4625e-05 0.132869 call_module layer1_1_relu 7.24792e-05 0.129049 call_module layer4_0_bn1 7.24792e-05 0.129049 call_module layer3_0_downsample_1 7.10487e-05 0.126502 call_module layer3_0_bn1 6.62804e-05 0.118012 call_function add_7 6.46114e-05 0.11504 call_function add_6 5.79357e-05 0.103154 call_function add_4 5.62668e-05 0.100183 call_module layer4_1_relu 5.17368e-05 0.092117 call_module layer4_0_relu 4.8399e-05 0.086174 call_module layer2_0_relu_1 4.673e-05 0.0832024 call_module layer2_1_relu_1 4.62532e-05 0.0823534 call_module layer2_0_relu 4.29153e-05 0.0764104 call_module layer4_0_relu_1 4.24385e-05 0.0755614 call_module layer2_1_relu 4.1008e-05 0.0730144 call_module layer4_1_relu_1 3.95775e-05 0.0704674 call_module layer3_1_relu 3.62396e-05 0.0645243 call_module layer3_0_relu 3.60012e-05 0.0640998 call_module layer3_0_relu_1 3.55244e-05 0.0632508 call_module layer3_1_relu_1 3.38554e-05 0.0602793 call_function flatten 2.59876e-05 0.0462707 placeholder x 2.28882e-05 0.0407522 output output 9.29832e-06 0.0165556
There are two things we should call out here:
MaxPool2d
takes up the most time. This is a known issue: pytorch/pytorch#51393
Conclusion#
As we can see, using FX we can easily capture PyTorch programs (even ones we don’t have the source code for!) in a machine-interpretable format and use that for analysis, such as the performance analysis we’ve done here. FX opens up an exciting world of possibilities for working with PyTorch programs.
Finally, since FX is still in beta, we would be happy to hear any feedback you have about using it. Please feel free to use the PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker (pytorch/pytorch#issues) to provide any feedback you might have.
Total running time of the script: (0 minutes 0.313 seconds)