Skip to content

Commit e488ad1

Browse files
jfix71Wei Wei
authored andcommitted
[acc_ops] Support Tensor.view torch.Shape input correctly (#22)
Summary: Pull Request resolved: pytorch/fx2trt#22 Alternative to D34911848 Reviewed By: frank-wei Differential Revision: D34930874 fbshipit-source-id: 0da1081cac09feb188d2624ba04c86d5f3ffef6c
1 parent fc4177e commit e488ad1

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,7 @@ def test_view(self):
15971597
"""
15981598

15991599
self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view(1, -1))
1600+
self._make_acc_op_function_test(acc_ops.reshape, lambda x: x.view([1, -1]))
16001601

16011602
def test_narrow(self):
16021603
"""

tracer/acc_tracer/acc_normalizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def get_normalized_kwargs(
286286
orig_kwargs_names, new_kwarg_name, is_optional = replacement_tuple
287287

288288
# Check if this is a varg and if so break/process the rest outside the loop.
289-
if len(orig_kwargs_names) == 1 and orig_kwargs_names[0] == "*":
289+
if "*" in orig_kwargs_names:
290+
assert len(orig_kwargs_names) == 1
290291
assert i == len(arg_replacement_tuples) - 1
291292
final_arg_is_varg = True
292293
break

tracer/acc_tracer/acc_ops.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,14 +1668,6 @@ def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
16681668
],
16691669
kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
16701670
)
1671-
@register_acc_op_mapping(
1672-
op_and_target=("call_method", "view"),
1673-
arg_replacement_tuples=[
1674-
("input", "input"),
1675-
("*", "shape"),
1676-
],
1677-
kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
1678-
)
16791671
@register_acc_op
16801672
def reshape(*, input, acc_out_ty=None):
16811673
assert acc_out_ty is not None
@@ -1689,11 +1681,18 @@ def reshape(*, input, acc_out_ty=None):
16891681
("*", "shape"),
16901682
],
16911683
)
1684+
@register_custom_acc_mapper_fn(
1685+
op_and_target=("call_method", "view"),
1686+
arg_replacement_tuples=[
1687+
("input", "input"),
1688+
("*", "shape"),
1689+
],
1690+
)
16921691
def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
16931692
"""
1694-
For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)).
1695-
Here we do some special handling with the `shape` arg in order to map it to
1696-
acc_ops.reshape. It also handles the case when `shape` is a list instead of
1693+
For Tensor.reshape and Tensor.view nodes, args could be (input, 1, 2, 3) or (input,
1694+
(1, 2, 3)). Here we do some special handling with the `shape` arg in order to map
1695+
it to acc_ops.reshape. It also handles the case when `shape` is a list instead of
16971696
tuple.
16981697
"""
16991698
input_node = node.kwargs["input"]

0 commit comments

Comments
 (0)