Skip to content

Commit 68f4039

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from dynamic_shape_propogation test
Summary: title Differential Revision: D56216416
1 parent c73bfc0 commit 68f4039

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

exir/tests/models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
import itertools
10-
from typing import List, Optional, Tuple, Union
10+
from typing import Any, List, Optional, Tuple, Union
1111

1212
import executorch.exir as exir
1313

@@ -34,6 +34,11 @@ def forward(
3434
def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
3535
return (torch.rand(4), torch.rand(5))
3636

37+
def get_dynamic_shape(self) -> Any: # pyre-ignore[3]
38+
dim = torch.export.Dim("dim", max=10)
39+
dim2 = torch.export.Dim("dim2", max=10)
40+
return ({0: dim}, {0: dim2})
41+
3742

3843
class ModelWithUnusedArg(nn.Module):
3944
def __init__(self) -> None:

exir/tests/test_dynamic_shape_propagation.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from unittest import TestCase
88

99
from executorch import exir
10+
from executorch.exir import to_edge
1011
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
1112
from executorch.exir.tests.models import Repeat
13+
from torch.export import export
1214

1315

1416
class TestDynamicShapeProp(TestCase):
@@ -17,15 +19,14 @@ def test_repeat(self):
1719
inputs = eager_model.get_random_inputs()
1820
inputs = inputs[0], inputs[1]
1921

20-
prog = exir.capture(
21-
eager_model,
22-
inputs,
23-
exir.CaptureConfig(enable_dynamic_shape=True),
24-
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
22+
prog = to_edge(
23+
export(eager_model, inputs, dynamic_shapes=eager_model.get_dynamic_shape()),
24+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
25+
)
2526

26-
new_prog = prog.transform(SpecPropPass(), HintBasedSymShapeEvalPass())
27+
new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
2728

28-
gm = new_prog.exported_program.graph_module
29+
gm = new_prog.exported_program().graph_module
2930

3031
DebugPass(show_spec=True)(gm)
3132
*_, return_node = gm.graph.nodes

0 commit comments

Comments
 (0)