Skip to content

Commit db15d27

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and rsub
Rsub error fixes and linting error fixed Rsub test case to include different inputs
1 parent ce3fa67 commit db15d27

File tree

3 files changed

+133
-1
lines changed

3 files changed

+133
-1
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2525
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
2626
from torch_tensorrt.fx.converters.impl.elementwise import fmod
27+
from torch_tensorrt.fx.converters.impl.elementwise import rsub
2728

2829
_LOGGER: logging.Logger = logging.getLogger(__name__)
2930

@@ -452,6 +453,20 @@ def aten_ops_reshape(
452453
return layer.get_output(0)
453454

454455

456+
@tensorrt_converter(torch.ops.aten.rsub.Tensor)
457+
def aten_ops_rsub(
458+
network: TRTNetwork,
459+
target: Target,
460+
args: Tuple[Argument, ...],
461+
kwargs: Dict[str, Argument],
462+
name: str,
463+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
464+
alpha = None
465+
if "alpha" in kwargs:
466+
alpha = kwargs["alpha"]
467+
return rsub(network, target, SourceIR.ATEN, name, args[0], args[1], alpha)
468+
469+
455470
@tensorrt_converter(torch.ops.aten.tanh.default)
456471
def aten_ops_tanh(
457472
network: TRTNetwork,
@@ -460,7 +475,6 @@ def aten_ops_tanh(
460475
kwargs: Dict[str, Argument],
461476
name: str,
462477
) -> Union[TRTTensor, Sequence[TRTTensor]]:
463-
464478
return activation.tanh(
465479
network,
466480
target,

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,44 @@ def rsqrt(
141141
return output
142142

143143

144+
def rsub(
145+
network: TRTNetwork,
146+
target: Target,
147+
source_ir: Optional[SourceIR],
148+
name: str,
149+
input: TRTTensor,
150+
other: TRTTensor,
151+
) -> TRTTensor:
152+
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
153+
trunc_div_value = trunc_div(
154+
network,
155+
target,
156+
source_ir,
157+
name + "_trunc_div",
158+
input,
159+
other,
160+
)
161+
prod_value = convert_binary_elementwise(
162+
network,
163+
target,
164+
source_ir,
165+
name + "_prod",
166+
trt.ElementWiseOperation.PROD,
167+
trunc_div_value,
168+
other,
169+
)
170+
sub_value = convert_binary_elementwise(
171+
network,
172+
target,
173+
SourceIR.ACC,
174+
name + "_sub",
175+
trt.ElementWiseOperation.SUB,
176+
input,
177+
prod_value,
178+
)
179+
return sub_value
180+
181+
144182
def fmod(
145183
network: TRTNetwork,
146184
target: Target,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests, TestCase
5+
from torch_tensorrt.dynamo import compile
6+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
7+
8+
9+
class TestRSubConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_alpha", (2, 1), 2),
13+
("3d_dim_alpha", (2, 1, 2), 2),
14+
]
15+
)
16+
def test_rsub_same(self, _, x, alpha):
17+
class rsub(nn.Module):
18+
def forward(self, input):
19+
return torch.rsub(input, input, alpha=alpha)
20+
21+
inputs = [torch.randn(x)]
22+
self.run_test(
23+
rsub(),
24+
inputs,
25+
expected_ops={torch.ops.aten.rsub.Tensor},
26+
)
27+
28+
# @parameterized.expand(
29+
# [
30+
# ("2d_dim_alpha", (2, 1), 2),
31+
# ("3d_dim_alpha", (2, 1, 2), 2),
32+
# ]
33+
# )
34+
# def test_rsub_diff(self, _, x, alpha):
35+
# class rsub(nn.Module):
36+
# def forward(self, inputOne, inputTwo):
37+
# return torch.rsub(inputOne, inputTwo, alpha=alpha)
38+
39+
# inputOne = [torch.randn(x)]
40+
# inputTwo = [torch.randn(x)]
41+
# inputs = (inputOne, inputTwo)
42+
# self.run_test(
43+
# rsub(),
44+
# inputs,
45+
# expected_ops={torch.ops.aten.rsub.Tensor},
46+
# )
47+
48+
class TestRSubDiff(TestCase):
49+
def test_rsub_diff(self):
50+
class rsub_diff(nn.Module):
51+
def forward(self, inputOne, inputTwo):
52+
return torch.rsub(inputOne, inputTwo, alpha=2)
53+
inputOne = torch.randn(2,1).cuda()
54+
inputTwo = torch.randn(2,1).cuda()
55+
alpha = 2
56+
inputs = [inputOne, inputTwo]
57+
fx_graph = torch.fx.symbolic_trace(rsub_diff())
58+
torch._dynamo.reset()
59+
optimized_model = compile(
60+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
61+
)
62+
optimized_model_results = optimized_model(*inputs).detach().cpu()
63+
torch_model_results = fx_graph(*inputs).detach().cpu()
64+
max_diff = float(
65+
torch.max(torch.abs(optimized_model_results - torch_model_results))
66+
)
67+
self.assertAlmostEqual(
68+
max_diff,
69+
0,
70+
5,
71+
f"Reciprocal TRT outputs don't match with the original model.",
72+
)
73+
74+
if __name__ == "__main__":
75+
run_tests()
76+
#Test two
77+
78+
79+
80+

0 commit comments

Comments
 (0)