Skip to content

Commit ca4b263

Browse files
add the sym_not / full operator to support dynamic shape (#3013)
1 parent 0c25d92 commit ca4b263

File tree

7 files changed

+235
-1
lines changed

7 files changed

+235
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,23 @@ def aten_ops_logical_not(
17591759
)
17601760

17611761

1762+
@dynamo_tensorrt_converter(torch.sym_not, supports_dynamic_shapes=True)
1763+
def aten_ops_sym_not(
1764+
ctx: ConversionContext,
1765+
target: Target,
1766+
args: Tuple[Argument, ...],
1767+
kwargs: Dict[str, Argument],
1768+
name: str,
1769+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1770+
return impl.unary.sym_not(
1771+
ctx,
1772+
target,
1773+
SourceIR.ATEN,
1774+
name,
1775+
args[0],
1776+
)
1777+
1778+
17621779
@dynamo_tensorrt_converter(torch.ops.aten.sign.default, supports_dynamic_shapes=True)
17631780
def aten_ops_sign(
17641781
ctx: ConversionContext,
@@ -3456,3 +3473,21 @@ def aten_ops_arange_start_step(
34563473
end=args[1],
34573474
step=args_bounds_check(args, 2, 1),
34583475
)
3476+
3477+
3478+
@dynamo_tensorrt_converter(torch.ops.aten.full.default, supports_dynamic_shapes=True)
3479+
def aten_ops_full(
3480+
ctx: ConversionContext,
3481+
target: Target,
3482+
args: Tuple[Argument, ...],
3483+
kwargs: Dict[str, Argument],
3484+
name: str,
3485+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3486+
return impl.full.full(
3487+
ctx,
3488+
target,
3489+
SourceIR.ATEN,
3490+
name,
3491+
shape=args[0],
3492+
fill_value=args[1],
3493+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
deconv,
1313
elementwise,
1414
embedding,
15+
full,
1516
grid,
1617
linear,
1718
matmul,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import List, Optional, Union
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo.conversion import impl
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
SourceIR,
10+
cast_trt_tensor,
11+
get_trt_tensor,
12+
)
13+
from torch_tensorrt.fx.types import TRTTensor
14+
15+
16+
def full(
17+
ctx: ConversionContext,
18+
target: Union[Target, str],
19+
source_ir: Optional[SourceIR],
20+
name: str,
21+
shape: Union[List[int], TRTTensor],
22+
fill_value: Union[int, float, bool],
23+
) -> TRTTensor:
24+
# in static shape scenario, shape is a list of int
25+
if isinstance(shape, List):
26+
return np.full(shape, fill_value)
27+
28+
# in dynamic shape scenario, shape is a shap tensor
29+
# use IFillLayer to fill the shape tensor with LINSPACE value
30+
layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, shape.dtype)
31+
layer.set_input(0, shape)
32+
layer.set_input(1, get_trt_tensor(ctx, 0, name + "_start", min_rank=0))
33+
delta = get_trt_tensor(ctx, 1, name + "_delta")
34+
input = []
35+
for _ in range(shape.shape[0]):
36+
input.append(delta)
37+
delta = impl.cat.cat(ctx, target, source_ir, name + "_cat", input, dim=0)
38+
layer.set_input(2, delta)
39+
output = layer.get_output(0)
40+
41+
# fill the output tensor with the actual fill_value
42+
output = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", output, 0)
43+
if isinstance(fill_value, (int, float)):
44+
if isinstance(fill_value, float):
45+
output = cast_trt_tensor(
46+
ctx, output, trt.float32, name + "_casted", target, source_ir
47+
)
48+
output = impl.elementwise.add(
49+
ctx, target, source_ir, name + "_add", output, fill_value
50+
)
51+
52+
if isinstance(fill_value, bool):
53+
output = cast_trt_tensor(
54+
ctx, output, trt.bool, name + "_casted", target, source_ir
55+
)
56+
output = impl.elementwise.logical_or(
57+
ctx, target, source_ir, name + "_add", output, fill_value
58+
)
59+
60+
return output

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,47 @@ def logical_not(
418418
)
419419

420420

421+
# Notes: this sym_not output is slightly different than the torch.sym_not
422+
# torch.sym_not always returns a scaler value: torch.sym_not(torch.tensor([True])) ---> False
423+
# our sym_not cannot return a scaler value, it always return a tensor: sym_not(torch.tensor([True])) ---> torch.tensor(False)
424+
def sym_not(
425+
ctx: ConversionContext,
426+
target: Target,
427+
source_ir: Optional[SourceIR],
428+
name: str,
429+
input_val: Union[TRTTensor, bool, torch.SymBool, torch.Tensor],
430+
) -> TRTTensor:
431+
# TODO: not sure when the torch.SymBool cases arises, will add the support in future
432+
if isinstance(input_val, torch.SymBool):
433+
raise NotImplementedError(
434+
"Torch-TensorRT support for sym_not operator when type is torch.SymBool is not available, Need to Implement"
435+
)
436+
elif isinstance(input_val, (TRTTensor, torch.Tensor)):
437+
if input_val.dtype != trt.bool and input_val.dtype != torch.bool:
438+
raise RuntimeError(
439+
f"Only Boolean value of ITensor/Tensor is allowed for sym_not, got {input_val.dtype=}"
440+
)
441+
# torch.sym_not only allows 1 Boolean value of Tensor, otherwise pytorch will throw the following error
442+
# RuntimeError: Boolean value of Tensor with more than one value is ambiguous
443+
rank = len(input_val.shape)
444+
if rank >= 1:
445+
for index in range(rank):
446+
dim = input_val.shape[index]
447+
if dim != 1:
448+
raise RuntimeError(
449+
f"Boolean value of Tensor with more than one value is not allowed for sym_not, got input_val.shape[{index}]={input_val.shape[index]}"
450+
)
451+
input_val = impl.shuffle.reshape(
452+
ctx, target, source_ir, name + "_reshpaed", input_val, (1,)
453+
)
454+
elif isinstance(input_val, bool):
455+
input_val = get_trt_tensor(ctx, input_val, name + "_casted", dtype=trt.bool)
456+
457+
return convert_unary(
458+
ctx, target, source_ir, name, trt.UnaryOperation.NOT, input_val
459+
)
460+
461+
421462
def bitwise_not(
422463
ctx: ConversionContext,
423464
target: Target,

tests/py/dynamo/conversion/harness.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ def run_test(
110110
ref_outputs = [ref_outputs]
111111
for out, ref in zip(outputs, ref_outputs):
112112
if not isinstance(ref, torch.Tensor):
113-
ref = torch.tensor([ref])
113+
if len(out.shape) == 0:
114+
ref = torch.tensor(ref)
115+
else:
116+
ref = torch.tensor([ref])
114117
ref = ref.cpu() # to_dtype test has cases with gpu output
115118
torch.testing.assert_close(
116119
out.cpu(),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch_tensorrt
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_utils import run_tests
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestFullConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
((5,), 2),
14+
((5, 3), 0.1),
15+
((5, 3, 2), True),
16+
]
17+
)
18+
def test_full_static(self, shape, fill_value):
19+
class full(nn.Module):
20+
def forward(self, x):
21+
return torch.ops.aten.full.default(shape, fill_value)
22+
23+
inputs = [torch.randn(1, 1)]
24+
self.run_test(
25+
full(),
26+
inputs,
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
((1,), (3,), (4,), [3], 11),
32+
((3, 5), (3, 7), (3, 10), [3, 7], False),
33+
((1, 5), (3, 7), (4, 10), [3, 7], True),
34+
((1, 5, 3), (3, 7, 3), (4, 10, 4), [3, 7, 3], 0.11),
35+
]
36+
)
37+
def test_full_dynamic(self, min_shape, opt_shape, max_shape, data, fill_value):
38+
class full(nn.Module):
39+
def forward(self, shape):
40+
return torch.ops.aten.full.default(shape, fill_value)
41+
42+
inputs = [
43+
torch_tensorrt.Input(
44+
min_shape=min_shape,
45+
opt_shape=opt_shape,
46+
max_shape=max_shape,
47+
dtype=torch.int64,
48+
torch_tensor=torch.tensor(data, dtype=torch.int64).cuda(),
49+
is_shape_tensor=True,
50+
)
51+
]
52+
self.run_test_with_dynamic_shape(
53+
full(),
54+
inputs,
55+
use_example_tensors=False,
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
run_tests()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestSymNotConverter(DispatchTestCase):
10+
11+
@parameterized.expand(
12+
[
13+
(torch.tensor(True),),
14+
(torch.tensor(False),),
15+
(torch.tensor([True]),),
16+
(torch.tensor([[True]]),),
17+
(torch.tensor([[False]]),),
18+
]
19+
)
20+
def test_sym_not_bool(self, data):
21+
class sym_not(nn.Module):
22+
def forward(self, input):
23+
return torch.sym_not(input)
24+
25+
inputs = [data]
26+
27+
self.run_test(
28+
sym_not(),
29+
inputs,
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
run_tests()

0 commit comments

Comments
 (0)