|
15 | 15 | ConverterRegistry,
|
16 | 16 | DynamoConverterImplSignature,
|
17 | 17 | )
|
18 |
| -from torch_tensorrt.fx.converters.converter_utils import ( |
19 |
| - broadcast, |
20 |
| - get_axes_for_reduce_op, |
21 |
| -) |
22 |
| -from torch_tensorrt.fx.types import TRTDataType, TRTTensor |
| 18 | + |
| 19 | +from ..types import Shape, TRTDataType, TRTLayer, TRTTensor |
23 | 20 |
|
24 | 21 | _LOGGER: logging.Logger = logging.getLogger(__name__)
|
25 | 22 |
|
@@ -177,9 +174,7 @@ def broadcast_to_same_shape(
|
177 | 174 | Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
|
178 | 175 |
|
179 | 176 | """
|
180 |
| - lhs_val, rhs_val = broadcast( |
181 |
| - ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" |
182 |
| - ) |
| 177 | + lhs_val, rhs_val = broadcast(ctx, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs") |
183 | 178 |
|
184 | 179 | lhs_val_shape = lhs_val.shape
|
185 | 180 | rhs_val_shape = rhs_val.shape
|
@@ -216,11 +211,6 @@ def broadcast_to_same_shape(
|
216 | 211 | return lhs_val, rhs_val
|
217 | 212 |
|
218 | 213 |
|
219 |
| -get_axes_for_reduce_op = functools.partial( |
220 |
| - get_axes_for_reduce_op, has_implicit_batch_dimension=False |
221 |
| -) |
222 |
| - |
223 |
| - |
224 | 214 | def extend_attr_to_tuple(
|
225 | 215 | val: Any,
|
226 | 216 | num_elem: int,
|
@@ -692,3 +682,160 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
|
692 | 682 | for i in range(len(shape) - 2, -1, -1):
|
693 | 683 | strides[i] = strides[i + 1] * shape[i + 1]
|
694 | 684 | return strides
|
| 685 | + |
| 686 | + |
| 687 | +def broadcast( |
| 688 | + ctx: ConversionContext, |
| 689 | + a: TRTTensor, |
| 690 | + b: TRTTensor, |
| 691 | + a_name: str, |
| 692 | + b_name: str, |
| 693 | + preset_diff: int = 0, |
| 694 | +) -> Tuple[TRTTensor, TRTTensor]: |
| 695 | + """ |
| 696 | + Broadcast two TensorRT tensors to the same number of dimensions by |
| 697 | + prepending 1s to the tensor with less number of dimensions. |
| 698 | +
|
| 699 | + Args: |
| 700 | + ctx (ConversionContext): A ConversionContext containing the TensorRT network |
| 701 | + a (TRTTensor): A TensorRT ITensor. |
| 702 | + b (TRTTensor): A TensorRT ITensor. |
| 703 | + a_name (str): Name of tensor a. |
| 704 | + b_name (str): Name of tensor b. |
| 705 | + preset_diff (int): The difference of number of dimensions after broadcast. |
| 706 | + A positive number means after broadcast, tensor `a` would have `preset_diff` |
| 707 | + more dimensions than `b`. This is used in matmul, since we need to broadcast |
| 708 | + tensors but not always to the same number of dimension. The reason is that |
| 709 | + matmul supports Matrix x Vector and in this case broadcasted vector should |
| 710 | + have 1 less number of dimensions than the matrix tensor. |
| 711 | +
|
| 712 | + Returns: |
| 713 | + Two TensorRT ITensors that are broadcasted to the same number of dimensions. |
| 714 | + """ |
| 715 | + a_shape = tuple(a.shape) |
| 716 | + b_shape = tuple(b.shape) |
| 717 | + |
| 718 | + diff = len(a_shape) - len(b_shape) - preset_diff |
| 719 | + if diff > 0: |
| 720 | + b = prepend_ones(ctx, b, f"{b_name}_broadcast", diff) |
| 721 | + elif diff < 0: |
| 722 | + a = prepend_ones(ctx, a, f"{a_name}_broadcast", -diff) |
| 723 | + |
| 724 | + return a, b |
| 725 | + |
| 726 | + |
| 727 | +def get_axes_for_reduce_op( |
| 728 | + dim: Union[int, Sequence[int]], |
| 729 | +) -> int: |
| 730 | + """ |
| 731 | + TensorRT reduce layer relies on the binary representation of axes to |
| 732 | + determine which dims to reduce. For example, if we want to reduce on |
| 733 | + dim 1 and 2 then axes should be 6(110). |
| 734 | +
|
| 735 | + Args: |
| 736 | + dim (Union[int, Sequence[int]]): An integer or a sequence of integers |
| 737 | + that will be used to generate axes for TensorRT. |
| 738 | +
|
| 739 | + Returns: |
| 740 | + An integer which binary form can be used as axes for TensorRT reduce |
| 741 | + layer. |
| 742 | + """ |
| 743 | + if isinstance(dim, int): |
| 744 | + dim = (dim,) |
| 745 | + |
| 746 | + axes = 0 |
| 747 | + for d in dim: |
| 748 | + axes |= 1 << d |
| 749 | + |
| 750 | + return axes |
| 751 | + |
| 752 | + |
| 753 | +def has_dynamic_shape(shape: Shape) -> bool: |
| 754 | + """ |
| 755 | + Determine if the given shape has dynamic dim. i.e. if there're -1 in shape. |
| 756 | +
|
| 757 | + Args: |
| 758 | + shape (Shape): Shape of a tensor. Essentially is a sequence of integers. |
| 759 | +
|
| 760 | + Returns: |
| 761 | + A boolean value indicates whether there's dynamic dim in the shape. |
| 762 | + """ |
| 763 | + count = 0 |
| 764 | + for s in shape: |
| 765 | + count += 1 if s == -1 else 0 |
| 766 | + return count > 0 |
| 767 | + |
| 768 | + |
| 769 | +def prepend_ones( |
| 770 | + ctx: ConversionContext, |
| 771 | + tensor: TRTTensor, |
| 772 | + name: str, |
| 773 | + num_prepend_ones: int, |
| 774 | +) -> TRTTensor: |
| 775 | + """ |
| 776 | + Prepend 1s to the shape of TensorRT ITensor `tensor`. |
| 777 | +
|
| 778 | + Args: |
| 779 | + ctx (ConversionContext): A ConversionContext containing the TensorRT network |
| 780 | + tensor (TRTTensor): A TensorRT tensor. |
| 781 | + name (str): Name of the TensorRT Shuffle layer which is used to prepend |
| 782 | + 1s. |
| 783 | + num_prepend_ones (int): Number of 1s that will be prepend. |
| 784 | +
|
| 785 | + Returns: |
| 786 | + A Tensorrt ITensor which contains the same value as `tensor` but with |
| 787 | + more 1s prepended to the beginning of `tensor` shape. |
| 788 | + """ |
| 789 | + layer = ctx.net.add_shuffle(tensor) |
| 790 | + |
| 791 | + # If there're dynamic dim in tensor's shape, we need to use shape layer to |
| 792 | + # compute the final shape. |
| 793 | + if has_dynamic_shape(tensor.shape): |
| 794 | + tensor_shape_layer = ctx.net.add_shape(tensor) |
| 795 | + tensor_shape = tensor_shape_layer.get_output(0) |
| 796 | + tensor_shape = cast_trt_tensor( |
| 797 | + ctx, tensor_shape, trt.int32, name + "shape_casted", "shape" |
| 798 | + ) |
| 799 | + tensor_shape_layer.name = f"{name}_broadcast_orig_shape" |
| 800 | + prepend_shape_layer = ctx.net.add_constant( |
| 801 | + (num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32) |
| 802 | + ) |
| 803 | + prepend_shape_layer.name = f"{name}_broadcast_prepend_ones" |
| 804 | + reshape_dim_layer = ctx.net.add_concatenation( |
| 805 | + [prepend_shape_layer.get_output(0), tensor_shape] |
| 806 | + ) |
| 807 | + reshape_dim_layer.axis = 0 |
| 808 | + reshape_dim_layer.name = f"{name}_broadcast_final_shape" |
| 809 | + layer.set_input(1, reshape_dim_layer.get_output(0)) |
| 810 | + else: |
| 811 | + layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape) |
| 812 | + |
| 813 | + layer.name = name |
| 814 | + return layer.get_output(0) |
| 815 | + |
| 816 | + |
| 817 | +def set_layer_name( |
| 818 | + layer: TRTLayer, |
| 819 | + target: Union[Target, torch.nn.Module, str], |
| 820 | + name: str, |
| 821 | + source_ir: Optional[SourceIR] = None, |
| 822 | +) -> None: |
| 823 | + """ |
| 824 | + Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]" |
| 825 | +
|
| 826 | + Args: |
| 827 | + layer (TRTLayer): A TensorRT layer of which we want to set the name. |
| 828 | + target (Target): A fx node.target or submodule. For call_function node, it's the function that |
| 829 | + the node represents. |
| 830 | + name (str): Consists of fx node.name with optional suffix. |
| 831 | + source_ir: (Optional[SourceIR]): The IR producing the op. |
| 832 | + """ |
| 833 | + |
| 834 | + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN |
| 835 | + |
| 836 | + target_name = ( |
| 837 | + f"{source_ir}_ops.{target}" |
| 838 | + if isinstance(target, str) |
| 839 | + else f"{source_ir}_ops.{target.__name__}" |
| 840 | + ) |
| 841 | + layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]" |
0 commit comments