Skip to content

Commit 69b9457

Browse files
author
Wei Wei
committed
[fx2trt] matmul, softmax, expand (#28)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/28 as titled Reviewed By: wushirong Differential Revision: D34997650 fbshipit-source-id: 62ce11b4ca0605f78b9022cb1271582e049f2327
1 parent 2e7e26d commit 69b9457

File tree

4 files changed

+101
-3
lines changed

4 files changed

+101
-3
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,36 @@ def acc_ops_slice_tensor(
16541654
return layer.get_output(0)
16551655

16561656

1657+
@tensorrt_converter(acc_ops.expand)
1658+
def acc_ops_expand_tensor(
1659+
network: TRTNetwork,
1660+
target: Target,
1661+
args: Tuple[Argument, ...],
1662+
kwargs: Dict[str, Argument],
1663+
name: str,
1664+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1665+
input_t = kwargs["input"]
1666+
shape = kwargs["sizes"].copy()
1667+
1668+
input_val = get_trt_tensor(network, input_t, f"{name}_input")
1669+
1670+
if network.has_implicit_batch_dimension:
1671+
shape = shape[1:]
1672+
1673+
ranks = len(input_val.shape)
1674+
# TRT does not support different dimension size
1675+
assert len(shape) == ranks
1676+
shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
1677+
1678+
inshape = tuple(input_val.shape)
1679+
shape = tuple(shape)
1680+
start = tuple([0]*ranks)
1681+
stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise
1682+
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
1683+
set_layer_name(layer, target, name)
1684+
return layer.get_output(0)
1685+
1686+
16571687
@tensorrt_converter(acc_ops.split, no_explicit_batch_dim=True)
16581688
def acc_ops_split(
16591689
network: TRTNetwork,

test/converters/acc_op/test_expand.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Owner(s): ["oncall: aiacc"]
2+
3+
import torch
4+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
5+
import torch.nn as nn
6+
from torch.testing._internal.common_fx2trt import AccTestCase
7+
from parameterized import parameterized
8+
from torch.testing._internal.common_utils import run_tests
9+
10+
11+
class TestExpandConverter(AccTestCase):
12+
@parameterized.expand(
13+
[
14+
("2d_dim", (2, 3), (2, 1)),
15+
("3d_dim", (2, 3, 4), (2, 1, 1)),
16+
("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)),
17+
("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)),
18+
]
19+
)
20+
def test_expand(self, _, sizes, init_size):
21+
class Expand(nn.Module):
22+
def forward(self, x):
23+
return x.expand(*sizes)
24+
25+
inputs = [torch.randn(*init_size)]
26+
self.run_test(
27+
Expand(),
28+
inputs,
29+
expected_ops={acc_ops.expand},
30+
)
31+
32+
if __name__ == '__main__':
33+
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,5 +2153,6 @@ def test_all_acc_ops_registered(self):
21532153
acc_ops.rescale_quantize_per_tensor,
21542154
acc_ops.rescale_quantize_per_channel,
21552155
acc_ops.nan_to_num,
2156+
acc_ops.expand,
21562157
},
21572158
)

tracer/acc_tracer/acc_ops.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,24 @@ def contiguous(*, input):
354354

355355

356356
@register_acc_op_properties(AccOpProperty.unary)
357-
@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.softmax))
357+
@register_acc_op_mapping(
358+
op_and_target=("call_method", "softmax"),
359+
arg_replacement_tuples=[
360+
("input", "input"),
361+
("dim", "dim"),
362+
("dtype", "dtype", this_arg_is_optional),
363+
],
364+
)
365+
@register_acc_op_mapping(
366+
op_and_target=("call_function", torch.nn.functional.softmax),
367+
arg_replacement_tuples=[
368+
("input", "input"),
369+
("dim", "dim"),
370+
("dtype", "dtype", this_arg_is_optional),
371+
],
372+
)
358373
@register_acc_op
359-
def softmax(*, input, dim, dtype):
374+
def softmax(*, input, dim, dtype=None):
360375
"""
361376
_stacklevel are ignored here.
362377
"""
@@ -471,7 +486,13 @@ def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
471486
new_node.meta = node.meta.copy()
472487
return new_node
473488

474-
489+
@register_acc_op_mapping(
490+
op_and_target=("call_function", operator.matmul),
491+
arg_replacement_tuples=[
492+
("input", "input"),
493+
("mat2", "other"),
494+
],
495+
)
475496
@register_acc_op_mapping(
476497
op_and_target=("call_function", torch.bmm),
477498
arg_replacement_tuples=[
@@ -1614,6 +1635,19 @@ def nan_to_num(*, input, nan=0.0, posinf=None, neginf=None):
16141635
return torch.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf)
16151636

16161637

1638+
@register_acc_op_properties(AccOpProperty.unary)
1639+
@register_acc_op_mapping(
1640+
op_and_target=("call_method", "expand"),
1641+
arg_replacement_tuples=[
1642+
("input", "input"),
1643+
("*", "sizes"),
1644+
],
1645+
)
1646+
@register_acc_op
1647+
def expand(*, input, sizes):
1648+
return input.expand(*sizes)
1649+
1650+
16171651
@register_acc_op_properties(AccOpProperty.unary)
16181652
@register_acc_op
16191653
def slice_tensor(*, input, dim, start, stop, step):

0 commit comments

Comments
 (0)