Skip to content

Commit c7fc9f4

Browse files
alexbeloiWei Wei
authored andcommitted
[fx] add acc_ops.expand(_as) and shape inference (#31)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/31 1. acc_ops a. 'call_method' `torch.Tensor.expand` will map to `acc_op.expand` b. 'call_method' `torch.Tensor.expand_as` will map to `acc_op.expand` 2. shape inference will run for `acc_op.expand` 3. acc_normalizer will map `acc_op.expand` to `acc_op.tile` using the required shape information Reviewed By: 842974287 Differential Revision: D34565863 fbshipit-source-id: ad94694e548f5b03af9cf3ef8e7c35e236bb0745
1 parent 730645d commit c7fc9f4

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tracer/acc_tracer/acc_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,3 +2192,27 @@ def gather(*, input, dim, index, sparse_grad=False):
21922192
@register_acc_op
21932193
def index_select(*, input, dim, index):
21942194
return torch.index_select(input, dim, index)
2195+
2196+
2197+
@register_custom_acc_mapper_fn(
2198+
op_and_target=("call_method", "expand_as"),
2199+
arg_replacement_tuples=[
2200+
("input", "input"),
2201+
("other", "other"),
2202+
]
2203+
)
2204+
def expand_as_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
2205+
"""
2206+
Maps expand_as(other) to expand(other.size())
2207+
"""
2208+
with node.graph.inserting_before(node):
2209+
size_node = node.graph.call_function(
2210+
size, kwargs={"input": node.kwargs["other"]}
2211+
)
2212+
size_node.meta["type"] = torch.Size
2213+
2214+
expand_node = node.graph.call_function(
2215+
expand, kwargs={"input": node.kwargs["input"], "sizes": size_node}
2216+
)
2217+
expand_node.meta = node.meta.copy()
2218+
return expand_node

0 commit comments

Comments
 (0)