Skip to content

Commit a1d94c1

Browse files
committed
Fixing the acc tests, logical_and operator and the leaky_relu test
1 parent 2bcf5f4 commit a1d94c1

File tree

5 files changed

+115
-94
lines changed

5 files changed

+115
-94
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,12 +1262,7 @@ def acc_ops_logical_not(
12621262
kwargs: Dict[str, Argument],
12631263
name: str,
12641264
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1265-
input_val = kwargs["input"]
1266-
operation_type = trt.UnaryOperation.NOT
1267-
# cast to bool type
1268-
if input_val.dtype in (trt.float32, trt.float16, trt.int32):
1269-
input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool)
1270-
return add_unary_layer(network, input_val, operation_type, target, name)
1265+
return add_logical_not(network, target, kwargs, name)
12711266

12721267

12731268
@tensorrt_converter(acc_ops.logical_and, no_implicit_batch_dim=True)
@@ -2335,7 +2330,7 @@ def acc_ops_getitem(
23352330
input_val = kwargs["input"]
23362331
slices = kwargs["idx"]
23372332
if not isinstance(input_val, TRTTensor):
2338-
return getitem(input_val, slices) # type: ignore[arg-type]
2333+
return operator.getitem(input_val, slices) # type: ignore[arg-type]
23392334

23402335
if not isinstance(slices, tuple) and not isinstance(slices, list):
23412336
slices = (slices,)
@@ -2803,7 +2798,7 @@ def acc_ops_hardtanh(
28032798
kwargs: Dict[str, Argument],
28042799
name: str,
28052800
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2806-
return add_hardtanh(network, target, kwargs, name)
2801+
return add_hard_tanh(network, target, kwargs, name)
28072802

28082803

28092804
@tensorrt_converter(acc_ops.interpolate)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,7 @@ def aten_ops_leaky_relu(
489489
kwargs: Dict[str, Argument],
490490
name: str,
491491
) -> Union[TRTTensor, Sequence[TRTTensor]]:
492-
kwargs_new = {
493-
"input": args[0],
494-
}
492+
kwargs_new = {"input": args[0], "negative_slope": args[1]}
495493
return add_leaky_relu(network, target, kwargs_new, name)
496494

497495

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -328,63 +328,6 @@ def broadcast(
328328
return a, b
329329

330330

331-
def get_shape_with_dynamic_shape(
332-
network: TRTNetwork,
333-
shape: Union[list, tuple, torch.Tensor],
334-
input_val: TRTTensor,
335-
target: Target,
336-
name: str,
337-
) -> TRTTensor:
338-
"""
339-
Prepare the real output tensor shape for dynamic shape mode tensor input.
340-
How this functions works:
341-
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
342-
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
343-
reduce operation output shape. Steps of calculations are:
344-
1. get the actual tensor shape of input_val via add_shape layer;
345-
2. create a all 0 tensor [0, 0, 0];
346-
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
347-
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
348-
all -1 dynamic shape dimensions with actual batch_size value;
349-
5. output shape with actual batch_size as [2048, 128, 256]
350-
351-
Args:
352-
network (TRTNetwork): TensorRT network object.
353-
shape: calculated shape of the expected output tensor
354-
input_val (TRTTensor): A TensorRT ITensor.
355-
target (Target): Target of fx node.
356-
name (str): The name we want to assign to the created TensorRT layer.
357-
Returns:
358-
TensorRT ITensors that represents the actual shape of the input_val
359-
"""
360-
# Ger real shape info for input_val
361-
input_shape = network.add_shape(input_val).get_output(0)
362-
363-
scale_layer = network.add_constant(
364-
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
365-
)
366-
set_layer_name(scale_layer, target, f"{name}_scale")
367-
scale_res = scale_layer.get_output(0)
368-
369-
length = input_shape.shape[0]
370-
zero_layer = network.add_constant(
371-
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
372-
)
373-
set_layer_name(zero_layer, target, f"{name}_zeros")
374-
375-
condition_val = operator.add_binary_elementwise_layer(
376-
network,
377-
scale_res,
378-
zero_layer.get_output(0),
379-
trt.ElementWiseOperation.LESS,
380-
target,
381-
f"{name}_shape",
382-
)
383-
select_layer = network.add_select(condition_val, input_shape, scale_res)
384-
set_layer_name(select_layer, target, f"{name}_select")
385-
return select_layer.get_output(0)
386-
387-
388331
def squeeze_left(const: torch.Tensor):
389332
"""
390333
Squeeze the size-1 dimensions on the left side of the shape tuple.
@@ -529,22 +472,6 @@ def dtype_uniform(
529472
return input, other
530473

531474

532-
def type_cast(
533-
network: TRTNetwork,
534-
target: Target,
535-
name: str,
536-
input: TRTTensor,
537-
cast_type: TRTDataType,
538-
):
539-
"""
540-
This function helps to cast the input type to cast_type
541-
"""
542-
layer_i = network.add_identity(input)
543-
layer_i.set_output_type(0, cast_type)
544-
set_layer_name(layer_i, target, f"{name}_dtype_change")
545-
return layer_i.get_output(0)
546-
547-
548475
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
549476
"""
550477
Convert a PyTorch Tensor to a Numpy Array. If the tensor is

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 108 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from .converter_utils import get_positive_dim
2222
from .converter_utils import prepend_ones
2323
from .converter_utils import has_dynamic_shape
24-
from .converter_utils import get_shape_with_dynamic_shape
2524
from .converter_utils import to_numpy
2625

2726
from ..types import (
@@ -289,6 +288,79 @@ def trt_dtype_to_torch_dtype(trt_dtype):
289288
return table[trt_dtype]
290289

291290

291+
def get_shape_with_dynamic_shape(
292+
network: TRTNetwork,
293+
shape: Union[list, tuple, torch.Tensor],
294+
input_val: TRTTensor,
295+
target: Target,
296+
name: str,
297+
) -> TRTTensor:
298+
"""
299+
Prepare the real output tensor shape for dynamic shape mode tensor input.
300+
How this functions works:
301+
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
302+
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
303+
reduce operation output shape. Steps of calculations are:
304+
1. get the actual tensor shape of input_val via add_shape layer;
305+
2. create a all 0 tensor [0, 0, 0];
306+
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
307+
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
308+
all -1 dynamic shape dimensions with actual batch_size value;
309+
5. output shape with actual batch_size as [2048, 128, 256]
310+
311+
Args:
312+
network (TRTNetwork): TensorRT network object.
313+
shape: calculated shape of the expected output tensor
314+
input_val (TRTTensor): A TensorRT ITensor.
315+
target (Target): Target of fx node.
316+
name (str): The name we want to assign to the created TensorRT layer.
317+
Returns:
318+
TensorRT ITensors that represents the actual shape of the input_val
319+
"""
320+
# Ger real shape info for input_val
321+
input_shape = network.add_shape(input_val).get_output(0)
322+
323+
scale_layer = network.add_constant(
324+
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
325+
)
326+
set_layer_name(scale_layer, target, f"{name}_scale")
327+
scale_res = scale_layer.get_output(0)
328+
329+
length = input_shape.shape[0]
330+
zero_layer = network.add_constant(
331+
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
332+
)
333+
set_layer_name(zero_layer, target, f"{name}_zeros")
334+
335+
condition_val = add_binary_elementwise_layer(
336+
network,
337+
scale_res,
338+
zero_layer.get_output(0),
339+
trt.ElementWiseOperation.LESS,
340+
target,
341+
f"{name}_shape",
342+
)
343+
select_layer = network.add_select(condition_val, input_shape, scale_res)
344+
set_layer_name(select_layer, target, f"{name}_select")
345+
return select_layer.get_output(0)
346+
347+
348+
def type_cast(
349+
network: TRTNetwork,
350+
target: Target,
351+
name: str,
352+
input: TRTTensor,
353+
cast_type: TRTDataType,
354+
):
355+
"""
356+
This function helps to cast the input type to cast_type
357+
"""
358+
layer_i = network.add_identity(input)
359+
layer_i.set_output_type(0, cast_type)
360+
set_layer_name(layer_i, target, f"{name}_dtype_change")
361+
return layer_i.get_output(0)
362+
363+
292364
def add_tile(network, target, kwargs, name):
293365
input_t = kwargs["input"]
294366
input_val = get_trt_tensor(network, input_t, f"{name}_input")
@@ -822,25 +894,54 @@ def add_minimum(network, target, kwargs, name):
822894
)
823895

824896

897+
def add_logical_not(network, target, kwargs, name):
898+
input_val = kwargs["input"]
899+
operation_type = trt.UnaryOperation.NOT
900+
# cast to bool type
901+
if input_val.dtype in (trt.float32, trt.float16, trt.int32):
902+
input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool)
903+
return add_unary_layer(network, input_val, operation_type, target, name)
904+
905+
825906
def add_logical_and(network, target, kwargs, name):
826907
if network.has_implicit_batch_dimension:
827908
raise RuntimeError(
828-
"The `ne` function should be called with explicit batch dimension."
909+
"The `logical_and` function should be called with explicit batch dimension."
829910
)
830911

831912
input_t = kwargs["input"]
832913
other_t = kwargs["other"]
914+
# we only support both inputs are bool type
915+
if target == acc_ops.bitwise_and:
916+
917+
def check_is_bool(input_t):
918+
if isinstance(input_t, TRTTensor):
919+
assert (
920+
input_t.dtype == trt.bool
921+
), "We currently do not support input is non-bool"
922+
elif isinstance(input_t, torch.Tensor):
923+
assert (
924+
input_t.dtype == torch.bool
925+
), "We currently do not support input is non-bool"
926+
else:
927+
assert isinstance(
928+
input_t.bool
929+
), "We currently do not support input is non-bool"
930+
931+
check_is_bool(input_t)
932+
check_is_bool(other_t)
833933

834934
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
835935
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
836936

837-
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
838-
eq_t = add_binary_elementwise_layer(
839-
network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
937+
if input_t.dtype != trt.bool:
938+
input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool)
939+
if other_t.dtype != trt.bool:
940+
other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool)
941+
return add_binary_elementwise_layer(
942+
network, input_t, other_t, trt.ElementWiseOperation.AND, target, name
840943
)
841944

842-
return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name)
843-
844945

845946
def add_ne(network, target, kwargs, name):
846947
if network.has_implicit_batch_dimension:

py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class TestLeakyReLUConverter(DispatchTestCase):
88
def test_leaky_relu(self):
99
class TestModule(nn.Module):
1010
def forward(self, x):
11-
return nn.functional.leaky_relu(x)
11+
return nn.functional.leaky_relu(x, negative_slope=0.05)
1212

1313
inputs = [torch.randn(1, 10)]
1414
self.run_test(
@@ -18,7 +18,7 @@ def forward(self, x):
1818
def test_leaky_relu_with_dynamic_shape(self):
1919
class TestModule(nn.Module):
2020
def forward(self, x):
21-
return nn.functional.leaky_relu(x)
21+
return nn.functional.leaky_relu(x, negative_slope=0.05)
2222

2323
input_specs = [
2424
InputTensorSpec(
@@ -34,7 +34,7 @@ def forward(self, x):
3434
def test_leaky_relu_with_dynamic_shape_four_dimensions(self):
3535
class TestModule(nn.Module):
3636
def forward(self, x):
37-
return nn.functional.leaky_relu(x)
37+
return nn.functional.leaky_relu(x, negative_slope=0.05)
3838

3939
input_specs = [
4040
InputTensorSpec(

0 commit comments

Comments
 (0)