Skip to content

feat: dynamic shape support for pow/mod/eq operator #2982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,7 @@ def aten_ops_div(
@dynamo_tensorrt_converter(
torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(operator.pow, supports_dynamic_shapes=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the PR (#2918), it looks like dynamic testing and registration for operator.mul were added because PyTorch internally uses operator.mul when using torch.nn.Linear layers. This might be why operator.mul was registered with the converter.

If I'm misunderstanding, could you please explain? (cc. @peri044 )

Then, in this PR, why do we need to register operator.pow, operator.eq, and operator.mod with the converter?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems operator.* overrides original Python ops, which results in running any Python op will create a TRT Layer to handle it. However, they are not listed in the schema. I'm also curious if these ops' converters are needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chohk88 / @zewenli98
If operator.pow is missing and ** is used in module, there is exception from TRTInterpreter.

E torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function _operator.pow not currently supported!

I saw such usage of operator in openchat model. I think we need to register these operators to support such usage.
https://github.com/imoneoi/openchat/blob/master/ochat/models/unpadded_llama.py#L105

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example. If then, it seems to be necessary I think.
cc: @narendasan @dheerajperi

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, openchat model uses operator.pow and there's one other model too I think. We need them to be registered as converters similar to other operator.* variants we have.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example. LGTM!

def aten_ops_pow(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2278,6 +2279,7 @@ def aten_ops_bitwise_not(

@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(operator.eq, supports_dynamic_shapes=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -3149,8 +3151,13 @@ def aten_ops_copy(
)


@dynamo_tensorrt_converter(torch.ops.aten.remainder.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Tensor)
@dynamo_tensorrt_converter(
torch.ops.aten.remainder.Scalar, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.remainder.Tensor, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(operator.mod, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
48 changes: 48 additions & 0 deletions tests/py/dynamo/conversion/test_eq_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,54 @@ def forward(self, lhs_val):
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_eq_operator_dynamic_shape(self, min_shape, opt_shape, max_shape):
class eq_tensor_operator(nn.Module):
def forward(self, lhs_val, rhs_val):
return lhs_val == rhs_val

class eq_tensor_scalar_operator(nn.Module):
def forward(self, lhs_val, rhs_val):
return lhs_val == torch.tensor(1)

class eq_scalar_operator(nn.Module):
def forward(self, lhs_val, rhs_val):
return lhs_val == 1.0

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
eq_tensor_operator(),
input_specs,
)
self.run_test_with_dynamic_shape(
eq_tensor_scalar_operator(),
input_specs,
)
self.run_test_with_dynamic_shape(
eq_scalar_operator(),
input_specs,
)


if __name__ == "__main__":
run_tests()
59 changes: 59 additions & 0 deletions tests/py/dynamo/conversion/test_pow_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,65 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
(
"2d_dim_dtype_half",
(1, 1),
(2, 2),
(4, 4),
torch.half,
torch.half,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_pow_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class pow(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.pow.Tensor_Tensor(lhs_val, rhs_val)

class pow_scalar(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.pow.Tensor_Scalar(lhs_val, 2.0)

class pow_operator(nn.Module):
def forward(self, lhs_val, rhs_val):
return lhs_val**rhs_val

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
pow(), input_specs, output_dtypes=[output_type]
)
self.run_test_with_dynamic_shape(
pow_scalar(), input_specs, output_dtypes=[output_type]
)
self.run_test_with_dynamic_shape(
pow_operator(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()
59 changes: 59 additions & 0 deletions tests/py/dynamo/conversion/test_remainder_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,65 @@ def forward(self, lhs_val, rhs_val):
inputs,
)

@parameterized.expand(
[
(
"2d_dim_dtype_half",
(1, 1),
(2, 2),
(4, 4),
torch.half,
torch.half,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_remainder_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class remainder(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.remainder.Tensor(lhs_val, rhs_val)

class remainder_scalar(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.remainder.Scalar(lhs_val, 2)

class mod_operator(nn.Module):
def forward(self, lhs_val, rhs_val):
return lhs_val % rhs_val

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
remainder(), input_specs, output_dtypes=[output_type]
)
self.run_test_with_dynamic_shape(
remainder_scalar(), input_specs, output_dtypes=[output_type]
)
self.run_test_with_dynamic_shape(
mod_operator(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()
Loading