Skip to content

Commit 3f3a925

Browse files
committed
aten converter- matmul, tanh, gelu, slice select
1 parent 840de6b commit 3f3a925

File tree

10 files changed

+400
-124
lines changed

10 files changed

+400
-124
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 6 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,53 +1980,8 @@ def acc_ops_slice_tensor(
19801980
kwargs: Dict[str, Argument],
19811981
name: str,
19821982
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1983-
input_val = kwargs["input"]
1984-
1985-
if not isinstance(input_val, TRTTensor):
1986-
raise RuntimeError(
1987-
f"slice_tensor received input {input_val} that is not part "
1988-
"of the TensorRT region!"
1989-
)
1990-
1991-
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
1992-
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
1993-
dynamic_shape = has_dynamic_shape(input_val.shape)
1994-
if network.has_implicit_batch_dimension:
1995-
if dim == 0:
1996-
raise RuntimeError(
1997-
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
1998-
)
1999-
dim = dim - 1
2000-
else:
2001-
if dynamic_shape:
2002-
# Check whether slice target dim is dynamic shape dim
2003-
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
2004-
2005-
start_int = cast(int, kwargs["start"])
2006-
stop_int = cast(int, kwargs["stop"])
2007-
step_int = cast(int, kwargs["step"])
2008-
start = [0] * len(input_val.shape)
2009-
start[dim] = start_int
2010-
stride = [1] * len(start)
2011-
stride[dim] = step_int
2012-
output_shape = list(input_val.shape)
2013-
output_shape[dim] = (stop_int - start_int) // step_int
2014-
2015-
if dynamic_shape > 0:
2016-
output_shape = get_shape_with_dynamic_shape(
2017-
network, output_shape, input_val, target, name
2018-
)
2019-
layer = network.add_slice(
2020-
input_val,
2021-
start=start,
2022-
shape=[] if dynamic_shape else output_shape,
2023-
stride=stride,
2024-
)
2025-
if dynamic_shape:
2026-
layer.set_input(2, output_shape)
2027-
set_layer_name(layer, target, name)
2028-
return layer.get_output(0)
2029-
1983+
return operator.add_slice(network, target, kwargs, name)
1984+
20301985

20311986
@tensorrt_converter(acc_ops.expand)
20321987
def acc_ops_expand_tensor(
@@ -2036,29 +1991,8 @@ def acc_ops_expand_tensor(
20361991
kwargs: Dict[str, Argument],
20371992
name: str,
20381993
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2039-
input_t = kwargs["input"]
2040-
shape = list(kwargs["sizes"])
2041-
2042-
input_val = get_trt_tensor(network, input_t, f"{name}_input")
2043-
2044-
if network.has_implicit_batch_dimension:
2045-
shape = shape[1:]
2046-
2047-
ranks = len(input_val.shape)
2048-
# TRT does not support different dimension size
2049-
assert len(shape) == ranks
2050-
shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
2051-
2052-
inshape = tuple(input_val.shape)
2053-
shape = tuple(shape)
2054-
start = tuple([0] * ranks)
2055-
stride = tuple(
2056-
[int(i == o) for i, o in zip(inshape, shape)]
2057-
) # stride == 1 if dimensions match, 0 otherwise
2058-
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
2059-
set_layer_name(layer, target, name)
2060-
return layer.get_output(0)
2061-
1994+
return operator.add_expand(network, target, kwargs, name)
1995+
20621996

20631997
@tensorrt_converter(acc_ops.where)
20641998
def acc_ops_where(
@@ -2754,34 +2688,8 @@ def acc_ops_gelu(
27542688
kwargs: Dict[str, Argument],
27552689
name: str,
27562690
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2757-
input_val = kwargs["input"]
2758-
approximate = kwargs["approximate"]
2759-
if approximate != "none":
2760-
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
2761-
if not isinstance(input_val, TRTTensor):
2762-
raise RuntimeError(
2763-
f"GELU received input {input_val} that is not part "
2764-
"of the TensorRT region!"
2765-
)
2766-
if network.has_implicit_batch_dimension:
2767-
raise RuntimeError(
2768-
"GeLU converter currently doesn't support implicit batch dimension"
2769-
)
2770-
2771-
plugin_name = "CustomGeluPluginDynamic"
2772-
# type_id 0 for float32, 1 for float16
2773-
type_id = trt.PluginField(
2774-
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
2775-
)
2776-
field_collection = TRTPluginFieldCollection([type_id])
2777-
plugin_version = "1"
2778-
2779-
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
2780-
2781-
layer = network.add_plugin_v2([input_val], plugin)
2782-
set_layer_name(layer, target, name)
2783-
return layer.get_output(0)
2784-
2691+
return activation.add_gelu(network, target, kwargs, name)
2692+
27852693

27862694
@tensorrt_converter(acc_ops.chunk)
27872695
def acc_ops_chunk(

py/torch_tensorrt/fx/converters/activation.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .converter_utils import mark_as_int8_layer
1414
from .converter_utils import set_layer_name
15+
from .converter_utils import get_trt_plugin
1516

1617
from ..types import (
1718
Shape,
@@ -106,6 +107,35 @@ def add_tanh(network, target, kwargs, name):
106107
operation_type = trt.ActivationType.TANH
107108
return add_activation_layer(network, input_val, operation_type, target, name)
108109

110+
def add_gelu(network, target, kwargs, name):
111+
input_val = kwargs["input"]
112+
approximate = kwargs["approximate"]
113+
if approximate != "none":
114+
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
115+
if not isinstance(input_val, TRTTensor):
116+
raise RuntimeError(
117+
f"GELU received input {input_val} that is not part "
118+
"of the TensorRT region!"
119+
)
120+
if network.has_implicit_batch_dimension:
121+
raise RuntimeError(
122+
"GeLU converter currently doesn't support implicit batch dimension"
123+
)
124+
125+
plugin_name = "CustomGeluPluginDynamic"
126+
# type_id 0 for float32, 1 for float16
127+
type_id = trt.PluginField(
128+
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
129+
)
130+
field_collection = TRTPluginFieldCollection([type_id])
131+
plugin_version = "1"
132+
133+
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
134+
135+
layer = network.add_plugin_v2([input_val], plugin)
136+
set_layer_name(layer, target, name)
137+
return layer.get_output(0)
138+
109139
def add_hard_sigmoid(network, target, kwargs, name):
110140
input_val = kwargs["input"]
111141

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -254,19 +254,6 @@ def aten_ops_mul(
254254
}
255255
return operator.add_mul(network, target, None, kwargs_new, name)
256256

257-
@tensorrt_converter(torch.ops.aten.mul.Tensor)
258-
def aten_ops_mul(
259-
network: TRTNetwork,
260-
target: Target,
261-
args: Tuple[Argument, ...],
262-
kwargs: Dict[str, Argument],
263-
name: str,
264-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
265-
kwargs_new = {
266-
"input": args[0],
267-
"other": args[1],
268-
}
269-
return operator.add_mul(network, target, None, kwargs_new, name)
270257

271258
@tensorrt_converter(torch.ops.aten.matmul.Tensor)
272259
def aten_ops_matmul(
@@ -283,7 +270,6 @@ def aten_ops_matmul(
283270
return operator.add_matmul(network, target, None, kwargs_new, name)
284271

285272

286-
287273
@tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
288274
@tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
289275
def aten_ops_pow(
@@ -392,10 +378,8 @@ def aten_ops_expand(
392378
"input": args[0],
393379
"sizes": args[1],
394380
}
395-
return acc_ops_converters.acc_ops_expand_tensor(
396-
network, target, None, kwargs_new, name
397-
)
398-
381+
return operator.add_expand(network, target, kwargs_new, name)
382+
399383

400384
@tensorrt_converter(operator.floordiv)
401385
def aten_ops_operator_floordiv(
@@ -497,8 +481,42 @@ def aten_ops_sym_size(
497481
set_layer_name(slice_layer, target, "_slice_layer")
498482
return slice_layer.get_output(0)
499483

484+
485+
@tensorrt_converter(torch.ops.aten.slice.Tensor)
486+
def aten_ops_slice(
487+
network: TRTNetwork,
488+
target: Target,
489+
args: Tuple[Argument, ...],
490+
kwargs: Dict[str, Argument],
491+
name: str,
492+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
493+
kwargs_new = {
494+
"input" : args[0],
495+
"dim" : args[1],
496+
"start" : args[2],
497+
"stop" : args[3],
498+
"step" : args[4],
499+
}
500+
return operator.add_slice(network, target. kwargs_new, name)
501+
502+
@tensorrt_converter(torch.ops.aten.select.Tensor)
503+
def aten_ops_select(
504+
network: TRTNetwork,
505+
target: Target,
506+
args: Tuple[Argument, ...],
507+
kwargs: Dict[str, Argument],
508+
name: str,
509+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
510+
kwargs_new = {
511+
"input" : args[0],
512+
"dim" : args[1],
513+
"index" : args[2],
514+
}
515+
return operator.add_select(network, target. kwargs_new, name)
516+
517+
500518
@tensorrt_converter(torch.ops.aten.leaky_relu.default)
501-
def aten_ops_relu(
519+
def aten_ops_leaky_relu(
502520
network: TRTNetwork,
503521
target: Target,
504522
args: Tuple[Argument, ...],
@@ -510,8 +528,9 @@ def aten_ops_relu(
510528
}
511529
return activation.add_leaky_relu(network, target, kwargs_new, name)
512530

531+
513532
@tensorrt_converter(torch.ops.aten.elu.default)
514-
def aten_ops_relu(
533+
def aten_ops_elu(
515534
network: TRTNetwork,
516535
target: Target,
517536
args: Tuple[Argument, ...],
@@ -521,7 +540,8 @@ def aten_ops_relu(
521540
kwargs_new = {
522541
"input": args[0],
523542
}
524-
return activation.elu(network, target, kwargs_new, name)
543+
return activation.add_elu(network, target, kwargs_new, name)
544+
525545

526546
@tensorrt_converter(torch.ops.aten.selu.default)
527547
def aten_ops_selu(
@@ -534,10 +554,11 @@ def aten_ops_selu(
534554
kwargs_new = {
535555
"input": args[0],
536556
}
537-
return activation.add_selu(network, target, kwargs_new, name)
557+
return activation.selu(network, target, kwargs_new, name)
538558

539-
@tensorrt_converter(torch.ops.aten.selu.default)
540-
def aten_ops_selu(
559+
560+
@tensorrt_converter(torch.ops.aten.gelu.default)
561+
def aten_ops_gelu(
541562
network: TRTNetwork,
542563
target: Target,
543564
args: Tuple[Argument, ...],
@@ -547,7 +568,8 @@ def aten_ops_selu(
547568
kwargs_new = {
548569
"input": args[0],
549570
}
550-
return activation.add_selu(network, target, kwargs_new, name)
571+
return activation.add_gelu(network, target, kwargs_new, name)
572+
551573

552574
@tensorrt_converter(torch.ops.aten.softsign.default)
553575
def aten_ops_softsign(
@@ -562,6 +584,7 @@ def aten_ops_softsign(
562584
}
563585
return activation.add_softsign(network, target, kwargs_new, name)
564586

587+
565588
@tensorrt_converter(torch.ops.aten.tanh.default)
566589
def aten_ops_tanh(
567590
network: TRTNetwork,
@@ -588,6 +611,7 @@ def aten_ops_softsign(
588611
}
589612
return activation.add_softsign(network, target, kwargs_new, name)
590613

614+
591615
@tensorrt_converter(torch.ops.aten.softsign.default)
592616
def aten_ops_hard_sigmoid(
593617
network: TRTNetwork,
@@ -601,6 +625,7 @@ def aten_ops_hard_sigmoid(
601625
}
602626
return activation.add_hard_sigmoid(network, target, kwargs_new, name)
603627

628+
604629
@tensorrt_converter(torch.ops.aten.sigmoid.default)
605630
def aten_ops_hard_tanh(
606631
network: TRTNetwork,
@@ -614,6 +639,7 @@ def aten_ops_hard_tanh(
614639
}
615640
return activation.add_hard_tanh(network, target, kwargs_new, name)
616641

642+
617643
@tensorrt_converter(torch.ops.aten.sigmoid.default)
618644
def aten_ops_sigmoid(
619645
network: TRTNetwork,

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_shape_with_dynamic_shape(
396396
)
397397
set_layer_name(zero_layer, target, f"{name}_zeros")
398398

399-
condition_val = add_binary_elementwise_layer(
399+
condition_val = operator.add_binary_elementwise_layer(
400400
network,
401401
scale_res,
402402
zero_layer.get_output(0),

0 commit comments

Comments
 (0)