|
21 | 21 | from .converter_utils import get_positive_dim
|
22 | 22 | from .converter_utils import prepend_ones
|
23 | 23 | from .converter_utils import has_dynamic_shape
|
24 |
| -from .converter_utils import get_shape_with_dynamic_shape |
25 | 24 | from .converter_utils import to_numpy
|
26 | 25 |
|
27 | 26 | from ..types import (
|
@@ -289,6 +288,79 @@ def trt_dtype_to_torch_dtype(trt_dtype):
|
289 | 288 | return table[trt_dtype]
|
290 | 289 |
|
291 | 290 |
|
| 291 | +def get_shape_with_dynamic_shape( |
| 292 | + network: TRTNetwork, |
| 293 | + shape: Union[list, tuple, torch.Tensor], |
| 294 | + input_val: TRTTensor, |
| 295 | + target: Target, |
| 296 | + name: str, |
| 297 | +) -> TRTTensor: |
| 298 | + """ |
| 299 | + Prepare the real output tensor shape for dynamic shape mode tensor input. |
| 300 | + How this functions works: |
| 301 | + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation |
| 302 | + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual |
| 303 | + reduce operation output shape. Steps of calculations are: |
| 304 | + 1. get the actual tensor shape of input_val via add_shape layer; |
| 305 | + 2. create a all 0 tensor [0, 0, 0]; |
| 306 | + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; |
| 307 | + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace |
| 308 | + all -1 dynamic shape dimensions with actual batch_size value; |
| 309 | + 5. output shape with actual batch_size as [2048, 128, 256] |
| 310 | +
|
| 311 | + Args: |
| 312 | + network (TRTNetwork): TensorRT network object. |
| 313 | + shape: calculated shape of the expected output tensor |
| 314 | + input_val (TRTTensor): A TensorRT ITensor. |
| 315 | + target (Target): Target of fx node. |
| 316 | + name (str): The name we want to assign to the created TensorRT layer. |
| 317 | + Returns: |
| 318 | + TensorRT ITensors that represents the actual shape of the input_val |
| 319 | + """ |
| 320 | + # Ger real shape info for input_val |
| 321 | + input_shape = network.add_shape(input_val).get_output(0) |
| 322 | + |
| 323 | + scale_layer = network.add_constant( |
| 324 | + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) |
| 325 | + ) |
| 326 | + set_layer_name(scale_layer, target, f"{name}_scale") |
| 327 | + scale_res = scale_layer.get_output(0) |
| 328 | + |
| 329 | + length = input_shape.shape[0] |
| 330 | + zero_layer = network.add_constant( |
| 331 | + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) |
| 332 | + ) |
| 333 | + set_layer_name(zero_layer, target, f"{name}_zeros") |
| 334 | + |
| 335 | + condition_val = add_binary_elementwise_layer( |
| 336 | + network, |
| 337 | + scale_res, |
| 338 | + zero_layer.get_output(0), |
| 339 | + trt.ElementWiseOperation.LESS, |
| 340 | + target, |
| 341 | + f"{name}_shape", |
| 342 | + ) |
| 343 | + select_layer = network.add_select(condition_val, input_shape, scale_res) |
| 344 | + set_layer_name(select_layer, target, f"{name}_select") |
| 345 | + return select_layer.get_output(0) |
| 346 | + |
| 347 | + |
| 348 | +def type_cast( |
| 349 | + network: TRTNetwork, |
| 350 | + target: Target, |
| 351 | + name: str, |
| 352 | + input: TRTTensor, |
| 353 | + cast_type: TRTDataType, |
| 354 | +): |
| 355 | + """ |
| 356 | + This function helps to cast the input type to cast_type |
| 357 | + """ |
| 358 | + layer_i = network.add_identity(input) |
| 359 | + layer_i.set_output_type(0, cast_type) |
| 360 | + set_layer_name(layer_i, target, f"{name}_dtype_change") |
| 361 | + return layer_i.get_output(0) |
| 362 | + |
| 363 | + |
292 | 364 | def add_tile(network, target, kwargs, name):
|
293 | 365 | input_t = kwargs["input"]
|
294 | 366 | input_val = get_trt_tensor(network, input_t, f"{name}_input")
|
@@ -822,25 +894,54 @@ def add_minimum(network, target, kwargs, name):
|
822 | 894 | )
|
823 | 895 |
|
824 | 896 |
|
| 897 | +def add_logical_not(network, target, kwargs, name): |
| 898 | + input_val = kwargs["input"] |
| 899 | + operation_type = trt.UnaryOperation.NOT |
| 900 | + # cast to bool type |
| 901 | + if input_val.dtype in (trt.float32, trt.float16, trt.int32): |
| 902 | + input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) |
| 903 | + return add_unary_layer(network, input_val, operation_type, target, name) |
| 904 | + |
| 905 | + |
825 | 906 | def add_logical_and(network, target, kwargs, name):
|
826 | 907 | if network.has_implicit_batch_dimension:
|
827 | 908 | raise RuntimeError(
|
828 |
| - "The `ne` function should be called with explicit batch dimension." |
| 909 | + "The `logical_and` function should be called with explicit batch dimension." |
829 | 910 | )
|
830 | 911 |
|
831 | 912 | input_t = kwargs["input"]
|
832 | 913 | other_t = kwargs["other"]
|
| 914 | + # we only support both inputs are bool type |
| 915 | + if target == acc_ops.bitwise_and: |
| 916 | + |
| 917 | + def check_is_bool(input_t): |
| 918 | + if isinstance(input_t, TRTTensor): |
| 919 | + assert ( |
| 920 | + input_t.dtype == trt.bool |
| 921 | + ), "We currently do not support input is non-bool" |
| 922 | + elif isinstance(input_t, torch.Tensor): |
| 923 | + assert ( |
| 924 | + input_t.dtype == torch.bool |
| 925 | + ), "We currently do not support input is non-bool" |
| 926 | + else: |
| 927 | + assert isinstance( |
| 928 | + input_t.bool |
| 929 | + ), "We currently do not support input is non-bool" |
| 930 | + |
| 931 | + check_is_bool(input_t) |
| 932 | + check_is_bool(other_t) |
833 | 933 |
|
834 | 934 | input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
|
835 | 935 | other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
|
836 | 936 |
|
837 |
| - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) |
838 |
| - eq_t = add_binary_elementwise_layer( |
839 |
| - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name |
| 937 | + if input_t.dtype != trt.bool: |
| 938 | + input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) |
| 939 | + if other_t.dtype != trt.bool: |
| 940 | + other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) |
| 941 | + return add_binary_elementwise_layer( |
| 942 | + network, input_t, other_t, trt.ElementWiseOperation.AND, target, name |
840 | 943 | )
|
841 | 944 |
|
842 |
| - return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) |
843 |
| - |
844 | 945 |
|
845 | 946 | def add_ne(network, target, kwargs, name):
|
846 | 947 | if network.has_implicit_batch_dimension:
|
|
0 commit comments