Skip to content

Commit d7c82ab

Browse files
committed
binary operator changes in aten
1 parent e8a8e38 commit d7c82ab

File tree

3 files changed

+73
-146
lines changed

3 files changed

+73
-146
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ def aten_ops_div(
144144
}
145145
rounding_mode = kwargs.get("rounding_mode")
146146
if rounding_mode is None:
147-
return add_div(network, target, None, kwargs_new, name)
147+
return add_div(network, target, kwargs_new, name)
148148
elif rounding_mode == "floor":
149-
return add_floor_div(network, target, None, kwargs_new, name)
149+
return add_floor_div(network, target, kwargs_new, name)
150150
elif rounding_mode == "trunc":
151-
return add_trunc_div(network, target, None, kwargs_new, name)
151+
return add_trunc_div(network, target, kwargs_new, name)
152152
else:
153153
raise RuntimeError(
154154
f"Target {target} does not support rounding mode {rounding_mode}"
@@ -167,7 +167,7 @@ def aten_ops_floor_div(
167167
"input": args[0],
168168
"other": args[1],
169169
}
170-
return add_floor_div(network, target, None, kwargs_new, name)
170+
return add_floor_div(network, target, kwargs_new, name)
171171

172172

173173
@tensorrt_converter(torch.ops.aten.fmod.Scalar)
@@ -183,7 +183,7 @@ def aten_ops_fmod(
183183
"input": args[0],
184184
"other": args[1],
185185
}
186-
return add_fmod(network, target, None, kwargs_new, name)
186+
return add_fmod(network, target, kwargs_new, name)
187187

188188

189189
@tensorrt_converter(torch.ops.aten.linear)
@@ -200,7 +200,7 @@ def aten_ops_linear(
200200
"bias": args[2],
201201
}
202202

203-
return add_linear(network, target, None, kwargs_new, name)
203+
return add_linear(network, target, kwargs_new, name)
204204

205205

206206
@tensorrt_converter(torch.ops.aten.max_pool3d)
@@ -249,7 +249,7 @@ def aten_ops_mul(
249249
"input": args[0],
250250
"other": args[1],
251251
}
252-
return add_mul(network, target, None, kwargs_new, name)
252+
return add_mul(network, target, kwargs_new, name)
253253

254254

255255
@tensorrt_converter(torch.ops.aten.matmul)
@@ -265,7 +265,7 @@ def aten_ops_matmul(
265265
"input": args[0],
266266
"other": args[1],
267267
}
268-
return add_matmul(network, target, None, kwargs_new, name)
268+
return add_matmul(network, target, kwargs_new, name)
269269

270270

271271
@tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
@@ -310,7 +310,7 @@ def aten_ops_sub(
310310
"input": args[0],
311311
"other": args[1],
312312
}
313-
return add_sub(network, target, None, kwargs_new, name)
313+
return add_sub(network, target, kwargs_new, name)
314314

315315

316316
@tensorrt_converter(torch.ops.aten.view.default)
@@ -392,7 +392,7 @@ def aten_ops_operator_floordiv(
392392
"input": args[0],
393393
"other": args[1],
394394
}
395-
return add_floor_div(network, target, None, kwargs_new, name)
395+
return add_floor_div(network, target, kwargs_new, name)
396396

397397

398398
@tensorrt_converter(operator.mul)
@@ -407,7 +407,7 @@ def aten_ops_operator_mul(
407407
"input": args[0],
408408
"other": args[1],
409409
}
410-
return add_mul(network, target, None, kwargs_new, name)
410+
return add_mul(network, target, kwargs_new, name)
411411

412412

413413
@tensorrt_converter(operator.add)
@@ -422,7 +422,7 @@ def aten_ops_operator_add(
422422
"input": args[0],
423423
"other": args[1],
424424
}
425-
return add_add(network, target, None, kwargs_new, name)
425+
return add_add(network, target, kwargs_new, name)
426426

427427

428428
@tensorrt_converter(operator.sub)
@@ -437,7 +437,7 @@ def aten_ops_operator_sub(
437437
"input": args[0],
438438
"other": args[1],
439439
}
440-
return add_sub(network, target, None, kwargs_new, name)
440+
return add_sub(network, target, kwargs_new, name)
441441

442442

443443
@tensorrt_converter(torch.ops.aten.sym_numel)
@@ -499,22 +499,6 @@ def aten_ops_slice(
499499
return add_slice(network, target.kwargs_new, name)
500500

501501

502-
@tensorrt_converter(torch.ops.aten.select)
503-
def aten_ops_select(
504-
network: TRTNetwork,
505-
target: Target,
506-
args: Tuple[Argument, ...],
507-
kwargs: Dict[str, Argument],
508-
name: str,
509-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
510-
kwargs_new = {
511-
"input": args[0],
512-
"dim": args[1],
513-
"index": args[2],
514-
}
515-
return add_select(network, target.kwargs_new, name)
516-
517-
518502
@tensorrt_converter(torch.ops.aten.leaky_relu.default)
519503
def aten_ops_leaky_relu(
520504
network: TRTNetwork,

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -524,122 +524,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names):
524524
return inputs
525525

526526

527-
def sign(
528-
network: TRTNetwork, input_val: TRTTensor, target: Target, name: str
529-
) -> TRTTensor:
530-
"""
531-
Sign is calculated as below:
532-
x = input
533-
sign = (exp(x) // exp(abs(x))) * 2 - 1
534-
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
535-
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
536-
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
537-
538-
Args:
539-
network (TRTNetwork): TensorRT network object.
540-
input_val (TRTTensor): The input tensor.
541-
target (Target): fx node target.
542-
name (str): Name of the fx node with optional suffix.
543-
544-
Returns:
545-
A TensorRT tensor represent the result of sign operator.
546-
"""
547-
input_exp_output = add_unary_layer(
548-
network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp"
549-
)
550-
input_abs_output = add_unary_layer(
551-
network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs"
552-
)
553-
input_abs_exp_output = add_unary_layer(
554-
network,
555-
input_abs_output,
556-
trt.UnaryOperation.EXP,
557-
target,
558-
f"{name}_prod_abs_exp",
559-
)
560-
floor_div_output = add_binary_elementwise_layer(
561-
network,
562-
input_exp_output,
563-
input_abs_exp_output,
564-
trt.ElementWiseOperation.FLOOR_DIV,
565-
target,
566-
f"{name}_exp_floor_div",
567-
)
568-
double_floor_div_output = add_binary_elementwise_layer(
569-
network,
570-
floor_div_output,
571-
2,
572-
trt.ElementWiseOperation.PROD,
573-
target,
574-
f"{name}_floor_div*2",
575-
)
576-
return add_binary_elementwise_layer(
577-
network,
578-
double_floor_div_output,
579-
1,
580-
trt.ElementWiseOperation.SUB,
581-
target,
582-
f"{name}_sign",
583-
)
584-
585-
586-
def trunc_div(
587-
input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str
588-
) -> TRTTensor:
589-
"""
590-
Perform trunc divide on Tensor, result of divide will be round toward zero.
591-
This means for positive number, it will be floor round; for negative number,
592-
it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
593-
594-
Args:
595-
input: divisor.
596-
other: dividend.
597-
network: INetworkDefinition.
598-
target: node target.
599-
name: namespace for the op
600-
601-
Returns:
602-
A TensorRT tensor represent the result of trunc divide.
603-
"""
604-
prod_output = add_binary_elementwise_layer(
605-
network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod"
606-
)
607-
sign_output = sign(network, prod_output, target, name)
608-
609-
# Convert constant input into ITensor for UnaryOperation
610-
if not isinstance(input, trt.tensorrt.ITensor):
611-
input = get_trt_tensor(network, input, f"{name}_input")
612-
if not isinstance(other, trt.tensorrt.ITensor):
613-
other = get_trt_tensor(
614-
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
615-
)
616-
617-
abs_input_output = add_unary_layer(
618-
network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input"
619-
)
620-
abs_other_output = add_unary_layer(
621-
network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other"
622-
)
623-
abs_floor_output = add_binary_elementwise_layer(
624-
network,
625-
abs_input_output,
626-
abs_other_output,
627-
trt.ElementWiseOperation.FLOOR_DIV,
628-
target,
629-
f"{name}_floor_div",
630-
)
631-
output = add_binary_elementwise_layer(
632-
network,
633-
abs_floor_output,
634-
sign_output,
635-
trt.ElementWiseOperation.PROD,
636-
target,
637-
f"{name}_output",
638-
)
639-
640-
return output
641-
642-
643527
def dtype_uniform(
644528
network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor
645529
):

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,65 @@ def add_binary_elementwise_layer(
162162
return output
163163

164164

165+
def sign(
166+
network: TRTNetwork, input_val: TRTTensor, target: Target, name: str
167+
) -> TRTTensor:
168+
"""
169+
Sign is calculated as below:
170+
x = input
171+
sign = (exp(x) // exp(abs(x))) * 2 - 1
172+
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
173+
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
174+
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
175+
176+
Args:
177+
network (TRTNetwork): TensorRT network object.
178+
input_val (TRTTensor): The input tensor.
179+
target (Target): fx node target.
180+
name (str): Name of the fx node with optional suffix.
181+
182+
Returns:
183+
A TensorRT tensor represent the result of sign operator.
184+
"""
185+
input_exp_output = add_unary_layer(
186+
network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp"
187+
)
188+
input_abs_output = add_unary_layer(
189+
network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs"
190+
)
191+
input_abs_exp_output = add_unary_layer(
192+
network,
193+
input_abs_output,
194+
trt.UnaryOperation.EXP,
195+
target,
196+
f"{name}_prod_abs_exp",
197+
)
198+
floor_div_output = add_binary_elementwise_layer(
199+
network,
200+
input_exp_output,
201+
input_abs_exp_output,
202+
trt.ElementWiseOperation.FLOOR_DIV,
203+
target,
204+
f"{name}_exp_floor_div",
205+
)
206+
double_floor_div_output = add_binary_elementwise_layer(
207+
network,
208+
floor_div_output,
209+
2,
210+
trt.ElementWiseOperation.PROD,
211+
target,
212+
f"{name}_floor_div*2",
213+
)
214+
return add_binary_elementwise_layer(
215+
network,
216+
double_floor_div_output,
217+
1,
218+
trt.ElementWiseOperation.SUB,
219+
target,
220+
f"{name}_sign",
221+
)
222+
223+
165224
def trunc_div(
166225
input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str
167226
) -> TRTTensor:
@@ -701,7 +760,7 @@ def add_pow(network, target, kwargs, name):
701760
return add_binary_elementwise_layer(
702761
network,
703762
kwargs["input"],
704-
kwargs["other"],
763+
kwargs["exponent"],
705764
trt.ElementWiseOperation.POW,
706765
target,
707766
name,

0 commit comments

Comments
 (0)