Skip to content

Commit 8303cd5

Browse files
committed
aten::matmul, aten::slice, aten::select converters
1 parent 4f18c0f commit 8303cd5

File tree

5 files changed

+255
-26
lines changed

5 files changed

+255
-26
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,37 @@ def aten_ops_select(
586586
"index": args[2],
587587
}
588588
return add_select(network, target, kwargs_new, name)
589+
590+
591+
@tensorrt_converter(torch.ops.aten.slice.Tensor)
592+
def aten_ops_slice(
593+
network: TRTNetwork,
594+
target: Target,
595+
args: Tuple[Argument, ...],
596+
kwargs: Dict[str, Argument],
597+
name: str,
598+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
599+
kwargs_new = {
600+
"input": args[0],
601+
"dim": args[1],
602+
"start": args[2],
603+
"stop": args[3],
604+
"step": args[4],
605+
}
606+
return add_slice(network, target, kwargs_new, name)
607+
608+
609+
@tensorrt_converter(torch.ops.aten.matmul)
610+
@tensorrt_converter(torch.ops.aten.mm.default)
611+
def aten_ops_matmul(
612+
network: TRTNetwork,
613+
target: Target,
614+
args: Tuple[Argument, ...],
615+
kwargs: Dict[str, Argument],
616+
name: str,
617+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
618+
kwargs_new = {
619+
"input": args[0],
620+
"other": args[1],
621+
}
622+
return add_matmul(network, target, kwargs_new, name)

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,11 +1165,92 @@ def add_select(network, target, kwargs, name):
11651165
output_shape = get_shape_with_dynamic_shape(
11661166
network, output_shape, input_val, target, name
11671167
)
1168-
input_shape = network.add_shape(input_val).get_output(0)
1169-
dim_value = torch.tensor(dim, dtype=torch.int32)
1170-
axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0)
1171-
layer = network.add_gather(input_shape, axis, index)
1168+
index_value = torch.tensor(index, dtype=torch.int32)
1169+
indices_tensor = network.add_constant(
1170+
index_value.shape, to_numpy(index_value)
1171+
).get_output(0)
1172+
layer = network.add_gather(input_val, indices_tensor, dim)
11721173
out = layer.get_output(0)
11731174
if len(out.shape) != 1:
11741175
layer = network.add_shuffle(out)
11751176
return layer.get_output(0)
1177+
1178+
1179+
def add_slice(network, target, kwargs, name):
1180+
input_val = kwargs["input"]
1181+
1182+
if not isinstance(input_val, TRTTensor):
1183+
raise RuntimeError(
1184+
f"slice_tensor received input {input_val} that is not part "
1185+
"of the TensorRT region!"
1186+
)
1187+
1188+
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
1189+
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
1190+
dynamic_shape = has_dynamic_shape(input_val.shape)
1191+
if network.has_implicit_batch_dimension:
1192+
if dim == 0:
1193+
raise RuntimeError(
1194+
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
1195+
)
1196+
dim = dim - 1
1197+
else:
1198+
if dynamic_shape:
1199+
# Check whether slice target dim is dynamic shape dim
1200+
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
1201+
1202+
start_int = cast(int, kwargs["start"])
1203+
stop_int = cast(int, kwargs["stop"])
1204+
step_int = cast(int, kwargs["step"])
1205+
start = [0] * len(input_val.shape)
1206+
start[dim] = start_int
1207+
stride = [1] * len(start)
1208+
stride[dim] = step_int
1209+
output_shape = list(input_val.shape)
1210+
output_shape[dim] = (stop_int - start_int) // step_int + 1
1211+
1212+
if dynamic_shape > 0:
1213+
output_shape = get_shape_with_dynamic_shape(
1214+
network, output_shape, input_val, target, name
1215+
)
1216+
layer = network.add_slice(
1217+
input_val,
1218+
start=start,
1219+
shape=[] if dynamic_shape else output_shape,
1220+
stride=stride,
1221+
)
1222+
if dynamic_shape:
1223+
layer.set_input(2, output_shape)
1224+
set_layer_name(layer, target, name)
1225+
return layer.get_output(0)
1226+
1227+
1228+
def add_matmul(network, target, kwargs, name):
1229+
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
1230+
other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")
1231+
1232+
for i in [input_val, other_val]:
1233+
if not isinstance(i, TRTTensor):
1234+
raise RuntimeError(
1235+
f"matmul received input {i} that is not part of the TensorRT region!"
1236+
)
1237+
1238+
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
1239+
preset_diff = 0
1240+
1241+
if len(input_val.shape) == 1:
1242+
preset_diff -= 1
1243+
input_matrix_op = trt.MatrixOperation.VECTOR
1244+
1245+
if len(other_val.shape) == 1:
1246+
preset_diff += 1
1247+
other_matrix_op = trt.MatrixOperation.VECTOR
1248+
1249+
input_val, other_val = broadcast(
1250+
network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
1251+
)
1252+
layer = network.add_matrix_multiply(
1253+
input_val, input_matrix_op, other_val, other_matrix_op
1254+
)
1255+
set_layer_name(layer, target, name)
1256+
return layer.get_output(0)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
5+
from parameterized import param, parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
8+
9+
10+
class TestMatMulConverter(DispatchTestCase):
11+
def test_matmul(self):
12+
class TestModule(torch.nn.Module):
13+
def forward(self, x, y):
14+
return torch.matmul(x, y)
15+
16+
inputOne = torch.randn(2, 32)
17+
inputTwo = torch.randn(32, 2)
18+
inputs = [inputOne, inputTwo]
19+
self.run_test(
20+
TestModule(),
21+
inputs,
22+
expected_ops={torch.ops.aten.mm.default},
23+
)
24+
25+
26+
if __name__ == "__main__":
27+
run_tests()

py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
88

99

10-
class TestSelectConverter(DispatchTestCase):
10+
class TestSelectConverterOne(DispatchTestCase):
1111
@parameterized.expand(
1212
[
13-
("select_dim_index", 2, 1),
13+
("select_dim_index", 1, 0),
1414
]
1515
)
1616
def test_select(self, _, dim, index):
@@ -21,33 +21,62 @@ def __init__(self):
2121
def forward(self, input):
2222
return torch.select(input, dim, index)
2323

24-
input = [torch.randn(1, 3, 32)]
24+
input = [torch.randn(1, 2)]
2525
self.run_test(
2626
TestModule(),
2727
input,
2828
expected_ops={torch.ops.aten.select.int},
2929
test_explicit_precision=True,
3030
)
3131

32-
# def test_select_with_dynamic_shape(self, _, dim_test, index_test):
33-
# class TestModule(torch.nn.Module):
34-
# def __init__(self, dim, index):
35-
# super().__init__()
36-
# self.dim = dim
37-
# self.index = index
38-
# def forward(self, input):
39-
# return torch.select(input, self.dim, self.index)
40-
41-
# input_spec = [
42-
# InputTensorSpec(
43-
# shape=(-1, 3, 32),
44-
# dtype=torch.float32,
45-
# shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))],
46-
# ),
47-
# ]
48-
# self.run_test_with_dynamic_shape(
49-
# TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select}
50-
# )
32+
33+
class TestSelectConverterTwo(DispatchTestCase):
34+
@parameterized.expand(
35+
[
36+
("select_dim_index", 1, 0),
37+
]
38+
)
39+
def test_select(self, _, dim, index):
40+
class TestModule(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
44+
def forward(self, input):
45+
return torch.select(input, dim, index)
46+
47+
input = [torch.randn(4, 4, 4, 4)]
48+
self.run_test(
49+
TestModule(),
50+
input,
51+
expected_ops={torch.ops.aten.select.int},
52+
test_explicit_precision=True,
53+
)
54+
55+
56+
class TestSelectConverterWithDynamicShape(DispatchTestCase):
57+
@parameterized.expand(
58+
[
59+
("select_dim_index", 1, 0),
60+
]
61+
)
62+
def test_select_with_dynamic_shape(self, _, dim, index):
63+
class TestModule(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
67+
def forward(self, input):
68+
return torch.select(input, dim, index)
69+
70+
input_spec = [
71+
InputTensorSpec(
72+
shape=(-1, 3, 3),
73+
dtype=torch.float32,
74+
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
75+
),
76+
]
77+
self.run_test_with_dynamic_shape(
78+
TestModule(), input_spec, expected_ops={torch.ops.aten.select.int}
79+
)
5180

5281

5382
if __name__ == "__main__":
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
5+
from parameterized import param, parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
8+
9+
10+
class TestSelectConverterImplicitBatch(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("select_dim_start_stop_step", 0, 0, 7, 2),
14+
]
15+
)
16+
def test_slice(self, _, dim, start, stop, step):
17+
class TestModule(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, input):
22+
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
23+
return out
24+
25+
input = [torch.randn(10, 2, 3, 1)]
26+
self.run_test(
27+
TestModule(),
28+
input,
29+
expected_ops={torch.ops.aten.slice.Tensor},
30+
)
31+
32+
33+
class TestSelectConverterExplicitBatch(DispatchTestCase):
34+
@parameterized.expand(
35+
[
36+
("select_dim_start_stop_step", 1, 0, 7, 2),
37+
]
38+
)
39+
def test_slice(self, _, dim, start, stop, step):
40+
class TestModule(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
44+
def forward(self, input):
45+
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
46+
return out
47+
48+
input = [torch.randn(10, 10, 3, 1)]
49+
self.run_test(
50+
TestModule(),
51+
input,
52+
expected_ops={torch.ops.aten.slice.Tensor},
53+
test_explicit_precision=True,
54+
)
55+
56+
57+
if __name__ == "__main__":
58+
run_tests()

0 commit comments

Comments
 (0)