Skip to content

Commit 67e22d6

Browse files
jiashenCpytorchmergebot
authored andcommitted
[Fix]: Convert operator that does specialization to its symbolic counterpart (pytorch#129578)
#### Issue During conversion, use symbolic operator when exist. #### Test Plan `pytest test/export/test_converter.py` Pull Request resolved: pytorch#129578 Approved by: https://github.com/angelayi
1 parent e8998d6 commit 67e22d6

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

test/export/test_converter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,18 @@ def forward(self, x):
797797
torch.randn([2, 3, 4]).to(torch.float32),
798798
torch.randn([2, 3, 4]).to(torch.float64),
799799
)
800-
self._check_equal_ts_ep_converter(func6, inp)
800+
ep_list = self._check_equal_ts_ep_converter(func6, inp)
801+
802+
# TODO: Additional check once dynamic shape is supported.
803+
# for ep in ep_list:
804+
# self.assertEqual(
805+
# ep.module()(
806+
# torch.randn([1, 1, 1]).to(torch.int8),
807+
# torch.randn([1, 1, 1]).to(torch.int32),
808+
# torch.randn([1, 1, 1]).to(torch.float32),
809+
# torch.randn([1, 1, 1]).to(torch.float64),
810+
# )[0], 1
811+
# )
801812

802813
def test_prim_tolist(self):
803814
class Module(torch.nn.Module):

torch/_export/converter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ def get_dtype_as_int(tensor):
107107
"aten::__contains__": operator.contains,
108108
"prim::dtype": get_dtype_as_int,
109109
"aten::len": len,
110+
# Mapping from specialized op to its symbolic counterpart.
111+
# They currently do not have any other overrides.
112+
"aten::numel": torch.ops.aten.sym_numel,
113+
"aten::size": torch.ops.aten.sym_size,
114+
"aten::storage_offset": torch.ops.aten.sym_storage_offset,
115+
"aten::stride": torch.ops.aten.sym_stride,
110116
}
111117

112118

@@ -497,9 +503,6 @@ def convert_prim_SetAttr(self, node: torch._C.Node):
497503
def convert_call_function_op(self, node: torch._C.Node):
498504
target = get_op_overload(node)
499505

500-
if target is torch.ops.aten.size.int:
501-
target = torch.ops.aten.sym_size.int
502-
503506
args, kwargs = self.get_args_kwargs(node, target._schema)
504507

505508
fx_node = self.fx_graph.call_function(target, args, kwargs)

0 commit comments

Comments
 (0)