Skip to content

Commit 37d1168

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and gelu
Linting error
1 parent 1ba6d13 commit 37d1168

File tree

5 files changed

+140
-27
lines changed

5 files changed

+140
-27
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,33 +3401,14 @@ def acc_ops_gelu(
34013401
kwargs: Dict[str, Argument],
34023402
name: str,
34033403
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3404-
input_val = kwargs["input"]
3405-
approximate = kwargs["approximate"]
3406-
if approximate != "none":
3407-
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
3408-
if not isinstance(input_val, TRTTensor):
3409-
raise RuntimeError(
3410-
f"GELU received input {input_val} that is not part "
3411-
"of the TensorRT region!"
3412-
)
3413-
if network.has_implicit_batch_dimension:
3414-
raise RuntimeError(
3415-
"GeLU converter currently doesn't support implicit batch dimension"
3416-
)
3417-
3418-
plugin_name = "CustomGeluPluginDynamic"
3419-
# type_id 0 for float32, 1 for float16
3420-
type_id = trt.PluginField(
3421-
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
3404+
return activation.gelu(
3405+
network,
3406+
target,
3407+
SourceIR.ACC,
3408+
name,
3409+
kwargs["input"],
3410+
kwargs["approximate"],
34223411
)
3423-
field_collection = TRTPluginFieldCollection([type_id])
3424-
plugin_version = "1"
3425-
3426-
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
3427-
3428-
layer = network.add_plugin_v2([input_val], plugin)
3429-
set_layer_name(layer, target, name)
3430-
return layer.get_output(0)
34313412

34323413

34333414
@tensorrt_converter(acc_ops.chunk)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,23 @@ def aten_ops_hardtanh(
242242
)
243243

244244

245+
@tensorrt_converter(torch.ops.aten.gelu.default)
246+
def aten_ops_gelu(
247+
network: TRTNetwork,
248+
target: Target,
249+
args: Tuple[Argument, ...],
250+
kwargs: Dict[str, Argument],
251+
name: str,
252+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
253+
return activation.gelu(
254+
network,
255+
target,
256+
SourceIR.ATEN,
257+
name,
258+
args[0],
259+
)
260+
261+
245262
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
246263
def aten_ops_fmod(
247264
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
import warnings
44
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5+
import math
56

67
# @manual=//deeplearning/trt/python:py_tensorrt
78
import tensorrt as trt
@@ -11,11 +12,15 @@
1112

1213
from torch_tensorrt.fx.converters.converter_utils import mark_as_int8_layer
1314
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
14-
from torch_tensorrt.fx.converters.converter_utils import SourceIR
15+
from torch_tensorrt.fx.converters.converter_utils import (
16+
SourceIR,
17+
get_trt_plugin,
18+
)
1519

1620
from torch_tensorrt.fx.types import (
1721
TRTNetwork,
1822
TRTTensor,
23+
TRTPluginFieldCollection,
1924
)
2025

2126

@@ -250,3 +255,47 @@ def elu_dyn_range_fn(dyn_range):
250255
input_val,
251256
dyn_range_fn=elu_dyn_range_fn,
252257
)
258+
259+
260+
def gelu(
261+
network: TRTNetwork,
262+
target: Target,
263+
source_ir: Optional[SourceIR],
264+
name: str,
265+
input_val: TRTTensor,
266+
alpha: Optional[Any] = None,
267+
):
268+
approximate = alpha
269+
if approximate is not None:
270+
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
271+
if not isinstance(input_val, TRTTensor):
272+
raise RuntimeError(
273+
f"GELU received input {input_val} that is not part "
274+
"of the TensorRT region!"
275+
)
276+
if network.has_implicit_batch_dimension:
277+
raise RuntimeError(
278+
"GeLU converter currently doesn't support implicit batch dimension"
279+
)
280+
plugin_name = "CustomGeluPluginDynamic"
281+
# type_id 0 for float32, 1 for float16
282+
type_id = trt.PluginField(
283+
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
284+
)
285+
field_collection = TRTPluginFieldCollection([type_id])
286+
plugin_version = "1"
287+
288+
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
289+
290+
layer = network.add_plugin_v2([input_val], plugin)
291+
292+
def gelu_dyn_range_fn(dyn_range):
293+
return (
294+
dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))
295+
), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))))
296+
297+
if input_val.dynamic_range is not None:
298+
dyn_range = gelu_dyn_range_fn(input_val.dynamic_range)
299+
mark_as_int8_layer(layer, dyn_range)
300+
set_layer_name(layer, target, name)
301+
return layer.get_output(0)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ def tanh(network, submod, args, kwargs, layer_name):
6868
)
6969

7070

71+
@tensorrt_converter(torch.nn.functional.gelu)
72+
@tensorrt_converter(torch.nn.modules.activation.GELU)
73+
def gelu(network, submod, args, kwargs, layer_name):
74+
# args/kwargs should have already been normalized to kwargs
75+
assert len(args) == 0
76+
77+
return activation.gelu(
78+
network=network,
79+
target="torch.nn.functional.gelu",
80+
source_ir=SourceIR.NN,
81+
name=layer_name,
82+
input_val=kwargs["input"],
83+
)
84+
85+
7186
@tensorrt_converter(torch.nn.functional.leaky_relu)
7287
@tensorrt_converter(torch.nn.modules.activation.LeakyReLU)
7388
def leaky_relu(network, submod, args, kwargs, layer_name):
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestGeLUConverter(DispatchTestCase):
8+
def test_gelu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.gelu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default})
15+
16+
def test_gelu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.gelu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default}
30+
)
31+
32+
def test_gelu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.gelu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

0 commit comments

Comments
 (0)