Skip to content

Commit b23c980

Browse files
authored
Reorg for converters in (FX Converter Refactor [1/N]) (#1867)
Signed-off-by: Naren Dasan <[email protected]>
1 parent c60070b commit b23c980

File tree

7 files changed

+217
-104
lines changed

7 files changed

+217
-104
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29+
from torch_tensorrt.fx.converters.impl import activation
2930

3031
_LOGGER: logging.Logger = logging.getLogger(__name__)
3132

@@ -1007,9 +1008,14 @@ def acc_ops_relu(
10071008
kwargs: Dict[str, Argument],
10081009
name: str,
10091010
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1010-
input_val = kwargs["input"]
1011-
operation_type = trt.ActivationType.RELU
1012-
return add_activation_layer(network, input_val, operation_type, target, name)
1011+
1012+
return activation.relu(
1013+
network,
1014+
target,
1015+
SourceIR.ACC,
1016+
name,
1017+
kwargs["input"],
1018+
)
10131019

10141020

10151021
@tensorrt_converter(acc_ops.leaky_relu)
@@ -1023,8 +1029,14 @@ def acc_ops_leaky_relu(
10231029
input_val = kwargs["input"]
10241030
negative_slope = kwargs["negative_slope"]
10251031
operation_type = trt.ActivationType.LEAKY_RELU
1026-
return add_activation_layer(
1027-
network, input_val, operation_type, target, name, negative_slope
1032+
return activation.convert_activation(
1033+
network,
1034+
target,
1035+
SourceIR.ACC,
1036+
name,
1037+
operation_type,
1038+
input_val,
1039+
alpha=negative_slope,
10281040
)
10291041

10301042

@@ -1039,7 +1051,9 @@ def acc_ops_elu(
10391051
input_val = kwargs["input"]
10401052
alpha = kwargs["alpha"]
10411053
operation_type = trt.ActivationType.ELU
1042-
return add_activation_layer(network, input_val, operation_type, target, name, alpha)
1054+
return activation.convert_activation(
1055+
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha
1056+
)
10431057

10441058

10451059
@tensorrt_converter(acc_ops.selu)
@@ -1052,7 +1066,14 @@ def acc_ops_selu(
10521066
) -> Union[TRTTensor, Sequence[TRTTensor]]:
10531067
input_val = kwargs["input"]
10541068
operation_type = trt.ActivationType.SELU
1055-
return add_activation_layer(network, input_val, operation_type, target, name)
1069+
return activation.convert_activation(
1070+
network,
1071+
target,
1072+
SourceIR.ACC,
1073+
name,
1074+
operation_type,
1075+
input_val,
1076+
)
10561077

10571078

10581079
@tensorrt_converter(acc_ops.softsign)
@@ -1065,7 +1086,14 @@ def acc_ops_softsign(
10651086
) -> Union[TRTTensor, Sequence[TRTTensor]]:
10661087
input_val = kwargs["input"]
10671088
operation_type = trt.ActivationType.SOFTSIGN
1068-
return add_activation_layer(network, input_val, operation_type, target, name)
1089+
return activation.convert_activation(
1090+
network,
1091+
target,
1092+
SourceIR.ACC,
1093+
name,
1094+
operation_type,
1095+
input_val,
1096+
)
10691097

10701098

10711099
@tensorrt_converter(acc_ops.sin)
@@ -1143,7 +1171,14 @@ def acc_ops_tanh(
11431171
) -> Union[TRTTensor, Sequence[TRTTensor]]:
11441172
input_val = kwargs["input"]
11451173
operation_type = trt.ActivationType.TANH
1146-
return add_activation_layer(network, input_val, operation_type, target, name)
1174+
return activation.convert_activation(
1175+
network,
1176+
target,
1177+
SourceIR.ACC,
1178+
name,
1179+
operation_type,
1180+
input_val,
1181+
)
11471182

11481183

11491184
@tensorrt_converter(acc_ops.asin)
@@ -3140,12 +3175,13 @@ def acc_ops_hard_sigmoid(
31403175
"of the TensorRT region!"
31413176
)
31423177

3143-
return add_activation_layer(
3178+
return activation.convert_activation(
31443179
network,
3145-
input_val,
3146-
trt.ActivationType.HARD_SIGMOID,
31473180
target,
3181+
SourceIR.ACC,
31483182
name,
3183+
trt.ActivationType.HARD_SIGMOID,
3184+
input_val,
31493185
alpha=1 / 6,
31503186
beta=0.5,
31513187
)
@@ -3167,8 +3203,13 @@ def acc_ops_sigmoid(
31673203
"of the TensorRT region!"
31683204
)
31693205

3170-
return add_activation_layer(
3171-
network, input_val, trt.ActivationType.SIGMOID, target, name
3206+
return activation.convert_activation(
3207+
network,
3208+
target,
3209+
SourceIR.ACC,
3210+
name,
3211+
trt.ActivationType.SIGMOID,
3212+
input_val,
31723213
)
31733214

31743215

@@ -3560,12 +3601,13 @@ def acc_ops_hardtanh(
35603601
"of the TensorRT region!"
35613602
)
35623603

3563-
return add_activation_layer(
3604+
return activation.convert_activation(
35643605
network,
3565-
input_val,
3566-
trt.ActivationType.CLIP,
35673606
target,
3607+
SourceIR.ACC,
35683608
name,
3609+
trt.ActivationType.CLIP,
3610+
input_val,
35693611
alpha=kwargs["min_val"],
35703612
beta=kwargs["max_val"],
35713613
)

py/torch_tensorrt/fx/converters/activation.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,6 @@
99
from .converter_utils import mark_as_int8_layer
1010

1111

12-
def common_activation(
13-
network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name
14-
):
15-
layer = network.add_activation(input=input_val, type=activation_type)
16-
layer.name = layer_name
17-
18-
if input_val.dynamic_range:
19-
dyn_range = activation_dyn_range_fn(input_val.dynamic_range)
20-
mark_as_int8_layer(layer, dyn_range)
21-
22-
return layer.get_output(0)
23-
24-
25-
@tensorrt_converter(torch.nn.functional.relu)
26-
@tensorrt_converter(torch.nn.modules.activation.ReLU)
27-
def relu(network, submod, args, kwargs, layer_name):
28-
# args/kwargs should have already been normalized to kwargs
29-
assert len(args) == 0
30-
input_val = kwargs["input"]
31-
32-
if not isinstance(input_val, trt.tensorrt.ITensor):
33-
raise RuntimeError(
34-
f"ReLU received input {input_val} that is not part "
35-
"of the TensorRT region!"
36-
)
37-
38-
def activation_dyn_range_fn(dyn_range):
39-
return max(0, dyn_range[0]), max(0, dyn_range[1])
40-
41-
return common_activation(
42-
network,
43-
submod,
44-
input_val,
45-
trt.ActivationType.RELU,
46-
activation_dyn_range_fn,
47-
layer_name,
48-
)
49-
50-
5112
@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
5213
def sigmoid(network, submod, args, kwargs, layer_name):
5314
# args/kwargs should have already been normalized to kwargs

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .converter_utils import * # noqa: F403
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
25+
from torch_tensorrt.fx.converters.impl import activation
2526

2627
_LOGGER: logging.Logger = logging.getLogger(__name__)
2728

@@ -290,10 +291,14 @@ def aten_ops_relu(
290291
kwargs: Dict[str, Argument],
291292
name: str,
292293
) -> Union[TRTTensor, Sequence[TRTTensor]]:
293-
kwargs_new = {
294-
"input": args[0],
295-
}
296-
return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name)
294+
295+
return activation.relu(
296+
network,
297+
target,
298+
SourceIR.ATEN,
299+
name,
300+
args[0],
301+
)
297302

298303

299304
@tensorrt_converter(torch.ops.aten.sub.Tensor)

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
44

5+
from enum import Enum, auto
56
import numpy as np
67

78
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -22,6 +23,26 @@
2223
from ..utils import torch_dtype_from_trt
2324

2425

26+
class SourceIR(Enum):
27+
NN = auto()
28+
ACC = auto()
29+
ATEN = auto()
30+
PRIM = auto()
31+
UNKNOWN = auto()
32+
33+
def __str__(self):
34+
if self == SourceIR.NN:
35+
return "nn"
36+
elif self == SourceIR.ACC:
37+
return "acc"
38+
elif self == SourceIR.ATEN:
39+
return "aten"
40+
elif self == SourceIR.PRIM:
41+
return "prim"
42+
else:
43+
return "unknown_ir"
44+
45+
2546
def get_trt_plugin(
2647
plugin_name: str,
2748
field_collection: List[TRTPluginFieldCollection],
@@ -77,7 +98,9 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
7798
return dim
7899

79100

80-
def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
101+
def set_layer_name(
102+
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
103+
) -> None:
81104
"""
82105
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
83106
@@ -86,8 +109,16 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
86109
target (Target): A fx node.target. For call_function node, it's the function that
87110
the node represents.
88111
name (str): Consists of fx node.name with optional suffix.
112+
source_ir: (Optional[SourceIR]): The IR producing the op.
89113
"""
90-
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"
114+
115+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
116+
117+
target_name = (
118+
f"{source_ir}_ops.{target}"
119+
if isinstance(target, str)
120+
else f"{source_ir}_ops.{target.__name__}"
121+
)
91122
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
92123

93124

@@ -560,48 +591,6 @@ def add_unary_layer(
560591
return layer.get_output(0)
561592

562593

563-
def add_activation_layer(
564-
network: TRTNetwork,
565-
input_val: TRTTensor,
566-
operation_type: trt.ActivationType,
567-
target: Target,
568-
name: str,
569-
alpha: Optional[Any] = None,
570-
beta: Optional[Any] = None,
571-
) -> TRTTensor:
572-
"""
573-
Add a TensorRT Activation layer to `network`.
574-
575-
Args:
576-
network (TRTNetwork): TensorRT network object.
577-
input_val (TRTTensor): Input to the activation op.
578-
Must be a TensorRT tensor.
579-
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
580-
operation.
581-
target (Target): Target of fx node.
582-
name (str): The name we want to assign to the created TensorRT layer.
583-
alpha (Optional[Any]): If not None, we will use it to set the alpha
584-
attribute of the created TensorRT activation layer.
585-
beta (Optional[Any]): If not None, we will use it to set the beta
586-
attribute of the created TensorRT activation layer.
587-
588-
Returns:
589-
The output of TensorRT Activation layer.
590-
"""
591-
if not isinstance(input_val, TRTTensor):
592-
raise RuntimeError(
593-
f"{operation_type} received input {input_val} that is not part "
594-
"of the TensorRT region!"
595-
)
596-
layer = network.add_activation(input_val, operation_type)
597-
if alpha is not None:
598-
layer.alpha = alpha
599-
if beta is not None:
600-
layer.beta = beta
601-
set_layer_name(layer, target, name)
602-
return layer.get_output(0)
603-
604-
605594
def add_reduce_layer(
606595
network: TRTNetwork,
607596
target: Target,

py/torch_tensorrt/fx/converters/impl/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)