Skip to content

Commit 71aedc5

Browse files
BoyuanFengfacebook-github-bot
authored andcommitted
Replace constraints with dynamic_shapes in export-to-executorch tutorial (#1659)
Summary: Pull Request resolved: #1659 X-link: pytorch/pytorch#117916 `constraints` argument for `torch.export` has been deprecated in favor of the `dynamic_shapes` argument. This PR updates the use of the deprecated API in export-to-executorch tutorial. Reviewed By: angelayi Differential Revision: D52932772 fbshipit-source-id: 5f2e2ce02ba1990a90ae3de1c6c95e767e39d298
1 parent 93dd96f commit 71aedc5

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

docs/source/tutorials_source/export-to-executorch-tutorial.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
#
5050
# Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of
5151
# positional arguments, optionally a dictionary of keyword arguments (not shown
52-
# in the example), and a list of constraints (covered later).
52+
# in the example), and a list of dynamic shapes (covered later).
5353

5454
import torch
5555
from torch._export import capture_pre_autograd_graph
@@ -134,31 +134,23 @@ def f(x, y):
134134
tb.print_exc()
135135

136136
######################################################################
137-
# To express that some input shapes are dynamic, we can insert constraints to
138-
# the exporting flow. This is done through the ``dynamic_dim`` API:
137+
# To express that some input shapes are dynamic, we can insert dynamic
138+
# shapes to the exporting flow. This is done through the ``Dim`` API:
139139

140-
from torch.export import dynamic_dim
140+
from torch.export import Dim
141141

142142

143143
def f(x, y):
144144
return x + y
145145

146146

147147
example_args = (torch.randn(3, 3), torch.randn(3, 3))
148-
constraints = [
149-
# Input 0, dimension 1 is dynamic
150-
dynamic_dim(example_args[0], 1),
151-
# Input 0, dimension 1 must be greater than or equal to 1
152-
1 <= dynamic_dim(example_args[0], 1),
153-
# Input 0, dimension 1 must be less than or equal to 10
154-
dynamic_dim(example_args[0], 1) <= 10,
155-
# Input 1, dimension 1 is equal to input 0, dimension 1
156-
dynamic_dim(example_args[1], 1) == dynamic_dim(example_args[0], 1),
157-
]
148+
dim1_x = Dim("dim1_x", min=1, max=10)
149+
dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
158150
pre_autograd_aten_dialect = capture_pre_autograd_graph(
159-
f, example_args, constraints=constraints
151+
f, example_args, dynamic_shapes=dynamic_shapes
160152
)
161-
aten_dialect: ExportedProgram = export(f, example_args, constraints=constraints)
153+
aten_dialect: ExportedProgram = export(f, example_args, dynamic_shapes=dynamic_shapes)
162154
print("ATen Dialect Graph")
163155
print(aten_dialect)
164156

@@ -168,10 +160,7 @@ def f(x, y):
168160
# of values.
169161
#
170162
# Additionally, we can see in the **Range constraints** that value of ``s0`` has
171-
# the range [1, 10], which was specified by our constraints. We also see in the
172-
# **Equality constraints**, the tuple ``(InputDim(input_name='arg1_1', dim=1),
173-
# InputDim(input_name='arg0_1', dim=1))```, meaning that input 0's dimension 1
174-
# is equal to input 1's dimension 1, which was also specified by our constraints.
163+
# the range [1, 10], which was specified by our dynamic shapes.
175164
#
176165
# Now let's try running the model with different shapes:
177166

0 commit comments

Comments
 (0)