Skip to content

Commit 45e43ca

Browse files
narendasangs-olive
authored andcommitted
refactor: Moving elementwise and unary core to impl
Signed-off-by: Naren Dasan <[email protected]> new file: ../converters/impl/unary/base.py
1 parent 35b0618 commit 45e43ca

File tree

10 files changed

+882
-470
lines changed

10 files changed

+882
-470
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 379 additions & 130 deletions
Large diffs are not rendered by default.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2323
from torch_tensorrt.fx.converters.impl import activation
24+
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2425

2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
2627

@@ -159,9 +160,7 @@ def aten_ops_div(
159160
network, target, None, kwargs_new, name
160161
)
161162
elif rounding_mode == "trunc":
162-
return acc_ops_converters.acc_ops_trunc_div(
163-
network, target, None, kwargs_new, name
164-
)
163+
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
165164
else:
166165
raise RuntimeError(
167166
f"Target {target} does not support rounding mode {rounding_mode}"

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 3 additions & 337 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SourceIR(Enum):
2828
ACC = auto()
2929
ATEN = auto()
3030
PRIM = auto()
31+
TORCHTRT_LOWERED = auto()
3132
UNKNOWN = auto()
3233

3334
def __str__(self):
@@ -39,6 +40,8 @@ def __str__(self):
3940
return "aten"
4041
elif self == SourceIR.PRIM:
4142
return "prim"
43+
elif self == SourceIR.TORCHTRT_LOWERED:
44+
return "torchtrt_lowered"
4245
else:
4346
return "unknown_ir"
4447

@@ -406,175 +409,6 @@ def broadcast(
406409
return a, b
407410

408411

409-
def get_shape_with_dynamic_shape(
410-
network: TRTNetwork,
411-
shape: Union[list, tuple, torch.Tensor],
412-
input_val: TRTTensor,
413-
target: Target,
414-
name: str,
415-
) -> TRTTensor:
416-
"""
417-
Prepare the real output tensor shape for dynamic shape mode tensor input.
418-
How this functions works:
419-
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
420-
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
421-
reduce operation output shape. Steps of calculations are:
422-
1. get the actual tensor shape of input_val via add_shape layer;
423-
2. create a all 0 tensor [0, 0, 0];
424-
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
425-
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
426-
all -1 dynamic shape dimensions with actual batch_size value;
427-
5. output shape with actual batch_size as [2048, 128, 256]
428-
429-
Args:
430-
network (TRTNetwork): TensorRT network object.
431-
shape: calculated shape of the expected output tensor
432-
input_val (TRTTensor): A TensorRT ITensor.
433-
target (Target): Target of fx node.
434-
name (str): The name we want to assign to the created TensorRT layer.
435-
Returns:
436-
TensorRT ITensors that represents the actual shape of the input_val
437-
"""
438-
# Ger real shape info for input_val
439-
input_shape = network.add_shape(input_val).get_output(0)
440-
441-
scale_layer = network.add_constant(
442-
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
443-
)
444-
set_layer_name(scale_layer, target, f"{name}_scale")
445-
scale_res = scale_layer.get_output(0)
446-
447-
length = input_shape.shape[0]
448-
zero_layer = network.add_constant(
449-
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
450-
)
451-
set_layer_name(zero_layer, target, f"{name}_zeros")
452-
453-
condition_val = add_binary_elementwise_layer(
454-
network,
455-
scale_res,
456-
zero_layer.get_output(0),
457-
trt.ElementWiseOperation.LESS,
458-
target,
459-
f"{name}_shape",
460-
)
461-
select_layer = network.add_select(condition_val, input_shape, scale_res)
462-
set_layer_name(select_layer, target, f"{name}_select")
463-
return select_layer.get_output(0)
464-
465-
466-
def add_binary_elementwise_layer(
467-
network: TRTNetwork,
468-
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
469-
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
470-
op_type: trt.ElementWiseOperation,
471-
target: Target,
472-
name: str,
473-
) -> TRTTensor:
474-
"""
475-
This function adds a TensorRT elementwise layer. We allow both operands to be
476-
constant (not a trt tensor) because in implicit batch dimension mode, we could
477-
introduce constant via .size() op. Other scenario should be const folded first.
478-
If any operand is not a trt tensor, we make it a trt constant layer while preserve
479-
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
480-
481-
Limitation:
482-
If we are using implicit batch dim mode, the operand that is not a trt
483-
tensor are not allowed to have larger ranks than the trt tensor operand.
484-
485-
Args:
486-
network (TRTNetwork): TensorRT network object.
487-
lhs_val (TRTTensor): Left operand of the binary operation. Could
488-
be a TensorRT tensor, a PyTorch tensor or a simple value.
489-
rhs_val (TRTTensor): Right operand of the binary operation. Similar
490-
to lhs_val.
491-
op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation.
492-
target (Target): Target of fx node.
493-
name (str): The name we want to assign to the created TensorRT layer.
494-
495-
Returns:
496-
The output of TensorRT Elementwise layer.
497-
"""
498-
lhs_dtype = None
499-
rhs_dtype = None
500-
is_lhs_trt_tensor = False
501-
is_rhs_trt_tensor = False
502-
503-
if isinstance(lhs_val, TRTTensor):
504-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
505-
is_lhs_trt_tensor = True
506-
if isinstance(rhs_val, TRTTensor):
507-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
508-
is_rhs_trt_tensor = True
509-
510-
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
511-
warnings.warn(
512-
f"Both operands of the binary elementwise op {name} "
513-
"are constant. In this case, please consider constant fold the model first."
514-
)
515-
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)
516-
517-
# If the following conditions are true:
518-
# 1. the network has implicit batch dimension,
519-
# 2. one operand has shape [] (real shape is [batch_size]),
520-
# 3. another operand is a scalar,
521-
# then the result should also have shape [] (real shape is [batch_size]).
522-
#
523-
# In such case, we need to convert the scalar operand to tensor, because
524-
# this way the shape will become [1], and then will be properly squeezed
525-
# into [], meaning that the result will have shape [], which is what we
526-
# expect.
527-
#
528-
# Note that the dtype here is supposed to be the same as the scalar
529-
# dtype but we don't have a way to detect whether it makes sense for the
530-
# scalar to be float or half. Hence we go with the lhs dtype.
531-
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
532-
rhs_val = np.array(
533-
[rhs_val], dtype=unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
534-
)
535-
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
536-
lhs_val = np.array(
537-
[lhs_val], dtype=unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
538-
)
539-
540-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
541-
# will fail because lhs shape has fewer dimensions than rhs shape. This
542-
# happens when using implicit batch dimension, when we removed the 1st
543-
# dimension from input tensor, causing it to have shape [] - a scalar. We
544-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
545-
# scalar too. More generally, we squeeze_left on input if it's a constant
546-
# tensor. This is safe because broadcast will pad dimensions on the left
547-
# (prepend) to make lhs and rhs shape compatible.
548-
if network.has_implicit_batch_dimension:
549-
if isinstance(lhs_val, (torch.Tensor, np.ndarray)):
550-
lhs_val = squeeze_left(lhs_val)
551-
if isinstance(rhs_val, (torch.Tensor, np.ndarray)):
552-
rhs_val = squeeze_left(rhs_val)
553-
554-
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
555-
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)
556-
557-
# Check the limitation in the doc string.
558-
if network.has_implicit_batch_dimension:
559-
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
560-
assert len(lhs_val.shape) >= len(
561-
rhs_val.shape
562-
), f"{lhs_val.shape} >= {rhs_val.shape}"
563-
elif not is_lhs_trt_tensor and is_rhs_trt_tensor:
564-
assert len(rhs_val.shape) >= len(
565-
lhs_val.shape
566-
), f"{rhs_val.shape} >= {lhs_val.shape}"
567-
568-
lhs_val, rhs_val = broadcast(
569-
network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
570-
)
571-
layer = network.add_elementwise(lhs_val, rhs_val, op_type)
572-
set_layer_name(layer, target, name)
573-
output = layer.get_output(0)
574-
output.name = output.name + "_" + target.__name__
575-
return output
576-
577-
578412
def squeeze_left(const: Union[torch.Tensor, np.ndarray]):
579413
"""
580414
Squeeze the size-1 dimensions on the left side of the shape tuple.
@@ -591,38 +425,6 @@ def squeeze_left(const: Union[torch.Tensor, np.ndarray]):
591425
return const
592426

593427

594-
def add_unary_layer(
595-
network: TRTNetwork,
596-
input_val: TRTTensor,
597-
operation_type: trt.UnaryOperation,
598-
target: Target,
599-
name: str,
600-
) -> TRTTensor:
601-
"""
602-
Add a TensorRT Unary layer to `network`.
603-
604-
Args:
605-
network (TRTNetwork): TensorRT network object.
606-
input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor.
607-
op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation.
608-
target (Target): Target of fx node.
609-
name (str): The name we want to assign to the created TensorRT layer.
610-
611-
Returns:
612-
The output of TensorRT Unary layer.
613-
"""
614-
if not isinstance(input_val, TRTTensor):
615-
raise RuntimeError(
616-
f"{operation_type} received input {input_val} that is not part "
617-
"of the TensorRT region!"
618-
)
619-
layer = network.add_unary(input_val, operation_type)
620-
set_layer_name(layer, target, name)
621-
output = layer.get_output(0)
622-
output.name = output.name + "_" + target.__name__
623-
return layer.get_output(0)
624-
625-
626428
def add_reduce_layer(
627429
network: TRTNetwork,
628430
target: Target,
@@ -727,142 +529,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names):
727529
return inputs
728530

729531

730-
def sign(
731-
network: TRTNetwork, input_val: TRTTensor, target: Target, name: str
732-
) -> TRTTensor:
733-
"""
734-
Sign is calculated as below:
735-
x = input
736-
sign = (exp(x) // exp(abs(x))) * 2 - 1
737-
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
738-
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
739-
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
740-
741-
Args:
742-
network (TRTNetwork): TensorRT network object.
743-
input_val (TRTTensor): The input tensor.
744-
target (Target): fx node target.
745-
name (str): Name of the fx node with optional suffix.
746-
747-
Returns:
748-
A TensorRT tensor represent the result of sign operator.
749-
"""
750-
input_exp_output = add_unary_layer(
751-
network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp"
752-
)
753-
input_abs_output = add_unary_layer(
754-
network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs"
755-
)
756-
input_abs_exp_output = add_unary_layer(
757-
network,
758-
input_abs_output,
759-
trt.UnaryOperation.EXP,
760-
target,
761-
f"{name}_prod_abs_exp",
762-
)
763-
floor_div_output = add_binary_elementwise_layer(
764-
network,
765-
input_exp_output,
766-
input_abs_exp_output,
767-
trt.ElementWiseOperation.FLOOR_DIV,
768-
target,
769-
f"{name}_exp_floor_div",
770-
)
771-
double_floor_div_output = add_binary_elementwise_layer(
772-
network,
773-
floor_div_output,
774-
2,
775-
trt.ElementWiseOperation.PROD,
776-
target,
777-
f"{name}_floor_div*2",
778-
)
779-
return add_binary_elementwise_layer(
780-
network,
781-
double_floor_div_output,
782-
1,
783-
trt.ElementWiseOperation.SUB,
784-
target,
785-
f"{name}_sign",
786-
)
787-
788-
789-
def trunc_div(
790-
input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str
791-
) -> TRTTensor:
792-
"""
793-
Perform trunc divide on Tensor, result of divide will be round toward zero.
794-
This means for positive number, it will be floor round; for negative number,
795-
it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
796-
797-
Args:
798-
input: divisor.
799-
other: dividend.
800-
network: INetworkDefinition.
801-
target: node target.
802-
name: namespace for the op
803-
804-
Returns:
805-
A TensorRT tensor represent the result of trunc divide.
806-
"""
807-
prod_output = add_binary_elementwise_layer(
808-
network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod"
809-
)
810-
sign_output = sign(network, prod_output, target, name)
811-
812-
# Convert constant input into ITensor for UnaryOperation
813-
if not isinstance(input, trt.tensorrt.ITensor):
814-
input = get_trt_tensor(network, input, f"{name}_input")
815-
if not isinstance(other, trt.tensorrt.ITensor):
816-
other = get_trt_tensor(
817-
network,
818-
other,
819-
f"{name}_other",
820-
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
821-
)
822-
823-
abs_input_output = add_unary_layer(
824-
network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input"
825-
)
826-
abs_other_output = add_unary_layer(
827-
network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other"
828-
)
829-
abs_floor_output = add_binary_elementwise_layer(
830-
network,
831-
abs_input_output,
832-
abs_other_output,
833-
trt.ElementWiseOperation.FLOOR_DIV,
834-
target,
835-
f"{name}_floor_div",
836-
)
837-
output = add_binary_elementwise_layer(
838-
network,
839-
abs_floor_output,
840-
sign_output,
841-
trt.ElementWiseOperation.PROD,
842-
target,
843-
f"{name}_output",
844-
)
845-
846-
return output
847-
848-
849-
def get_python_op_from_trt_elementwise_op(
850-
trt_op: TRTElementWiseOp,
851-
) -> Callable[[Any, Any], Any]:
852-
if trt_op == trt.ElementWiseOperation.SUM:
853-
return operator.add
854-
elif trt_op == trt.ElementWiseOperation.PROD:
855-
return operator.mul
856-
elif trt_op == trt.ElementWiseOperation.SUB:
857-
return operator.sub
858-
elif trt_op == trt.ElementWiseOperation.DIV:
859-
return operator.truediv
860-
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
861-
return operator.floordiv
862-
else:
863-
raise RuntimeError(f"{trt_op} is not supported yet!")
864-
865-
866532
def dtype_uniform(
867533
network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor
868534
):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *

0 commit comments

Comments
 (0)