Skip to content

Commit 2460e15

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
New URL for developer tools tutorial (#5384)
Summary: Pull Request resolved: #5384 ***I want to put the new URL in the PTC presentation slides.*** Old URL: https://pytorch.org/executorch/main/tutorials/sdk-integration-tutorial.html New URL (replaced "sdk" with "developer-tools"): https://pytorch.org/executorch/main/tutorials/developer-tools-integration-tutorial.html Reviewed By: dbort Differential Revision: D62727902 fbshipit-source-id: ecc13886f8649e5cede6dae3a00b05619aa456b1
1 parent a9ffb3a commit 2460e15

8 files changed

+311
-295
lines changed

docs/source/devtools-tutorial.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Developer Tools Usage Tutorial
2+
3+
Please refer to the [Developer Tools tutorial](./tutorials/devtools-integration-tutorial) for a walkthrough on how to profile a model in ExecuTorch using the Developer Tools.

docs/source/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Topics in this section will help you get started with ExecuTorch.
9494
tutorials/export-to-executorch-tutorial
9595
running-a-model-cpp-tutorial
9696
extension-module
97-
tutorials/sdk-integration-tutorial
97+
tutorials/devtools-integration-tutorial
9898
apple-runtime
9999
demo-apps-ios
100100
demo-apps-android
@@ -204,7 +204,7 @@ Topics in this section will help you get started with ExecuTorch.
204204
sdk-debugging
205205
sdk-inspector
206206
sdk-delegate-integration
207-
sdk-tutorial
207+
devtools-tutorial
208208

209209
.. toctree::
210210
:glob:
@@ -247,7 +247,7 @@ ExecuTorch tutorials.
247247
:header: Using the ExecuTorch Developer Tools to Profile a Model
248248
:card_description: A tutorial for using the ExecuTorch Developer Tools to profile and analyze a model with linkage back to source code.
249249
:image: _static/img/generic-pytorch-logo.png
250-
:link: tutorials/sdk-integration-tutorial.html
250+
:link: tutorials/devtools-integration-tutorial.html
251251
:tags: devtools
252252

253253
.. customcarditem::

docs/source/native-delegates-executorch-xnnpack-delegate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Since weight packing creates an extra copy of the weights inside XNNPACK, We fre
7474
When executing the XNNPACK subgraphs, we prepare the tensor inputs and outputs and feed them to the XNNPACK runtime graph. After executing the runtime graph, the output pointers are filled with the computed tensors.
7575

7676
#### **Profiling**
77-
We have enabled basic profiling for XNNPACK delegate that can be enabled with the following compiler flag `-DENABLE_XNNPACK_PROFILING`. With ExecuTorch's Developer Tools integration, you can also now use the Developer Tools to profile the model. You can follow the steps in [Using the ExecuTorch Developer Tools to Profile a Model](./tutorials/sdk-integration-tutorial) on how to profile ExecuTorch models and use Developer Tools' Inspector API to view XNNPACK's internal profiling information.
77+
We have enabled basic profiling for XNNPACK delegate that can be enabled with the following compiler flag `-DENABLE_XNNPACK_PROFILING`. With ExecuTorch's Developer Tools integration, you can also now use the Developer Tools to profile the model. You can follow the steps in [Using the ExecuTorch Developer Tools to Profile a Model](./tutorials/devtools-integration-tutorial) on how to profile ExecuTorch models and use Developer Tools' Inspector API to view XNNPACK's internal profiling information.
7878

7979

8080
[comment]: <> (TODO: Refactor quantizer to a more official quantization doc)

docs/source/sdk-inspector.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ APIs:
1717
* By accessing the `public attributes <#inspector-attributes>`__ of the ``Inspector``, ``EventBlock``, and ``Event`` classes.
1818
* By using a `CLI <#cli>`__ tool for basic functionalities.
1919

20-
Please refer to the `e2e use case doc <tutorials/sdk-integration-tutorial.html>`__ get an understanding of how to use these in a real world example.
20+
Please refer to the `e2e use case doc <tutorials/devtools-integration-tutorial.html>`__ get an understanding of how to use these in a real world example.
2121

2222

2323
Inspector Methods

docs/source/sdk-profiling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ We provide access to all the profiling data via the Python [Inspector API](./sdk
2020
- Through the Inspector API, users can do a wide range of analysis varying from printing out performance details to doing more finer granular calculation on module level.
2121

2222

23-
Please refer to the [Developer Tools tutorial](./tutorials/sdk-integration-tutorial.rst) for a step-by-step walkthrough of the above process on a sample model.
23+
Please refer to the [Developer Tools tutorial](./tutorials/devtools-integration-tutorial.rst) for a step-by-step walkthrough of the above process on a sample model.

docs/source/sdk-tutorial.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
## Developer Tools Usage Tutorial
22

3-
Please refer to the [Developer Tools tutorial](./tutorials/sdk-integration-tutorial) for a walkthrough on how to profile a model in ExecuTorch using the Developer Tools.
3+
Please update your link to <https://pytorch.org/executorch/main/devtools-tutorial.html>. This URL will be deleted after v0.4.0.
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Using the ExecuTorch Developer Tools to Profile a Model
10+
========================
11+
12+
**Author:** `Jack Khuu <https://github.com/Jack-Khuu>`__
13+
"""
14+
15+
######################################################################
16+
# The `ExecuTorch Developer Tools <../sdk-overview.html>`__ is a set of tools designed to
17+
# provide users with the ability to profile, debug, and visualize ExecuTorch
18+
# models.
19+
#
20+
# This tutorial will show a full end-to-end flow of how to utilize the Developer Tools to profile a model.
21+
# Specifically, it will:
22+
#
23+
# 1. Generate the artifacts consumed by the Developer Tools (`ETRecord <../sdk-etrecord.html>`__, `ETDump <../sdk-etdump.html>`__).
24+
# 2. Create an Inspector class consuming these artifacts.
25+
# 3. Utilize the Inspector class to analyze the model profiling result.
26+
27+
######################################################################
28+
# Prerequisites
29+
# -------------
30+
#
31+
# To run this tutorial, you’ll first need to
32+
# `Set up your ExecuTorch environment <../getting-started-setup.html>`__.
33+
#
34+
35+
######################################################################
36+
# Generate ETRecord (Optional)
37+
# ----------------------------
38+
#
39+
# The first step is to generate an ``ETRecord``. ``ETRecord`` contains model
40+
# graphs and metadata for linking runtime results (such as profiling) to
41+
# the eager model. This is generated via ``executorch.devtools.generate_etrecord``.
42+
#
43+
# ``executorch.devtools.generate_etrecord`` takes in an output file path (str), the
44+
# edge dialect model (``EdgeProgramManager``), the ExecuTorch dialect model
45+
# (``ExecutorchProgramManager``), and an optional dictionary of additional models.
46+
#
47+
# In this tutorial, an example model (shown below) is used to demonstrate.
48+
49+
import copy
50+
51+
import torch
52+
import torch.nn as nn
53+
import torch.nn.functional as F
54+
from executorch.devtools import generate_etrecord
55+
56+
from executorch.exir import (
57+
EdgeCompileConfig,
58+
EdgeProgramManager,
59+
ExecutorchProgramManager,
60+
to_edge,
61+
)
62+
from torch.export import export, ExportedProgram
63+
64+
65+
# Generate Model
66+
class Net(nn.Module):
67+
def __init__(self):
68+
super(Net, self).__init__()
69+
# 1 input image channel, 6 output channels, 5x5 square convolution
70+
# kernel
71+
self.conv1 = nn.Conv2d(1, 6, 5)
72+
self.conv2 = nn.Conv2d(6, 16, 5)
73+
# an affine operation: y = Wx + b
74+
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
75+
self.fc2 = nn.Linear(120, 84)
76+
self.fc3 = nn.Linear(84, 10)
77+
78+
def forward(self, x):
79+
# Max pooling over a (2, 2) window
80+
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
81+
# If the size is a square, you can specify with a single number
82+
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
83+
x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
84+
x = F.relu(self.fc1(x))
85+
x = F.relu(self.fc2(x))
86+
x = self.fc3(x)
87+
return x
88+
89+
90+
model = Net()
91+
92+
aten_model: ExportedProgram = export(
93+
model,
94+
(torch.randn(1, 1, 32, 32),),
95+
)
96+
97+
edge_program_manager: EdgeProgramManager = to_edge(
98+
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
99+
)
100+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
101+
et_program_manager: ExecutorchProgramManager = edge_program_manager.to_executorch()
102+
103+
104+
# Generate ETRecord
105+
etrecord_path = "etrecord.bin"
106+
generate_etrecord(etrecord_path, edge_program_manager_copy, et_program_manager)
107+
108+
# sphinx_gallery_start_ignore
109+
from unittest.mock import patch
110+
111+
# sphinx_gallery_end_ignore
112+
113+
######################################################################
114+
#
115+
# .. warning::
116+
# Users should do a deepcopy of the output of ``to_edge()`` and pass in the
117+
# deepcopy to the ``generate_etrecord`` API. This is needed because the
118+
# subsequent call, ``to_executorch()``, does an in-place mutation and will
119+
# lose debug data in the process.
120+
#
121+
122+
######################################################################
123+
# Generate ETDump
124+
# ---------------
125+
#
126+
# Next step is to generate an ``ETDump``. ``ETDump`` contains runtime results
127+
# from executing a `Bundled Program Model <../sdk-bundled-io.html>`__.
128+
#
129+
# In this tutorial, a `Bundled Program` is created from the example model above.
130+
131+
import torch
132+
from executorch.devtools import BundledProgram
133+
134+
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
135+
from executorch.devtools.bundled_program.serialize import (
136+
serialize_from_bundled_program_to_flatbuffer,
137+
)
138+
139+
from executorch.exir import to_edge
140+
from torch.export import export
141+
142+
# Step 1: ExecuTorch Program Export
143+
m_name = "forward"
144+
method_graphs = {m_name: export(model, (torch.randn(1, 1, 32, 32),))}
145+
146+
# Step 2: Construct Method Test Suites
147+
inputs = [[torch.randn(1, 1, 32, 32)] for _ in range(2)]
148+
149+
method_test_suites = [
150+
MethodTestSuite(
151+
method_name=m_name,
152+
test_cases=[
153+
MethodTestCase(inputs=inp, expected_outputs=getattr(model, m_name)(*inp))
154+
for inp in inputs
155+
],
156+
)
157+
]
158+
159+
# Step 3: Generate BundledProgram
160+
executorch_program = to_edge(method_graphs).to_executorch()
161+
bundled_program = BundledProgram(executorch_program, method_test_suites)
162+
163+
# Step 4: Serialize BundledProgram to flatbuffer.
164+
serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer(
165+
bundled_program
166+
)
167+
save_path = "bundled_program.bp"
168+
with open(save_path, "wb") as f:
169+
f.write(serialized_bundled_program)
170+
171+
######################################################################
172+
# Use CMake (follow `these instructions <../runtime-build-and-cross-compilation.html#configure-the-cmake-build>`__ to set up cmake) to execute the Bundled Program to generate the ``ETDump``::
173+
#
174+
# cd executorch
175+
# ./examples/devtools/build_example_runner.sh
176+
# cmake-out/examples/devtools/example_runner --bundled_program_path="bundled_program.bp"
177+
178+
######################################################################
179+
# Creating an Inspector
180+
# ---------------------
181+
#
182+
# Final step is to create the ``Inspector`` by passing in the artifact paths.
183+
# Inspector takes the runtime results from ``ETDump`` and correlates them to
184+
# the operators of the Edge Dialect Graph.
185+
#
186+
# Recall: An ``ETRecord`` is not required. If an ``ETRecord`` is not provided,
187+
# the Inspector will show runtime results without operator correlation.
188+
#
189+
# To visualize all runtime events, call Inspector's ``print_data_tabular``.
190+
191+
from executorch.devtools import Inspector
192+
193+
# sphinx_gallery_start_ignore
194+
inspector_patch = patch.object(Inspector, "__init__", return_value=None)
195+
inspector_patch_print = patch.object(Inspector, "print_data_tabular", return_value="")
196+
inspector_patch.start()
197+
inspector_patch_print.start()
198+
# sphinx_gallery_end_ignore
199+
etdump_path = "etdump.etdp"
200+
inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
201+
# sphinx_gallery_start_ignore
202+
inspector.event_blocks = []
203+
# sphinx_gallery_end_ignore
204+
inspector.print_data_tabular()
205+
206+
# sphinx_gallery_start_ignore
207+
inspector_patch.stop()
208+
inspector_patch_print.stop()
209+
# sphinx_gallery_end_ignore
210+
211+
######################################################################
212+
# Analyzing with an Inspector
213+
# ---------------------------
214+
#
215+
# ``Inspector`` provides 2 ways of accessing ingested information: `EventBlocks <../sdk-inspector#eventblock-class>`__
216+
# and ``DataFrames``. These mediums give users the ability to perform custom
217+
# analysis about their model performance.
218+
#
219+
# Below are examples usages, with both ``EventBlock`` and ``DataFrame`` approaches.
220+
221+
# Set Up
222+
import pprint as pp
223+
224+
import pandas as pd
225+
226+
pd.set_option("display.max_colwidth", None)
227+
pd.set_option("display.max_columns", None)
228+
229+
######################################################################
230+
# If a user wants the raw profiling results, they would do something similar to
231+
# finding the raw runtime data of an ``addmm.out`` event.
232+
233+
for event_block in inspector.event_blocks:
234+
# Via EventBlocks
235+
for event in event_block.events:
236+
if event.name == "native_call_addmm.out":
237+
print(event.name, event.perf_data.raw)
238+
239+
# Via Dataframe
240+
df = event_block.to_dataframe()
241+
df = df[df.event_name == "native_call_addmm.out"]
242+
print(df[["event_name", "raw"]])
243+
print()
244+
245+
######################################################################
246+
# If a user wants to trace an operator back to their model code, they would do
247+
# something similar to finding the module hierarchy and stack trace of the
248+
# slowest ``convolution.out`` call.
249+
250+
for event_block in inspector.event_blocks:
251+
# Via EventBlocks
252+
slowest = None
253+
for event in event_block.events:
254+
if event.name == "native_call_convolution.out":
255+
if slowest is None or event.perf_data.p50 > slowest.perf_data.p50:
256+
slowest = event
257+
if slowest is not None:
258+
print(slowest.name)
259+
print()
260+
pp.pprint(slowest.stack_traces)
261+
print()
262+
pp.pprint(slowest.module_hierarchy)
263+
264+
# Via Dataframe
265+
df = event_block.to_dataframe()
266+
df = df[df.event_name == "native_call_convolution.out"]
267+
if len(df) > 0:
268+
slowest = df.loc[df["p50"].idxmax()]
269+
print(slowest.event_name)
270+
print()
271+
pp.pprint(slowest.stack_traces)
272+
print()
273+
pp.pprint(slowest.module_hierarchy)
274+
275+
######################################################################
276+
# If a user wants the total runtime of a module, they can use
277+
# ``find_total_for_module``.
278+
279+
print(inspector.find_total_for_module("L__self__"))
280+
print(inspector.find_total_for_module("L__self___conv2"))
281+
282+
######################################################################
283+
# Note: ``find_total_for_module`` is a special first class method of
284+
# `Inspector <../sdk-inspector.html>`__
285+
286+
######################################################################
287+
# Conclusion
288+
# ----------
289+
#
290+
# In this tutorial, we learned about the steps required to consume an ExecuTorch
291+
# model with the ExecuTorch Developer Tools. It also showed how to use the Inspector APIs
292+
# to analyze the model run results.
293+
#
294+
# Links Mentioned
295+
# ^^^^^^^^^^^^^^^
296+
#
297+
# - `ExecuTorch Developer Tools Overview <../sdk-overview.html>`__
298+
# - `ETRecord <../sdk-etrecord.html>`__
299+
# - `ETDump <../sdk-etdump.html>`__
300+
# - `Inspector <../sdk-inspector.html>`__

0 commit comments

Comments
 (0)