Skip to content

Commit dc7df4f

Browse files
guangy10facebook-github-bot
authored andcommitted
Fix source for export to executorch tutorial (#2400)
Summary: Pull Request resolved: #2400 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: angelayi Differential Revision: D54862906 Pulled By: guangy10 fbshipit-source-id: c26f03f56d58697f155b8201a0711da7dd5b86da
1 parent d5f898d commit dc7df4f

File tree

2 files changed

+38
-33
lines changed

2 files changed

+38
-33
lines changed

docs/source/tutorials_source/export-to-executorch-tutorial.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
130130
aten_dialect: ExportedProgram = export(f, example_args)
131131

132132
# Works correctly
133-
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))
133+
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
134134

135135
# Errors
136136
try:
137-
print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
137+
print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))
138138
except Exception:
139139
tb.print_exc()
140140

@@ -175,18 +175,18 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
175175
# Now let's try running the model with different shapes:
176176

177177
# Works correctly
178-
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))
179-
print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
178+
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
179+
print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))
180180

181181
# Errors because it violates our constraint that input 0, dim 1 <= 10
182182
try:
183-
print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15)))
183+
print(aten_dialect.module()(torch.ones(3, 15), torch.ones(3, 15)))
184184
except Exception:
185185
tb.print_exc()
186186

187187
# Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
188188
try:
189-
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2)))
189+
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 2)))
190190
except Exception:
191191
tb.print_exc()
192192

@@ -287,23 +287,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
287287
# there is only one program, it will by default be saved to the name "forward".
288288

289289

290-
def encode(x):
291-
return torch.nn.functional.linear(x, torch.randn(5, 10))
290+
class Encode(torch.nn.Module):
291+
def forward(self, x):
292+
return torch.nn.functional.linear(x, torch.randn(5, 10))
292293

293294

294-
def decode(x):
295-
return torch.nn.functional.linear(x, torch.randn(10, 5))
295+
class Decode(torch.nn.Module):
296+
def forward(self, x):
297+
return torch.nn.functional.linear(x, torch.randn(10, 5))
296298

297299

298300
encode_args = (torch.randn(1, 10),)
299301
aten_encode: ExportedProgram = export(
300-
capture_pre_autograd_graph(encode, encode_args),
302+
capture_pre_autograd_graph(Encode(), encode_args),
301303
encode_args,
302304
)
303305

304306
decode_args = (torch.randn(1, 5),)
305307
aten_decode: ExportedProgram = export(
306-
capture_pre_autograd_graph(decode, decode_args),
308+
capture_pre_autograd_graph(Decode(), decode_args),
307309
decode_args,
308310
)
309311

@@ -486,17 +488,18 @@ def forward(self, x):
486488
# ``LoweredBackendModule`` for each of those subgraphs.
487489

488490

489-
def f(a, x, b):
490-
y = torch.mm(a, x)
491-
z = y + b
492-
a = z - a
493-
y = torch.mm(a, x)
494-
z = y + b
495-
return z
491+
class Foo(torch.nn.Module):
492+
def forward(self, a, x, b):
493+
y = torch.mm(a, x)
494+
z = y + b
495+
a = z - a
496+
y = torch.mm(a, x)
497+
z = y + b
498+
return z
496499

497500

498501
example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
499-
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
502+
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
500503
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
501504
edge_program: EdgeProgramManager = to_edge(aten_dialect)
502505
exported_program = edge_program.exported_program()
@@ -520,17 +523,18 @@ def f(a, x, b):
520523
# call ``to_backend`` on it:
521524

522525

523-
def f(a, x, b):
524-
y = torch.mm(a, x)
525-
z = y + b
526-
a = z - a
527-
y = torch.mm(a, x)
528-
z = y + b
529-
return z
526+
class Foo(torch.nn.Module):
527+
def forward(self, a, x, b):
528+
y = torch.mm(a, x)
529+
z = y + b
530+
a = z - a
531+
y = torch.mm(a, x)
532+
z = y + b
533+
return z
530534

531535

532536
example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
533-
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
537+
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
534538
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
535539
edge_program: EdgeProgramManager = to_edge(aten_dialect)
536540
exported_program = edge_program.exported_program()

examples/apple/coreml/scripts/export.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
import pathlib
99
import sys
1010

11-
import torch
1211
import executorch.exir as exir
1312

14-
from executorch.backends.apple.coreml.compiler import CoreMLBackend
13+
import torch
1514

16-
from executorch.exir.backend.backend_api import to_backend
17-
from executorch.exir.backend.compile_spec_schema import CompileSpec
18-
from executorch.sdk.etrecord import generate_etrecord
15+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1916

2017
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
2118
CoreMLPartitioner,
2219
)
2320

21+
from executorch.exir.backend.backend_api import to_backend
22+
from executorch.exir.backend.compile_spec_schema import CompileSpec
23+
from executorch.sdk.etrecord import generate_etrecord
24+
2425
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent.parent
2526
EXAMPLES_DIR = REPO_ROOT / "examples"
2627
sys.path.append(str(EXAMPLES_DIR.absolute()))

0 commit comments

Comments
 (0)