Skip to content

Commit c393b2f

Browse files
suopytorchmergebot
authored andcommitted
[export] require Module to be passed to export (pytorch#117528)
This PR changes torch.export to require an nn.Module as input, rather than taking an arbitrary callable. The rationale for this is that we have several invariants the ExportedProgram that are ambiguous if the top-level object being traced is a function: 1. We "guarantee" that every call_function node has an `nn_module_stack` populated. 2. We offer ways to access the state_dict/parameters/buffers of the exported program. We'd like torch.export to offer strong invariants—the value proposition of export is that you can trade flexibility for stronger guarantees about your model. An alternative design would be to implicitly convert the top-level function into a module, rather than require that the user provide a module. I think that's reasonable (it's what we did in TorchScript), but in the spirit of being explicit (another design tenet of export) I avoid that here. Differential Revision: [D52789321](https://our.internmc.facebook.com/intern/diff/D52789321/) Pull Request resolved: pytorch#117528 Approved by: https://github.com/thiagocrepaldi, https://github.com/zhxchen17, https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
1 parent 3ee092f commit c393b2f

File tree

4 files changed

+139
-88
lines changed

4 files changed

+139
-88
lines changed

test/export/test_export.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2876,8 +2876,11 @@ def test_lift_custom_obj(self):
28762876

28772877
custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4])
28782878

2879-
def f(x):
2880-
return x + x
2879+
class Foo(torch.nn.Module):
2880+
def forward(self, x):
2881+
return x + x
2882+
2883+
f = Foo()
28812884

28822885
inputs = (torch.zeros(4, 4),)
28832886
ep = export(f, inputs)

test/onnx/test_fx_to_onnx_with_onnxruntime.py

Lines changed: 104 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,15 @@ def setUp(self):
8383
self.ort_version = onnxruntime.__version__
8484

8585
def test_simple_function(self):
86-
def func(x):
87-
# TODO(justinchuby): Replicate torch's type casting policy
88-
# in the exporter for type promotion support
89-
y = x + 1.0
90-
z = y.relu()
91-
return (y, z)
86+
class Foo(torch.nn.Module):
87+
def forward(self, x):
88+
# TODO(justinchuby): Replicate torch's type casting policy
89+
# in the exporter for type promotion support
90+
y = x + 1.0
91+
z = y.relu()
92+
return (y, z)
93+
94+
func = Foo()
9295

9396
tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
9497

@@ -118,10 +121,13 @@ def test_func_with_args_and_tensor_kwargs(self):
118121
# practice to set mutable default values.
119122
# `DynamoOptimizeExporter` applies a workaround by binding args and kwargs to
120123
# model signature and fill in the default values of unprovided optional arguments.
121-
def func(x, b=torch.tensor(1.0)):
122-
y = x + b
123-
z = y.relu()
124-
return (y, z)
124+
class Foo(torch.nn.Module):
125+
def forward(self, x, b=torch.tensor(1.0)):
126+
y = x + b
127+
z = y.relu()
128+
return (y, z)
129+
130+
func = Foo()
125131

126132
tensor_x = torch.randn(1, 2, 3, dtype=torch.float32)
127133

@@ -140,21 +146,24 @@ def func(x, b=torch.tensor(1.0)):
140146
"sympy operation tests don't need dynamic shape"
141147
)
142148
def test_sympy_operatons_return_numeric(self):
143-
def func(x, y):
144-
# TODO: add boolean tests when SymBool is supported
145-
# to infer types
146-
return (
147-
torch.tensor([operator.add(x.item(), y.item())]),
148-
torch.tensor([operator.sub(x.item(), y.item())]),
149-
torch.tensor([operator.mul(x.item(), y.item())]),
150-
torch.tensor([operator.truediv(x.item(), y.item())]),
151-
torch.tensor([operator.floordiv(x.item(), y.item())]),
152-
torch.tensor([operator.pow(x.item(), y.item())]),
153-
torch.tensor([operator.abs(x.item())]),
154-
torch.tensor([operator.neg(x.item())]),
155-
torch.tensor([math.ceil(x.item())]),
156-
torch.tensor([math.floor(x.item())]),
157-
)
149+
class Foo(torch.nn.Module):
150+
def forward(self, x, y):
151+
# TODO: add boolean tests when SymBool is supported
152+
# to infer types
153+
return (
154+
torch.tensor([operator.add(x.item(), y.item())]),
155+
torch.tensor([operator.sub(x.item(), y.item())]),
156+
torch.tensor([operator.mul(x.item(), y.item())]),
157+
torch.tensor([operator.truediv(x.item(), y.item())]),
158+
torch.tensor([operator.floordiv(x.item(), y.item())]),
159+
torch.tensor([operator.pow(x.item(), y.item())]),
160+
torch.tensor([operator.abs(x.item())]),
161+
torch.tensor([operator.neg(x.item())]),
162+
torch.tensor([math.ceil(x.item())]),
163+
torch.tensor([math.floor(x.item())]),
164+
)
165+
166+
func = Foo()
158167

159168
x = torch.randn(1, dtype=torch.float32)
160169
y = torch.randn(1, dtype=torch.float32)
@@ -171,10 +180,13 @@ def func(x, y):
171180
reason="https://github.com/pytorch/pytorch/issues/99534",
172181
)
173182
def test_xfail_func_with_non_tensor_args(self):
174-
def func(x, b=1.0):
175-
y = x + b
176-
z = y.relu()
177-
return (y, z)
183+
class Foo(torch.nn.Module):
184+
def forward(self, x, b=1.0):
185+
y = x + b
186+
z = y.relu()
187+
return (y, z)
188+
189+
func = Foo()
178190

179191
tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
180192

@@ -202,25 +214,29 @@ def func(x, b=1.0):
202214
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
203215

204216
def test_func_with_nested_input_structure(self):
205-
def func(
206-
x_dict: Dict[str, torch.Tensor],
207-
y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
208-
z_list: List[List[torch.Tensor]],
209-
):
210-
if "a" in x_dict:
211-
x = x_dict["a"]
212-
elif "b" in x_dict:
213-
x = x_dict["b"]
214-
else:
215-
x = torch.randn(3)
217+
class Foo(torch.nn.Module):
218+
def forward(
219+
self,
220+
x_dict: Dict[str, torch.Tensor],
221+
y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
222+
z_list: List[List[torch.Tensor]],
223+
):
224+
if "a" in x_dict:
225+
x = x_dict["a"]
226+
elif "b" in x_dict:
227+
x = x_dict["b"]
228+
else:
229+
x = torch.randn(3)
216230

217-
y1, (y2, y3) = y_tuple
231+
y1, (y2, y3) = y_tuple
218232

219-
z = x + y1 + y2 + y3
220-
for z_sub_list in z_list:
221-
z = z + torch.stack(z_sub_list).sum()
233+
z = x + y1 + y2 + y3
234+
for z_sub_list in z_list:
235+
z = z + torch.stack(z_sub_list).sum()
222236

223-
return z
237+
return z
238+
239+
func = Foo()
224240

225241
x_dict = {"a": torch.randn(3), "c": torch.randn(3)}
226242
y_tuple = (torch.randn(3), (torch.randn(3), torch.randn(3)))
@@ -233,14 +249,17 @@ def func(
233249
)
234250

235251
def test_func_with_nested_output_structure(self):
236-
def func(x, y, z):
237-
x = x + y
238-
y = y + z
239-
z = x + y
240-
out1 = (x, (y, z))
241-
out2 = [[x, y], [y, z]]
242-
out3 = {"z": z, "x": x}
243-
return out1, out2, out3
252+
class Foo(torch.nn.Module):
253+
def forward(self, x, y, z):
254+
x = x + y
255+
y = y + z
256+
z = x + y
257+
out1 = (x, (y, z))
258+
out2 = [[x, y], [y, z]]
259+
out3 = {"z": z, "x": x}
260+
return out1, out2, out3
261+
262+
func = Foo()
244263

245264
x = torch.randn(3)
246265
y = torch.randn(3)
@@ -535,19 +554,22 @@ def forward(self, x):
535554

536555
@pytorch_test_common.skipIfNoCuda
537556
def test__scaled_dot_product_flash_attention(self):
538-
def func(x):
539-
(
540-
output,
541-
_,
542-
_,
543-
_,
544-
_,
545-
_,
546-
_,
547-
_,
548-
_,
549-
) = torch.ops.aten._scaled_dot_product_flash_attention(x, x, x)
550-
return output
557+
class Foo(torch.nn.Module):
558+
def forward(self, x):
559+
(
560+
output,
561+
_,
562+
_,
563+
_,
564+
_,
565+
_,
566+
_,
567+
_,
568+
_,
569+
) = torch.ops.aten._scaled_dot_product_flash_attention(x, x, x)
570+
return output
571+
572+
func = Foo()
551573

552574
x = torch.randn(1, 1, 1, 32, device=torch.device("cuda"))
553575
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x,))
@@ -597,9 +619,12 @@ def forward(
597619
)
598620

599621
def test_operator_with_data_dependent_output(self):
600-
def func(x):
601-
# Repro from llama. Emits `torch.ops.aten._local_scalar_dense`.
602-
return x + torch.full(x.shape, torch.tensor(torch.finfo(x.dtype).min))
622+
class Foo(torch.nn.Module):
623+
def forward(self, x):
624+
# Repro from llama. Emits `torch.ops.aten._local_scalar_dense`.
625+
return x + torch.full(x.shape, torch.tensor(torch.finfo(x.dtype).min))
626+
627+
func = Foo()
603628

604629
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
605630
func, (torch.randn(3, 4),)
@@ -610,8 +635,11 @@ def func(x):
610635
reason="https://github.com/pytorch/pytorch/issues/112622",
611636
)
612637
def test_operator_with_scalar_output(self):
613-
def func(x, y):
614-
return x.item() + y
638+
class Foo(torch.nn.Module):
639+
def forward(self, x, y):
640+
return x.item() + y
641+
642+
func = Foo()
615643

616644
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
617645
func, (torch.tensor([1]), torch.randn(3, 4))
@@ -622,8 +650,11 @@ def func(x, y):
622650
reason="https://github.com/pytorch/pytorch/issues/112622",
623651
)
624652
def test_operator_with_dynamic_output_shape(self):
625-
def func(x):
626-
return x.nonzero()
653+
class Foo(torch.nn.Module):
654+
def forward(self, x):
655+
return x.nonzero()
656+
657+
func = Foo()
627658

628659
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
629660
func, (torch.randn(3, 4),)

test/onnx/torch_export/test_torch_export_with_onnxruntime.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@ def forward(self, x):
8484
)
8585

8686
def test_exported_program_with_specialized_input_during_tracing(self):
87-
def f(x, y):
88-
return x + y
87+
class Foo(torch.nn.Module):
88+
def forward(self, x, y):
89+
return x + y
90+
91+
f = Foo()
8992

9093
tensor_input = torch.ones(7, 5)
9194
dim0_x = torch.export.Dim("dim0_x", min=6)
@@ -131,7 +134,7 @@ def forward(self, x):
131134
# NOTE: If input is ExportedProgram, we need to specify dynamic_shapes
132135
# as a tuple.
133136
reexported_program = torch.export.export(
134-
exported_program, (tensor_input,), dynamic_shapes=({0: dim0_x},)
137+
exported_program.module(), (tensor_input,), dynamic_shapes=({0: dim0_x},)
135138
)
136139
reexported_onnx_program = torch.onnx.dynamo_export(
137140
reexported_program, tensor_input
@@ -145,8 +148,11 @@ def forward(self, x):
145148
)
146149

147150
def test_onnx_program_supports_none_arg_name_in_dynamic(self):
148-
def foo(a, b):
149-
return a.sum() + b.sum()
151+
class Foo(torch.nn.Module):
152+
def forward(self, a, b):
153+
return a.sum() + b.sum()
154+
155+
foo = Foo()
150156

151157
dim = torch.export.Dim("dim")
152158
exported_program = torch.export.export(
@@ -165,8 +171,11 @@ def foo(a, b):
165171
)
166172

167173
def test_onnx_program_suppors_non_arg_name_with_kwarg(self):
168-
def foo(a, b, kw1, kw2):
169-
return a.sum() + b.sum() + kw1.sum() - kw2.sum()
174+
class Foo(torch.nn.Module):
175+
def forward(self, a, b, kw1, kw2):
176+
return a.sum() + b.sum() + kw1.sum() - kw2.sum()
177+
178+
foo = Foo()
170179

171180
dim = torch.export.Dim("dim")
172181
dim_for_kw1 = torch.export.Dim("dim_for_kw1")
@@ -238,8 +247,11 @@ def forward(self, x, b):
238247
)
239248

240249
def test_onnx_program_supports_non_arg_name_with_container_type(self):
241-
def foo(a, b):
242-
return a[0].sum() + a[1].sum() + b.sum()
250+
class Foo(torch.nn.Module):
251+
def forward(self, a, b):
252+
return a[0].sum() + a[1].sum() + b.sum()
253+
254+
foo = Foo()
243255

244256
inp_a = (torch.randn(4, 4), torch.randn(4, 4))
245257
inp_b = torch.randn(4, 4)

torch/export/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575

7676
def export(
77-
f: Callable,
77+
mod: torch.nn.Module,
7878
args: Tuple[Any, ...],
7979
kwargs: Optional[Dict[str, Any]] = None,
8080
*,
@@ -124,7 +124,7 @@ def export(
124124
``dynamic_shapes`` argument to your :func:`export` call.
125125
126126
Args:
127-
f: The callable to trace.
127+
mod: We will trace the forward method of this module.
128128
129129
args: Example positional inputs.
130130
@@ -179,6 +179,11 @@ def export(
179179
from ._trace import _export
180180
from .dynamic_shapes import _process_dynamic_shapes
181181

182+
if not isinstance(mod, torch.nn.Module):
183+
raise ValueError(
184+
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
185+
)
186+
182187
if constraints is not None:
183188
warnings.warn(
184189
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
@@ -188,10 +193,10 @@ def export(
188193
stacklevel=2,
189194
)
190195
else:
191-
constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
196+
constraints = _process_dynamic_shapes(mod, args, kwargs, dynamic_shapes)
192197

193198
return _export(
194-
f,
199+
mod,
195200
args,
196201
kwargs,
197202
constraints,

0 commit comments

Comments
 (0)