Skip to content

Commit 24364dd

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
rename tutorial (#5384)
Summary: Pull Request resolved: #5384 Differential Revision: D62727902
1 parent eecf74f commit 24364dd

File tree

5 files changed

+311
-290
lines changed

5 files changed

+311
-290
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Developer Tools Usage Tutorial
2+
3+
Please refer to the [Developer Tools tutorial](./tutorials/developer-tools-integration-tutorial) for a walkthrough on how to profile a model in ExecuTorch using the Developer Tools.

docs/source/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Topics in this section will help you get started with ExecuTorch.
9494
tutorials/export-to-executorch-tutorial
9595
running-a-model-cpp-tutorial
9696
extension-module
97-
tutorials/sdk-integration-tutorial
97+
tutorials/developer-tools-integration-tutorial
9898
apple-runtime
9999
demo-apps-ios
100100
demo-apps-android
@@ -204,7 +204,7 @@ Topics in this section will help you get started with ExecuTorch.
204204
sdk-debugging
205205
sdk-inspector
206206
sdk-delegate-integration
207-
sdk-tutorial
207+
developer-tools-tutorial
208208

209209
.. toctree::
210210
:glob:
@@ -247,7 +247,7 @@ ExecuTorch tutorials.
247247
:header: Using the ExecuTorch Developer Tools to Profile a Model
248248
:card_description: A tutorial for using the ExecuTorch Developer Tools to profile and analyze a model with linkage back to source code.
249249
:image: _static/img/generic-pytorch-logo.png
250-
:link: tutorials/sdk-integration-tutorial.html
250+
:link: tutorials/developer-tools-integration-tutorial.html
251251
:tags: devtools
252252

253253
.. customcarditem::

docs/source/sdk-tutorial.md

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Using the ExecuTorch Developer Tools to Profile a Model
10+
========================
11+
12+
**Author:** `Jack Khuu <https://github.com/Jack-Khuu>`__
13+
"""
14+
15+
######################################################################
16+
# The `ExecuTorch Developer Tools <../sdk-overview.html>`__ is a set of tools designed to
17+
# provide users with the ability to profile, debug, and visualize ExecuTorch
18+
# models.
19+
#
20+
# This tutorial will show a full end-to-end flow of how to utilize the Developer Tools to profile a model.
21+
# Specifically, it will:
22+
#
23+
# 1. Generate the artifacts consumed by the Developer Tools (`ETRecord <../sdk-etrecord.html>`__, `ETDump <../sdk-etdump.html>`__).
24+
# 2. Create an Inspector class consuming these artifacts.
25+
# 3. Utilize the Inspector class to analyze the model profiling result.
26+
27+
######################################################################
28+
# Prerequisites
29+
# -------------
30+
#
31+
# To run this tutorial, you’ll first need to
32+
# `Set up your ExecuTorch environment <../getting-started-setup.html>`__.
33+
#
34+
35+
######################################################################
36+
# Generate ETRecord (Optional)
37+
# ----------------------------
38+
#
39+
# The first step is to generate an ``ETRecord``. ``ETRecord`` contains model
40+
# graphs and metadata for linking runtime results (such as profiling) to
41+
# the eager model. This is generated via ``executorch.devtools.generate_etrecord``.
42+
#
43+
# ``executorch.devtools.generate_etrecord`` takes in an output file path (str), the
44+
# edge dialect model (``EdgeProgramManager``), the ExecuTorch dialect model
45+
# (``ExecutorchProgramManager``), and an optional dictionary of additional models.
46+
#
47+
# In this tutorial, an example model (shown below) is used to demonstrate.
48+
49+
import copy
50+
51+
import torch
52+
import torch.nn as nn
53+
import torch.nn.functional as F
54+
from executorch.devtools import generate_etrecord
55+
56+
from executorch.exir import (
57+
EdgeCompileConfig,
58+
EdgeProgramManager,
59+
ExecutorchProgramManager,
60+
to_edge,
61+
)
62+
from torch.export import export, ExportedProgram
63+
64+
65+
# Generate Model
66+
class Net(nn.Module):
67+
def __init__(self):
68+
super(Net, self).__init__()
69+
# 1 input image channel, 6 output channels, 5x5 square convolution
70+
# kernel
71+
self.conv1 = nn.Conv2d(1, 6, 5)
72+
self.conv2 = nn.Conv2d(6, 16, 5)
73+
# an affine operation: y = Wx + b
74+
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
75+
self.fc2 = nn.Linear(120, 84)
76+
self.fc3 = nn.Linear(84, 10)
77+
78+
def forward(self, x):
79+
# Max pooling over a (2, 2) window
80+
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
81+
# If the size is a square, you can specify with a single number
82+
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
83+
x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
84+
x = F.relu(self.fc1(x))
85+
x = F.relu(self.fc2(x))
86+
x = self.fc3(x)
87+
return x
88+
89+
90+
model = Net()
91+
92+
aten_model: ExportedProgram = export(
93+
model,
94+
(torch.randn(1, 1, 32, 32),),
95+
)
96+
97+
edge_program_manager: EdgeProgramManager = to_edge(
98+
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
99+
)
100+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
101+
et_program_manager: ExecutorchProgramManager = edge_program_manager.to_executorch()
102+
103+
104+
# Generate ETRecord
105+
etrecord_path = "etrecord.bin"
106+
generate_etrecord(etrecord_path, edge_program_manager_copy, et_program_manager)
107+
108+
# sphinx_gallery_start_ignore
109+
from unittest.mock import patch
110+
111+
# sphinx_gallery_end_ignore
112+
113+
######################################################################
114+
#
115+
# .. warning::
116+
# Users should do a deepcopy of the output of ``to_edge()`` and pass in the
117+
# deepcopy to the ``generate_etrecord`` API. This is needed because the
118+
# subsequent call, ``to_executorch()``, does an in-place mutation and will
119+
# lose debug data in the process.
120+
#
121+
122+
######################################################################
123+
# Generate ETDump
124+
# ---------------
125+
#
126+
# Next step is to generate an ``ETDump``. ``ETDump`` contains runtime results
127+
# from executing a `Bundled Program Model <../sdk-bundled-io.html>`__.
128+
#
129+
# In this tutorial, a `Bundled Program` is created from the example model above.
130+
131+
import torch
132+
from executorch.devtools import BundledProgram
133+
134+
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
135+
from executorch.devtools.bundled_program.serialize import (
136+
serialize_from_bundled_program_to_flatbuffer,
137+
)
138+
139+
from executorch.exir import to_edge
140+
from torch.export import export
141+
142+
# Step 1: ExecuTorch Program Export
143+
m_name = "forward"
144+
method_graphs = {m_name: export(model, (torch.randn(1, 1, 32, 32),))}
145+
146+
# Step 2: Construct Method Test Suites
147+
inputs = [[torch.randn(1, 1, 32, 32)] for _ in range(2)]
148+
149+
method_test_suites = [
150+
MethodTestSuite(
151+
method_name=m_name,
152+
test_cases=[
153+
MethodTestCase(inputs=inp, expected_outputs=getattr(model, m_name)(*inp))
154+
for inp in inputs
155+
],
156+
)
157+
]
158+
159+
# Step 3: Generate BundledProgram
160+
executorch_program = to_edge(method_graphs).to_executorch()
161+
bundled_program = BundledProgram(executorch_program, method_test_suites)
162+
163+
# Step 4: Serialize BundledProgram to flatbuffer.
164+
serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer(
165+
bundled_program
166+
)
167+
save_path = "bundled_program.bp"
168+
with open(save_path, "wb") as f:
169+
f.write(serialized_bundled_program)
170+
171+
######################################################################
172+
# Use CMake (follow `these instructions <../runtime-build-and-cross-compilation.html#configure-the-cmake-build>`__ to set up cmake) to execute the Bundled Program to generate the ``ETDump``::
173+
#
174+
# cd executorch
175+
# ./examples/devtools/build_example_runner.sh
176+
# cmake-out/examples/devtools/example_runner --bundled_program_path="bundled_program.bp"
177+
178+
######################################################################
179+
# Creating an Inspector
180+
# ---------------------
181+
#
182+
# Final step is to create the ``Inspector`` by passing in the artifact paths.
183+
# Inspector takes the runtime results from ``ETDump`` and correlates them to
184+
# the operators of the Edge Dialect Graph.
185+
#
186+
# Recall: An ``ETRecord`` is not required. If an ``ETRecord`` is not provided,
187+
# the Inspector will show runtime results without operator correlation.
188+
#
189+
# To visualize all runtime events, call Inspector's ``print_data_tabular``.
190+
191+
from executorch.devtools import Inspector
192+
193+
# sphinx_gallery_start_ignore
194+
inspector_patch = patch.object(Inspector, "__init__", return_value=None)
195+
inspector_patch_print = patch.object(Inspector, "print_data_tabular", return_value="")
196+
inspector_patch.start()
197+
inspector_patch_print.start()
198+
# sphinx_gallery_end_ignore
199+
etdump_path = "etdump.etdp"
200+
inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
201+
# sphinx_gallery_start_ignore
202+
inspector.event_blocks = []
203+
# sphinx_gallery_end_ignore
204+
inspector.print_data_tabular()
205+
206+
# sphinx_gallery_start_ignore
207+
inspector_patch.stop()
208+
inspector_patch_print.stop()
209+
# sphinx_gallery_end_ignore
210+
211+
######################################################################
212+
# Analyzing with an Inspector
213+
# ---------------------------
214+
#
215+
# ``Inspector`` provides 2 ways of accessing ingested information: `EventBlocks <../sdk-inspector#eventblock-class>`__
216+
# and ``DataFrames``. These mediums give users the ability to perform custom
217+
# analysis about their model performance.
218+
#
219+
# Below are examples usages, with both ``EventBlock`` and ``DataFrame`` approaches.
220+
221+
# Set Up
222+
import pprint as pp
223+
224+
import pandas as pd
225+
226+
pd.set_option("display.max_colwidth", None)
227+
pd.set_option("display.max_columns", None)
228+
229+
######################################################################
230+
# If a user wants the raw profiling results, they would do something similar to
231+
# finding the raw runtime data of an ``addmm.out`` event.
232+
233+
for event_block in inspector.event_blocks:
234+
# Via EventBlocks
235+
for event in event_block.events:
236+
if event.name == "native_call_addmm.out":
237+
print(event.name, event.perf_data.raw)
238+
239+
# Via Dataframe
240+
df = event_block.to_dataframe()
241+
df = df[df.event_name == "native_call_addmm.out"]
242+
print(df[["event_name", "raw"]])
243+
print()
244+
245+
######################################################################
246+
# If a user wants to trace an operator back to their model code, they would do
247+
# something similar to finding the module hierarchy and stack trace of the
248+
# slowest ``convolution.out`` call.
249+
250+
for event_block in inspector.event_blocks:
251+
# Via EventBlocks
252+
slowest = None
253+
for event in event_block.events:
254+
if event.name == "native_call_convolution.out":
255+
if slowest is None or event.perf_data.p50 > slowest.perf_data.p50:
256+
slowest = event
257+
if slowest is not None:
258+
print(slowest.name)
259+
print()
260+
pp.pprint(slowest.stack_traces)
261+
print()
262+
pp.pprint(slowest.module_hierarchy)
263+
264+
# Via Dataframe
265+
df = event_block.to_dataframe()
266+
df = df[df.event_name == "native_call_convolution.out"]
267+
if len(df) > 0:
268+
slowest = df.loc[df["p50"].idxmax()]
269+
print(slowest.event_name)
270+
print()
271+
pp.pprint(slowest.stack_traces)
272+
print()
273+
pp.pprint(slowest.module_hierarchy)
274+
275+
######################################################################
276+
# If a user wants the total runtime of a module, they can use
277+
# ``find_total_for_module``.
278+
279+
print(inspector.find_total_for_module("L__self__"))
280+
print(inspector.find_total_for_module("L__self___conv2"))
281+
282+
######################################################################
283+
# Note: ``find_total_for_module`` is a special first class method of
284+
# `Inspector <../sdk-inspector.html>`__
285+
286+
######################################################################
287+
# Conclusion
288+
# ----------
289+
#
290+
# In this tutorial, we learned about the steps required to consume an ExecuTorch
291+
# model with the ExecuTorch Developer Tools. It also showed how to use the Inspector APIs
292+
# to analyze the model run results.
293+
#
294+
# Links Mentioned
295+
# ^^^^^^^^^^^^^^^
296+
#
297+
# - `ExecuTorch Developer Tools Overview <../sdk-overview.html>`__
298+
# - `ETRecord <../sdk-etrecord.html>`__
299+
# - `ETDump <../sdk-etdump.html>`__
300+
# - `Inspector <../sdk-inspector.html>`__

0 commit comments

Comments
 (0)