Skip to content

Commit 2de2db6

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and adding rsqrt converter
1 parent 9611d67 commit 2de2db6

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2525
from torch_tensorrt.fx.converters.impl import activation
2626
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
27+
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
2728

2829
_LOGGER: logging.Logger = logging.getLogger(__name__)
2930

@@ -300,6 +301,42 @@ def aten_ops_relu(
300301
)
301302

302303

304+
@tensorrt_converter(torch.ops.aten.relu.default)
305+
def aten_ops_relu(
306+
network: TRTNetwork,
307+
target: Target,
308+
args: Tuple[Argument, ...],
309+
kwargs: Dict[str, Argument],
310+
name: str,
311+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
312+
313+
return activation.relu(
314+
network,
315+
target,
316+
SourceIR.ATEN,
317+
name,
318+
args[0],
319+
)
320+
321+
322+
@tensorrt_converter(torch.ops.aten.rsqrt.default)
323+
def aten_ops_rsqrt(
324+
network: TRTNetwork,
325+
target: Target,
326+
args: Tuple[Argument, ...],
327+
kwargs: Dict[str, Argument],
328+
name: str,
329+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
330+
331+
return rsqrt(
332+
network,
333+
target,
334+
SourceIR.ATEN,
335+
name,
336+
args[0],
337+
)
338+
339+
303340
@tensorrt_converter(torch.ops.aten.sub.Tensor)
304341
def aten_ops_sub(
305342
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,33 @@ def trunc_div(
109109
)
110110

111111
return output
112+
113+
114+
def rsqrt(
115+
network: TRTNetwork,
116+
target: Target,
117+
source_ir: Optional[SourceIR],
118+
name: str,
119+
input: TRTTensor,
120+
other: TRTTensor,
121+
) -> TRTTensor:
122+
123+
sqrt_trt_output = convert_unary(
124+
network,
125+
target,
126+
source_ir,
127+
f"{name}"_sqrt,
128+
trt.UnaryOperation.SQRT,
129+
input,
130+
)
131+
132+
output = convert_binary_elementwise(
133+
network,
134+
1,
135+
sqrt_trt_output,
136+
trt.ElementWiseOperation.DIV,
137+
target,
138+
f"{name}_outpur",
139+
)
140+
141+
return output
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
6+
7+
8+
class TestRSubConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
("2d_dim_alpha", (2, 1), 2),
12+
("3d_dim_alpha", (2, 1, 2), 2),
13+
]
14+
)
15+
def test_rsqrt(self, _, x, alpha):
16+
class rsqrt(nn.Module):
17+
def forward(self, input):
18+
return torch.rsqrt(input)
19+
20+
inputs = [torch.randn(x) + 1]
21+
self.run_test(
22+
rsqrt(),
23+
inputs,
24+
expected_ops={torch.ops.aten.rsqrt.default},
25+
)
26+
27+
28+
if __name__ == "__main__":
29+
run_tests()

0 commit comments

Comments
 (0)