Skip to content

Commit 737b667

Browse files
committed
refactor: Reorging to reduce code duplication and seperating TRT implementation, example changes with ReLU
Signed-off-by: Naren Dasan <[email protected]>
1 parent c5cc6e3 commit 737b667

File tree

6 files changed

+147
-48
lines changed

6 files changed

+147
-48
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29+
from torch_tensorrt.fx.converters.impl import activation
2930

3031
_LOGGER: logging.Logger = logging.getLogger(__name__)
3132

@@ -1004,9 +1005,8 @@ def acc_ops_relu(
10041005
kwargs: Dict[str, Argument],
10051006
name: str,
10061007
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1007-
input_val = kwargs["input"]
1008-
operation_type = trt.ActivationType.RELU
1009-
return add_activation_layer(network, input_val, operation_type, target, name)
1008+
1009+
return activation.convert_relu(network, target, kwargs, name, SourceIR.ACC)
10101010

10111011

10121012
@tensorrt_converter(acc_ops.leaky_relu)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .converter_utils import * # noqa: F403
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
25+
from torch_tensorrt.fx.converters.impl import activation
2526

2627
_LOGGER: logging.Logger = logging.getLogger(__name__)
2728

@@ -293,7 +294,9 @@ def aten_ops_relu(
293294
kwargs_new = {
294295
"input": args[0],
295296
}
296-
return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name)
297+
return activation.convert_relu(
298+
network, target, kwargs_new, name, source_ir=SourceIR.ATEN
299+
)
297300

298301

299302
@tensorrt_converter(torch.ops.aten.sub.Tensor)

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
44

5+
from enum import Enum, auto
56
import numpy as np
67

78
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -22,6 +23,26 @@
2223
from ..utils import torch_dtype_from_trt
2324

2425

26+
class SourceIR(Enum):
27+
NN = auto()
28+
ACC = auto()
29+
ATEN = auto()
30+
PRIM = auto()
31+
UNKNOWN = auto()
32+
33+
def __str__(self):
34+
if self == SourceIR.NN:
35+
return "nn"
36+
elif self == SourceIR.ACC:
37+
return "acc"
38+
elif self == SourceIR.ATEN:
39+
return "aten"
40+
elif self == SourceIR.PRIM:
41+
return "prim"
42+
else:
43+
return "unknown_ir"
44+
45+
2546
def get_trt_plugin(
2647
plugin_name: str,
2748
field_collection: List[TRTPluginFieldCollection],
@@ -77,7 +98,9 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
7798
return dim
7899

79100

80-
def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
101+
def set_layer_name(
102+
layer: TRTLayer, target: Target, name: str, source_ir: SourceIR = SourceIR.UNKNOWN
103+
) -> None:
81104
"""
82105
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
83106
@@ -87,7 +110,11 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
87110
the node represents.
88111
name (str): Consists of fx node.name with optional suffix.
89112
"""
90-
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"
113+
target_name = (
114+
f"{source_ir}_ops.{target}"
115+
if isinstance(target, str)
116+
else f"{source_ir}_ops.{target.__name__}"
117+
)
91118
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
92119

93120

@@ -560,48 +587,6 @@ def add_unary_layer(
560587
return layer.get_output(0)
561588

562589

563-
def add_activation_layer(
564-
network: TRTNetwork,
565-
input_val: TRTTensor,
566-
operation_type: trt.ActivationType,
567-
target: Target,
568-
name: str,
569-
alpha: Optional[Any] = None,
570-
beta: Optional[Any] = None,
571-
) -> TRTTensor:
572-
"""
573-
Add a TensorRT Activation layer to `network`.
574-
575-
Args:
576-
network (TRTNetwork): TensorRT network object.
577-
input_val (TRTTensor): Input to the activation op.
578-
Must be a TensorRT tensor.
579-
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
580-
operation.
581-
target (Target): Target of fx node.
582-
name (str): The name we want to assign to the created TensorRT layer.
583-
alpha (Optional[Any]): If not None, we will use it to set the alpha
584-
attribute of the created TensorRT activation layer.
585-
beta (Optional[Any]): If not None, we will use it to set the beta
586-
attribute of the created TensorRT activation layer.
587-
588-
Returns:
589-
The output of TensorRT Activation layer.
590-
"""
591-
if not isinstance(input_val, TRTTensor):
592-
raise RuntimeError(
593-
f"{operation_type} received input {input_val} that is not part "
594-
"of the TensorRT region!"
595-
)
596-
layer = network.add_activation(input_val, operation_type)
597-
if alpha is not None:
598-
layer.alpha = alpha
599-
if beta is not None:
600-
layer.beta = beta
601-
set_layer_name(layer, target, name)
602-
return layer.get_output(0)
603-
604-
605590
def add_reduce_layer(
606591
network: TRTNetwork,
607592
target: Target,

py/torch_tensorrt/fx/converters/impl/__init__.py

Whitespace-only changes.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import operator
3+
import warnings
4+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5+
6+
# @manual=//deeplearning/trt/python:py_tensorrt
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Argument, Target
10+
11+
12+
from torch_tensorrt.fx.converters.converter_utils import mark_as_int8_layer
13+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
14+
from torch_tensorrt.fx.converters.converter_utils import SourceIR
15+
16+
from torch_tensorrt.fx.types import (
17+
TRTNetwork,
18+
TRTTensor,
19+
)
20+
21+
22+
def convert_activation(
23+
network: TRTNetwork,
24+
input_val: TRTTensor,
25+
operation_type: trt.ActivationType,
26+
target: Target,
27+
name: str,
28+
alpha: Optional[Any] = None,
29+
beta: Optional[Any] = None,
30+
dyn_range_fn: Optional[Callable[[float, float], Any]] = None,
31+
source_ir: SourceIR = SourceIR.UNKNOWN,
32+
) -> TRTTensor:
33+
"""
34+
Add a TensorRT Activation layer to `network`.
35+
36+
Args:
37+
network (TRTNetwork): TensorRT network object.
38+
input_val (TRTTensor): Input to the activation op.
39+
Must be a TensorRT tensor.
40+
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
41+
operation.
42+
target (Target): Target of fx node.
43+
name (str): The name we want to assign to the created TensorRT layer.
44+
alpha (Optional[Any]): If not None, we will use it to set the alpha
45+
attribute of the created TensorRT activation layer.
46+
beta (Optional[Any]): If not None, we will use it to set the beta
47+
attribute of the created TensorRT activation layer.
48+
dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range
49+
50+
51+
Returns:
52+
The output of TensorRT Activation layer.
53+
"""
54+
if not isinstance(input_val, TRTTensor):
55+
raise RuntimeError(
56+
f"{operation_type} received input {input_val} that is not part "
57+
"of the TensorRT region!"
58+
)
59+
layer = network.add_activation(input_val, operation_type)
60+
if alpha is not None:
61+
layer.alpha = alpha
62+
if beta is not None:
63+
layer.beta = beta
64+
set_layer_name(layer, target, name, source_ir)
65+
66+
if input_val.dynamic_range is not None:
67+
dyn_range = dyn_range_fn(input_val.dynamic_range)
68+
mark_as_int8_layer(layer, dyn_range)
69+
return layer.get_output(0)
70+
71+
72+
def convert_relu(
73+
network: TRTNetwork,
74+
target: Target,
75+
kwargs: Dict[str, Any],
76+
name: str,
77+
source_ir: SourceIR = SourceIR.UNKNOWN,
78+
):
79+
input_val = kwargs["input"]
80+
operation_type = trt.ActivationType.RELU
81+
82+
def relu_dyn_range_fn(dyn_range):
83+
return max(0, dyn_range[0]), max(0, dyn_range[1])
84+
85+
return convert_activation(
86+
network, input_val, operation_type, target, name, relu_dyn_range_fn, source_ir
87+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
3+
# @manual=//deeplearning/trt/python:py_tensorrt
4+
import tensorrt as trt
5+
import torch
6+
7+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
8+
from torch_tensorrt.fx.converters.impl import activation
9+
from torch_tensorrt.fx.converters.converter_utils import SourceIR
10+
11+
12+
@tensorrt_converter(torch.nn.functional.relu)
13+
@tensorrt_converter(torch.nn.modules.activation.ReLU)
14+
def relu(network, submod, args, kwargs, layer_name):
15+
# args/kwargs should have already been normalized to kwargs
16+
assert len(args) == 0
17+
18+
return activation.convert_relu(
19+
network=network,
20+
target="torch.nn.functional.relu",
21+
kwargs=kwargs,
22+
name=layer_name,
23+
source_ir=SourceIR.NN,
24+
)

0 commit comments

Comments
 (0)