Skip to content

Commit c0167c5

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Fix tests. (#2277)
Summary: Pull Request resolved: #2277 as title. bypass-github-export-checks Reviewed By: angelayi Differential Revision: D54588001 fbshipit-source-id: df805d81b7bc2f91c8440d41703a02971f2d365e
1 parent 8ad8a2e commit c0167c5

File tree

2 files changed

+43
-23
lines changed

2 files changed

+43
-23
lines changed

exir/backend/test/test_backends_lifted.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,17 +1012,19 @@ def false_fn(x, y):
10121012
x = x - y
10131013
return x
10141014

1015-
def f(x, y):
1016-
x = x + y
1017-
x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1018-
x = x - y
1019-
return x
1015+
class Module(torch.nn.Module):
1016+
def forward(self, x, y):
1017+
x = x + y
1018+
x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1019+
x = x - y
1020+
return x
10201021

1022+
f = Module()
10211023
inputs = (torch.ones(2, 2), torch.ones(2, 2))
10221024
orig_res = f(*inputs)
10231025
orig = to_edge(
10241026
export(
1025-
torch.export.WrapperModule(f),
1027+
f,
10261028
inputs,
10271029
)
10281030
)
@@ -1066,15 +1068,17 @@ def map_fn(x, y):
10661068
x = x + y
10671069
return x
10681070

1069-
def f(xs, y):
1070-
y = torch.mm(y, y)
1071-
return control_flow.map(map_fn, xs, y)
1071+
class Module(torch.nn.Module):
1072+
def forward(self, xs, y):
1073+
y = torch.mm(y, y)
1074+
return control_flow.map(map_fn, xs, y)
10721075

1076+
f = Module()
10731077
inputs = (torch.ones(2, 2), torch.ones(2, 2))
10741078
orig_res = f(*inputs)
10751079
orig = to_edge(
10761080
export(
1077-
torch.export.WrapperModule(f),
1081+
f,
10781082
inputs,
10791083
)
10801084
)
@@ -1132,9 +1136,10 @@ def map_fn(x, pred1, pred2, y):
11321136
x = x + y
11331137
return x.sin()
11341138

1135-
def f(xs, pred1, pred2, y):
1136-
y = torch.mm(y, y)
1137-
return control_flow.map(map_fn, xs, pred1, pred2, y)
1139+
class Module(torch.nn.Module):
1140+
def forward(self, xs, pred1, pred2, y):
1141+
y = torch.mm(y, y)
1142+
return control_flow.map(map_fn, xs, pred1, pred2, y)
11381143

11391144
inputs = (
11401145
torch.ones(2, 2),
@@ -1143,10 +1148,11 @@ def f(xs, pred1, pred2, y):
11431148
torch.ones(2, 2),
11441149
)
11451150

1151+
f = Module()
11461152
orig_res = f(*inputs)
11471153
orig = to_edge(
11481154
export(
1149-
torch.export.WrapperModule(f),
1155+
f,
11501156
inputs,
11511157
)
11521158
)
@@ -1205,12 +1211,14 @@ def f(xs, pred1, pred2, y):
12051211
)
12061212

12071213
def test_list_input(self):
1208-
def f(x: List[torch.Tensor]):
1209-
y = x[0] + x[1]
1210-
return y
1214+
class Module(torch.nn.Module):
1215+
def forward(self, x: List[torch.Tensor]):
1216+
y = x[0] + x[1]
1217+
return y
12111218

1219+
f = Module()
12121220
inputs = ([torch.randn(2, 2), torch.randn(2, 2)],)
1213-
edge_prog = to_edge(export(torch.export.WrapperModule(f), inputs))
1221+
edge_prog = to_edge(export(f, inputs))
12141222
lowered_gm = to_backend(
12151223
BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
12161224
)
@@ -1227,12 +1235,14 @@ def forward(self, x: List[torch.Tensor]):
12271235
gm.exported_program().module()(*inputs)
12281236

12291237
def test_dict_input(self):
1230-
def f(x: Dict[str, torch.Tensor]):
1231-
y = x["a"] + x["b"]
1232-
return y
1238+
class Module(torch.nn.Module):
1239+
def forward(self, x: Dict[str, torch.Tensor]):
1240+
y = x["a"] + x["b"]
1241+
return y
12331242

1243+
f = Module()
12341244
inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
1235-
edge_prog = to_edge(export(torch.export.WrapperModule(f), inputs))
1245+
edge_prog = to_edge(export(f, inputs))
12361246
lowered_gm = to_backend(
12371247
BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
12381248
)

exir/program/test/test_program.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030

3131
from torch.library import impl, Library
3232

33+
34+
class WrapperModule(torch.nn.Module):
35+
def __init__(self, fn):
36+
super().__init__()
37+
self.fn = fn
38+
39+
def forward(self, *args, **kwargs):
40+
return self.fn(*args, **kwargs)
41+
42+
3343
lib = Library("test_op", "DEF")
3444

3545
# Fake a operator for testing.
@@ -374,7 +384,7 @@ def _test_edge_dialect_verifier(self, callable, validate_ir=True):
374384
two,
375385
)
376386
if not isinstance(callable, torch.nn.Module):
377-
callable = torch.export.WrapperModule(callable)
387+
callable = WrapperModule(callable)
378388

379389
exported_foo = export(callable, inputs)
380390
_ = to_edge(exported_foo, compile_config=edge_compile_config)

0 commit comments

Comments
 (0)