Skip to content

chore: dynamic shape support for clamp/min/max/floor_div/logical_and #2977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,10 +733,10 @@ def aten_ops_where(
)


@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.clip.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor, supports_dynamic_shapes=True)
def aten_ops_clamp(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1878,7 +1878,7 @@ def aten_ops_mul(
)


@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default, supports_dynamic_shapes=True)
def aten_ops_maximum(
ctx: ConversionContext,
target: Target,
Expand All @@ -1896,7 +1896,7 @@ def aten_ops_maximum(
)


@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default, supports_dynamic_shapes=True)
def aten_ops_minimum(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2017,8 +2017,13 @@ def aten_ops_pow(
)


@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
@dynamo_tensorrt_converter(
torch.ops.aten.floor_divide.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.floor_divide.Scalar, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(operator.floordiv, supports_dynamic_shapes=True)
def aten_ops_floor_div(
ctx: ConversionContext,
target: Target,
Expand All @@ -2036,7 +2041,9 @@ def aten_ops_floor_div(
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
@dynamo_tensorrt_converter(
torch.ops.aten.logical_and.default, supports_dynamic_shapes=True
)
def aten_ops_logical_and(
ctx: ConversionContext,
target: Target,
Expand Down
8 changes: 4 additions & 4 deletions tests/py/dynamo/conversion/test_clamp_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def forward(self, x):

input_specs = [
Input(
shape=(-1, -1, 3, 3),
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
min_shape=(1, 1, 3, 3),
opt_shape=(3, 3, 3, 3),
max_shape=(5, 5, 3, 3),
dtype=torch.float,
),
]

self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)

Expand Down
52 changes: 52 additions & 0 deletions tests/py/dynamo/conversion/test_floor_div_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,58 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
(
"2d_dim_dtype_half",
(1, 1),
(2, 2),
(4, 4),
torch.half,
torch.half,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_floor_div_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class floor_div(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.floor_divide.default(lhs_val, rhs_val)

class floor_div_operator(nn.Module):
def forward(self, lhs_val, rhs_val):
return lhs_val // rhs_val

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
floor_div(), input_specs, output_dtypes=[output_type]
)
self.run_test_with_dynamic_shape(
floor_div_operator(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()
39 changes: 39 additions & 0 deletions tests/py/dynamo/conversion/test_logical_and_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,45 @@ def forward(self, lhs_val, rhs_val):
inputs,
)

@parameterized.expand(
[
(
"2d_dim_dtype_float",
(1, 1),
(2, 2),
(4, 4),
torch.float,
),
(
"3d_dim_dtype_bool",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.bool,
),
]
)
def test_logical_and_dynamic_shape(self, _, min_shape, opt_shape, max_shape, type):
class logical_and(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.logical_and.default(lhs_val, rhs_val)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(logical_and(), input_specs)


if __name__ == "__main__":
run_tests()
47 changes: 46 additions & 1 deletion tests/py/dynamo/conversion/test_maximum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestMaximumConverter(DispatchTestCase):
def test_maximum(self, _, shape):
class Maximum(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.maximum(lhs_val, rhs_val)
return torch.ops.aten.maximum.default(lhs_val, rhs_val)

inputs = [torch.randn(shape), torch.randn(shape)]
self.run_test(
Expand All @@ -26,6 +26,51 @@ def forward(self, lhs_val, rhs_val):
use_dynamo_tracer=True,
)

@parameterized.expand(
[
(
"2d_dim_dtype_half",
(1, 1),
(2, 2),
(4, 4),
torch.half,
torch.half,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_maximum_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class Maximum(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.maximum.default(lhs_val, rhs_val)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
Maximum(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()
47 changes: 46 additions & 1 deletion tests/py/dynamo/conversion/test_minimum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestMinimumConverter(DispatchTestCase):
def test_minimum(self, _, shape):
class Minimum(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.minimum(lhs_val, rhs_val)
return torch.ops.aten.minimum.default(lhs_val, rhs_val)

inputs = [torch.randn(shape), torch.randn(shape)]
self.run_test(
Expand All @@ -26,6 +26,51 @@ def forward(self, lhs_val, rhs_val):
use_dynamo_tracer=True,
)

@parameterized.expand(
[
(
"2d_dim_dtype_half",
(1, 1),
(2, 2),
(4, 4),
torch.half,
torch.half,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_minimum_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class Minimum(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.minimum.default(lhs_val, rhs_val)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
Minimum(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()
Loading
Loading