Skip to content

add the sym_not / full operator to support dynamic shape #3013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,23 @@ def aten_ops_logical_not(
)


@dynamo_tensorrt_converter(torch.sym_not, supports_dynamic_shapes=True)
def aten_ops_sym_not(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.sym_not(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.sign.default, supports_dynamic_shapes=True)
def aten_ops_sign(
ctx: ConversionContext,
Expand Down Expand Up @@ -3456,3 +3473,21 @@ def aten_ops_arange_start_step(
end=args[1],
step=args_bounds_check(args, 2, 1),
)


@dynamo_tensorrt_converter(torch.ops.aten.full.default, supports_dynamic_shapes=True)
def aten_ops_full(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.full.full(
ctx,
target,
SourceIR.ATEN,
name,
shape=args[0],
fill_value=args[1],
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
deconv,
elementwise,
embedding,
full,
grid,
linear,
matmul,
Expand Down
60 changes: 60 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.types import TRTTensor


def full(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
shape: Union[List[int], TRTTensor],
fill_value: Union[int, float, bool],
) -> TRTTensor:
# in static shape scenario, shape is a list of int
if isinstance(shape, List):
return np.full(shape, fill_value)

# in dynamic shape scenario, shape is a shap tensor
# use IFillLayer to fill the shape tensor with LINSPACE value
layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, shape.dtype)
layer.set_input(0, shape)
layer.set_input(1, get_trt_tensor(ctx, 0, name + "_start", min_rank=0))
delta = get_trt_tensor(ctx, 1, name + "_delta")
input = []
for _ in range(shape.shape[0]):
input.append(delta)
delta = impl.cat.cat(ctx, target, source_ir, name + "_cat", input, dim=0)
layer.set_input(2, delta)
output = layer.get_output(0)

# fill the output tensor with the actual fill_value
output = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", output, 0)
if isinstance(fill_value, (int, float)):
if isinstance(fill_value, float):
output = cast_trt_tensor(
ctx, output, trt.float32, name + "_casted", target, source_ir
)
output = impl.elementwise.add(
ctx, target, source_ir, name + "_add", output, fill_value
)

if isinstance(fill_value, bool):
output = cast_trt_tensor(
ctx, output, trt.bool, name + "_casted", target, source_ir
)
output = impl.elementwise.logical_or(
ctx, target, source_ir, name + "_add", output, fill_value
)

return output
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,47 @@ def logical_not(
)


# Notes: this sym_not output is slightly different than the torch.sym_not
# torch.sym_not always returns a scaler value: torch.sym_not(torch.tensor([True])) ---> False
# our sym_not cannot return a scaler value, it always return a tensor: sym_not(torch.tensor([True])) ---> torch.tensor(False)
def sym_not(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: Union[TRTTensor, bool, torch.SymBool, torch.Tensor],
) -> TRTTensor:
# TODO: not sure when the torch.SymBool cases arises, will add the support in future
if isinstance(input_val, torch.SymBool):
raise NotImplementedError(
"Torch-TensorRT support for sym_not operator when type is torch.SymBool is not available, Need to Implement"
)
elif isinstance(input_val, (TRTTensor, torch.Tensor)):
if input_val.dtype != trt.bool and input_val.dtype != torch.bool:
raise RuntimeError(
f"Only Boolean value of ITensor/Tensor is allowed for sym_not, got {input_val.dtype=}"
)
# torch.sym_not only allows 1 Boolean value of Tensor, otherwise pytorch will throw the following error
# RuntimeError: Boolean value of Tensor with more than one value is ambiguous
rank = len(input_val.shape)
if rank >= 1:
for index in range(rank):
dim = input_val.shape[index]
if dim != 1:
raise RuntimeError(
f"Boolean value of Tensor with more than one value is not allowed for sym_not, got input_val.shape[{index}]={input_val.shape[index]}"
)
input_val = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshpaed", input_val, (1,)
)
elif isinstance(input_val, bool):
input_val = get_trt_tensor(ctx, input_val, name + "_casted", dtype=trt.bool)

return convert_unary(
ctx, target, source_ir, name, trt.UnaryOperation.NOT, input_val
)


def bitwise_not(
ctx: ConversionContext,
target: Target,
Expand Down
5 changes: 4 additions & 1 deletion tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def run_test(
ref_outputs = [ref_outputs]
for out, ref in zip(outputs, ref_outputs):
if not isinstance(ref, torch.Tensor):
ref = torch.tensor([ref])
if len(out.shape) == 0:
ref = torch.tensor(ref)
else:
ref = torch.tensor([ref])
ref = ref.cpu() # to_dtype test has cases with gpu output
torch.testing.assert_close(
out.cpu(),
Expand Down
60 changes: 60 additions & 0 deletions tests/py/dynamo/conversion/test_full_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn as nn
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestFullConverter(DispatchTestCase):
@parameterized.expand(
[
((5,), 2),
((5, 3), 0.1),
((5, 3, 2), True),
]
)
def test_full_static(self, shape, fill_value):
class full(nn.Module):
def forward(self, x):
return torch.ops.aten.full.default(shape, fill_value)

inputs = [torch.randn(1, 1)]
self.run_test(
full(),
inputs,
)

@parameterized.expand(
[
((1,), (3,), (4,), [3], 11),
((3, 5), (3, 7), (3, 10), [3, 7], False),
((1, 5), (3, 7), (4, 10), [3, 7], True),
((1, 5, 3), (3, 7, 3), (4, 10, 4), [3, 7, 3], 0.11),
]
)
def test_full_dynamic(self, min_shape, opt_shape, max_shape, data, fill_value):
class full(nn.Module):
def forward(self, shape):
return torch.ops.aten.full.default(shape, fill_value)

inputs = [
torch_tensorrt.Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.int64,
torch_tensor=torch.tensor(data, dtype=torch.int64).cuda(),
is_shape_tensor=True,
)
]
self.run_test_with_dynamic_shape(
full(),
inputs,
use_example_tensors=False,
)


if __name__ == "__main__":
run_tests()
34 changes: 34 additions & 0 deletions tests/py/dynamo/conversion/test_sym_not_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestSymNotConverter(DispatchTestCase):

@parameterized.expand(
[
(torch.tensor(True),),
(torch.tensor(False),),
(torch.tensor([True]),),
(torch.tensor([[True]]),),
(torch.tensor([[False]]),),
]
)
def test_sym_not_bool(self, data):
class sym_not(nn.Module):
def forward(self, input):
return torch.sym_not(input)

inputs = [data]

self.run_test(
sym_not(),
inputs,
)


if __name__ == "__main__":
run_tests()
Loading