Skip to content

Commit 8c8e897

Browse files
committed
Fixing matmul, select, tanh tests
1 parent 3f3a925 commit 8c8e897

File tree

8 files changed

+113
-154
lines changed

8 files changed

+113
-154
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
import activation
30-
import operator
29+
30+
from .activation import *
31+
from .operator import *
3132

3233
_LOGGER: logging.Logger = logging.getLogger(__name__)
3334

py/torch_tensorrt/fx/converters/activation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def add_activation_layer(
3333
name: str,
3434
alpha: Optional[Any] = None,
3535
beta: Optional[Any] = None,
36-
dyn_range_fn: Optional[Callable[Tuple[float, float]]] = None
36+
dyn_range_fn: Optional[Callable[[float, float], Any]] = None
3737
) -> TRTTensor:
3838
"""
3939
Add a TensorRT Activation layer to `network`.
@@ -109,9 +109,10 @@ def add_tanh(network, target, kwargs, name):
109109

110110
def add_gelu(network, target, kwargs, name):
111111
input_val = kwargs["input"]
112-
approximate = kwargs["approximate"]
113-
if approximate != "none":
114-
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
112+
if "approximate" in kwargs.keys():
113+
approximate = kwargs["approximate"]
114+
if approximate != "none":
115+
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
115116
if not isinstance(input_val, TRTTensor):
116117
raise RuntimeError(
117118
f"GELU received input {input_val} that is not part "

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 34 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
2222

2323
from .converter_utils import * # noqa: F403
24+
from .activation import *
25+
from .operator import *
26+
2427
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
25-
import activation
26-
import operator
28+
2729

2830
_LOGGER: logging.Logger = logging.getLogger(__name__)
2931

@@ -40,7 +42,7 @@ def aten_ops_add(
4042
"input": args[0],
4143
"other": args[1],
4244
}
43-
return operator.add_add(network, target, None, kwargs_new, name)
45+
return add_add(network, target, None, kwargs_new, name)
4446

4547

4648
@tensorrt_converter(torch.ops.aten.mean.dim)
@@ -143,13 +145,13 @@ def aten_ops_div(
143145
}
144146
rounding_mode = kwargs.get("rounding_mode")
145147
if rounding_mode is None:
146-
return operator.add_div(network, target, None, kwargs_new, name)
148+
return add_div(network, target, None, kwargs_new, name)
147149
elif rounding_mode == "floor":
148-
return operator.add_floor_div(
150+
return add_floor_div(
149151
network, target, None, kwargs_new, name
150152
)
151153
elif rounding_mode == "trunc":
152-
return operator.add_trunc_div(
154+
return add_trunc_div(
153155
network, target, None, kwargs_new, name
154156
)
155157
else:
@@ -170,7 +172,7 @@ def aten_ops_floor_div(
170172
"input": args[0],
171173
"other": args[1],
172174
}
173-
return operator.add_floor_div(network, target, None, kwargs_new, name)
175+
return add_floor_div(network, target, None, kwargs_new, name)
174176

175177

176178
@tensorrt_converter(torch.ops.aten.fmod.Scalar)
@@ -186,7 +188,7 @@ def aten_ops_fmod(
186188
"input": args[0],
187189
"other": args[1],
188190
}
189-
return operator.add_fmod(network, target, None, kwargs_new, name)
191+
return add_fmod(network, target, None, kwargs_new, name)
190192

191193

192194
@tensorrt_converter(torch.ops.aten.linear)
@@ -203,7 +205,7 @@ def aten_ops_linear(
203205
"bias": args[2],
204206
}
205207

206-
return operator.add_linear(network, target, None, kwargs_new, name)
208+
return add_linear(network, target, None, kwargs_new, name)
207209

208210

209211
@tensorrt_converter(torch.ops.aten.max_pool3d)
@@ -252,10 +254,11 @@ def aten_ops_mul(
252254
"input": args[0],
253255
"other": args[1],
254256
}
255-
return operator.add_mul(network, target, None, kwargs_new, name)
257+
return add_mul(network, target, None, kwargs_new, name)
256258

257259

258-
@tensorrt_converter(torch.ops.aten.matmul.Tensor)
260+
@tensorrt_converter(torch.ops.aten.matmul)
261+
@tensorrt_converter(torch.ops.aten.mm.default)
259262
def aten_ops_matmul(
260263
network: TRTNetwork,
261264
target: Target,
@@ -267,7 +270,7 @@ def aten_ops_matmul(
267270
"input": args[0],
268271
"other": args[1],
269272
}
270-
return operator.add_matmul(network, target, None, kwargs_new, name)
273+
return add_matmul(network, target, None, kwargs_new, name)
271274

272275

273276
@tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
@@ -283,7 +286,7 @@ def aten_ops_pow(
283286
"input": args[0],
284287
"exponent": args[1],
285288
}
286-
return operator.add_pow(network, target, None, kwargs_new, name)
289+
return add_pow(network, target, kwargs_new, name)
287290

288291

289292
@tensorrt_converter(torch.ops.aten.relu.default)
@@ -297,7 +300,7 @@ def aten_ops_relu(
297300
kwargs_new = {
298301
"input": args[0],
299302
}
300-
return activation.add_relu(network, target, kwargs_new, name)
303+
return add_relu(network, target, kwargs_new, name)
301304

302305
@tensorrt_converter(torch.ops.aten.sub.Tensor)
303306
def aten_ops_sub(
@@ -311,7 +314,7 @@ def aten_ops_sub(
311314
"input": args[0],
312315
"other": args[1],
313316
}
314-
return operator.add_sub(network, target, None, kwargs_new, name)
317+
return add_sub(network, target, None, kwargs_new, name)
315318

316319

317320
@tensorrt_converter(torch.ops.aten.view.default)
@@ -378,7 +381,7 @@ def aten_ops_expand(
378381
"input": args[0],
379382
"sizes": args[1],
380383
}
381-
return operator.add_expand(network, target, kwargs_new, name)
384+
return add_expand(network, target, kwargs_new, name)
382385

383386

384387
@tensorrt_converter(operator.floordiv)
@@ -393,7 +396,7 @@ def aten_ops_operator_floordiv(
393396
"input": args[0],
394397
"other": args[1],
395398
}
396-
return operator.add_floor_div(network, target, None, kwargs_new, name)
399+
return add_floor_div(network, target, None, kwargs_new, name)
397400

398401

399402
@tensorrt_converter(operator.mul)
@@ -408,7 +411,7 @@ def aten_ops_operator_mul(
408411
"input": args[0],
409412
"other": args[1],
410413
}
411-
return operator.add_mul(network, target, None, kwargs_new, name)
414+
return add_mul(network, target, None, kwargs_new, name)
412415

413416

414417
@tensorrt_converter(operator.add)
@@ -423,7 +426,7 @@ def aten_ops_operator_add(
423426
"input": args[0],
424427
"other": args[1],
425428
}
426-
return operator.add_add(network, target, None, kwargs_new, name)
429+
return add_add(network, target, None, kwargs_new, name)
427430

428431

429432
@tensorrt_converter(operator.sub)
@@ -438,7 +441,7 @@ def aten_ops_operator_sub(
438441
"input": args[0],
439442
"other": args[1],
440443
}
441-
return operator.add_sub(network, target, None, kwargs_new, name)
444+
return add_sub(network, target, None, kwargs_new, name)
442445

443446

444447
@tensorrt_converter(torch.ops.aten.sym_numel)
@@ -497,9 +500,10 @@ def aten_ops_slice(
497500
"stop" : args[3],
498501
"step" : args[4],
499502
}
500-
return operator.add_slice(network, target. kwargs_new, name)
503+
return add_slice(network, target. kwargs_new, name)
501504

502-
@tensorrt_converter(torch.ops.aten.select.Tensor)
505+
506+
@tensorrt_converter(torch.ops.aten.select)
503507
def aten_ops_select(
504508
network: TRTNetwork,
505509
target: Target,
@@ -512,7 +516,7 @@ def aten_ops_select(
512516
"dim" : args[1],
513517
"index" : args[2],
514518
}
515-
return operator.add_select(network, target. kwargs_new, name)
519+
return add_select(network, target. kwargs_new, name)
516520

517521

518522
@tensorrt_converter(torch.ops.aten.leaky_relu.default)
@@ -526,7 +530,7 @@ def aten_ops_leaky_relu(
526530
kwargs_new = {
527531
"input": args[0],
528532
}
529-
return activation.add_leaky_relu(network, target, kwargs_new, name)
533+
return add_leaky_relu(network, target, kwargs_new, name)
530534

531535

532536
@tensorrt_converter(torch.ops.aten.elu.default)
@@ -540,7 +544,7 @@ def aten_ops_elu(
540544
kwargs_new = {
541545
"input": args[0],
542546
}
543-
return activation.add_elu(network, target, kwargs_new, name)
547+
return add_elu(network, target, kwargs_new, name)
544548

545549

546550
@tensorrt_converter(torch.ops.aten.selu.default)
@@ -554,7 +558,7 @@ def aten_ops_selu(
554558
kwargs_new = {
555559
"input": args[0],
556560
}
557-
return activation.selu(network, target, kwargs_new, name)
561+
return add_selu(network, target, kwargs_new, name)
558562

559563

560564
@tensorrt_converter(torch.ops.aten.gelu.default)
@@ -568,22 +572,7 @@ def aten_ops_gelu(
568572
kwargs_new = {
569573
"input": args[0],
570574
}
571-
return activation.add_gelu(network, target, kwargs_new, name)
572-
573-
574-
@tensorrt_converter(torch.ops.aten.softsign.default)
575-
def aten_ops_softsign(
576-
network: TRTNetwork,
577-
target: Target,
578-
args: Tuple[Argument, ...],
579-
kwargs: Dict[str, Argument],
580-
name: str,
581-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
582-
kwargs_new = {
583-
"input": args[0],
584-
}
585-
return activation.add_softsign(network, target, kwargs_new, name)
586-
575+
return add_gelu(network, target, kwargs_new, name)
587576

588577
@tensorrt_converter(torch.ops.aten.tanh.default)
589578
def aten_ops_tanh(
@@ -596,34 +585,7 @@ def aten_ops_tanh(
596585
kwargs_new = {
597586
"input": args[0],
598587
}
599-
return activation.add_tanh(network, target, kwargs_new, name)
600-
601-
@tensorrt_converter(torch.ops.aten.softsign.default)
602-
def aten_ops_softsign(
603-
network: TRTNetwork,
604-
target: Target,
605-
args: Tuple[Argument, ...],
606-
kwargs: Dict[str, Argument],
607-
name: str,
608-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
609-
kwargs_new = {
610-
"input": args[0],
611-
}
612-
return activation.add_softsign(network, target, kwargs_new, name)
613-
614-
615-
@tensorrt_converter(torch.ops.aten.softsign.default)
616-
def aten_ops_hard_sigmoid(
617-
network: TRTNetwork,
618-
target: Target,
619-
args: Tuple[Argument, ...],
620-
kwargs: Dict[str, Argument],
621-
name: str,
622-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
623-
kwargs_new = {
624-
"input": args[0],
625-
}
626-
return activation.add_hard_sigmoid(network, target, kwargs_new, name)
588+
return add_tanh(network, target, kwargs_new, name)
627589

628590

629591
@tensorrt_converter(torch.ops.aten.sigmoid.default)
@@ -637,7 +599,7 @@ def aten_ops_hard_tanh(
637599
kwargs_new = {
638600
"input": args[0],
639601
}
640-
return activation.add_hard_tanh(network, target, kwargs_new, name)
602+
return add_hard_tanh(network, target, kwargs_new, name)
641603

642604

643605
@tensorrt_converter(torch.ops.aten.sigmoid.default)
@@ -651,7 +613,7 @@ def aten_ops_sigmoid(
651613
kwargs_new = {
652614
"input": args[0],
653615
}
654-
return activation.add_sigmoid(network, target, kwargs_new, name)
616+
return add_sigmoid(network, target, kwargs_new, name)
655617

656618

657619

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import operator
33
import warnings
4+
import logging
45
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
56

67
import tensorrt as trt
@@ -687,7 +688,7 @@ def add_pow(network, target, kwargs, name):
687688
network,
688689
kwargs["input"],
689690
kwargs["other"],
690-
trt.ElementWiseOperation.PROD,
691+
trt.ElementWiseOperation.POW,
691692
target,
692693
name,
693694
)

0 commit comments

Comments
 (0)