Skip to content

Commit 088900d

Browse files
authored
feat: support aten.clamp.Tensor and update aten.clamp.default dynamo converters (#2522)
1 parent 128dd65 commit 088900d

File tree

5 files changed

+68
-107
lines changed

5 files changed

+68
-107
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -499,25 +499,6 @@ def aten_ops_softplus(
499499
)
500500

501501

502-
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
503-
def aten_ops_clip(
504-
ctx: ConversionContext,
505-
target: Target,
506-
args: Tuple[Argument, ...],
507-
kwargs: Dict[str, Argument],
508-
name: str,
509-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
510-
return impl.activation.clip(
511-
ctx,
512-
target,
513-
SourceIR.ATEN,
514-
name,
515-
args[0],
516-
alpha=args_bounds_check(args, 1),
517-
beta=args_bounds_check(args, 2),
518-
)
519-
520-
521502
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
522503
def aten_ops_hard_sigmoid(
523504
ctx: ConversionContext,
@@ -695,6 +676,9 @@ def aten_ops_where(
695676

696677

697678
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
679+
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
680+
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
681+
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor)
698682
def aten_ops_clamp(
699683
ctx: ConversionContext,
700684
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -235,36 +235,6 @@ def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]
235235
)
236236

237237

238-
def clip(
239-
ctx: ConversionContext,
240-
target: Target,
241-
source_ir: Optional[SourceIR],
242-
name: str,
243-
input_val: TRTTensor,
244-
alpha: float,
245-
beta: float,
246-
) -> TRTTensor:
247-
operation_type = trt.ActivationType.CLIP
248-
249-
def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]:
250-
def clip_fn(x: float) -> float:
251-
return max(alpha, min(beta, x))
252-
253-
return clip_fn(dyn_range[0]), clip_fn(dyn_range[1])
254-
255-
return convert_activation(
256-
ctx,
257-
target,
258-
source_ir,
259-
name,
260-
operation_type,
261-
input_val,
262-
alpha=alpha,
263-
beta=beta,
264-
dyn_range_fn=clip_dyn_range_fn,
265-
)
266-
267-
268238
def hard_sigmoid(
269239
ctx: ConversionContext,
270240
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Optional, Union
22

3-
import numpy as np
43
import tensorrt as trt
54
import torch
65
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -17,7 +16,6 @@
1716
)
1817
from torch_tensorrt.dynamo.conversion.impl.unary import sign
1918
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
20-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
2119
from torch_tensorrt.fx.types import TRTTensor
2220
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2321

@@ -186,63 +184,21 @@ def clamp(
186184
source_ir: Optional[SourceIR],
187185
name: str,
188186
input_val: TRTTensor,
189-
min_val: Optional[float] = None,
190-
max_val: Optional[float] = None,
187+
min_val: Optional[Union[int, float, TRTTensor]] = None,
188+
max_val: Optional[Union[int, float, TRTTensor]] = None,
191189
) -> TRTTensor:
192-
if not isinstance(input_val, TRTTensor):
193-
raise RuntimeError(
194-
f"Clamp received input {input_val} that is not part "
195-
"of the TensorRT region!"
196-
)
197-
198-
def _add_layer(
199-
ctx: ConversionContext,
200-
input: TRTTensor,
201-
val: float,
202-
op: trt.ElementWiseOperation,
203-
name: str,
204-
) -> (
205-
trt.ILayer
206-
): # TODO: Simplify and merge implementations, should just be max and min stacked
207-
if not len(input.shape):
208-
# clamping scalar
209-
acc_ops_clamp_trt = get_trt_tensor(
210-
ctx,
211-
squeeze_left(
212-
np.array(
213-
[val],
214-
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
215-
)
216-
),
217-
f"{name}_clamp_{val}",
218-
)
219-
else:
220-
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
221-
acc_ops_clamp_tensor = np.full(
222-
acc_ops_clamp_shape,
223-
val,
224-
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
225-
)
226-
acc_ops_clamp_trt = ctx.net.add_constant(
227-
acc_ops_clamp_shape, acc_ops_clamp_tensor
228-
).get_output(0)
229-
layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op)
230-
return layer
231-
190+
clamped_val = input_val
232191
if min_val is not None:
233-
clamp_min_layer = _add_layer(
234-
ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name
192+
clamped_val = impl.elementwise.max(
193+
ctx, target, source_ir, f"{name}_max", clamped_val, min_val
235194
)
236-
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
237-
input_val = clamp_min_layer.get_output(0)
195+
238196
if max_val is not None:
239-
clamp_max_layer = _add_layer(
240-
ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name
197+
clamped_val = impl.elementwise.min(
198+
ctx, target, source_ir, f"{name}_min", clamped_val, max_val
241199
)
242-
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
243-
input_val = clamp_max_layer.get_output(0)
244200

245-
return input_val
201+
return clamped_val
246202

247203

248204
def add(

tests/py/dynamo/conversion/test_clamp_aten.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def forward(self, x):
4949

5050
class TestScalarModule(torch.nn.Module):
5151
def forward(self, x):
52-
y = torch.ops.aten.mean.default(x)
52+
y = torch.ops.aten.mean.dim(x, None, True)
5353
return torch.ops.aten.clamp.default(y, min, max)
5454

5555
input_specs = [
@@ -63,6 +63,30 @@ def forward(self, x):
6363
self.run_test_with_dynamic_shape(TestModule(), input_specs)
6464
self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)
6565

66+
@parameterized.expand(
67+
[
68+
param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)),
69+
param("min", min=0.5 * torch.randn(3, 4)),
70+
param("max", max=0.5 * torch.randn(3, 4)),
71+
param(
72+
"minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)
73+
),
74+
param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)),
75+
]
76+
)
77+
def test_clamp_tensor(
78+
self,
79+
test_name,
80+
min=None,
81+
max=None,
82+
):
83+
class TestModule(torch.nn.Module):
84+
def forward(self, x):
85+
return torch.ops.aten.clamp.Tensor(x, min, max)
86+
87+
inputs = [torch.randn(3, 4)]
88+
self.run_test(TestModule(), inputs)
89+
6690

6791
if __name__ == "__main__":
6892
run_tests()

tests/py/dynamo/conversion/test_clip_aten.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,38 @@ class TestClipConverter(DispatchTestCase):
1919
def test_clip(self, test_name, min=None, max=None):
2020
class TestModule(torch.nn.Module):
2121
def forward(self, x):
22-
return torch.ops.aten.clamp.default(x, min, max)
22+
return torch.ops.aten.clip.default(x, min, max)
2323

2424
inputs = [torch.randn(3, 4)]
2525
self.run_test(TestModule(), inputs)
2626

27+
@parameterized.expand(
28+
[
29+
param(
30+
"defaultInt32",
31+
min=torch.tensor(-1, dtype=torch.int32),
32+
max=torch.tensor(0, dtype=torch.int32),
33+
),
34+
param(
35+
"defaultFloat32",
36+
min=torch.tensor(0.5, dtype=torch.float32),
37+
max=torch.tensor(1.0, dtype=torch.float32),
38+
),
39+
param(
40+
"minBiggerThanMax",
41+
min=torch.tensor(1.0, dtype=torch.float32),
42+
max=torch.tensor(0, dtype=torch.int32),
43+
),
44+
]
45+
)
46+
def test_clip(self, test_name, min=None, max=None):
47+
class TestModule(torch.nn.Module):
48+
def forward(self, x, min, max):
49+
return torch.ops.aten.clip.Tensor(x, min, max)
50+
51+
inputs = [torch.randn(3, 4), min, max]
52+
self.run_test(TestModule(), inputs)
53+
2754
@parameterized.expand(
2855
[
2956
param("default", min=-1, max=0),
@@ -37,12 +64,12 @@ def test_clip_with_dynamic_shape_four_dimensions(
3764
):
3865
class TestModule(torch.nn.Module):
3966
def forward(self, x):
40-
return torch.ops.aten.clamp.default(x, min, max)
67+
return torch.ops.aten.clip.default(x, min, max)
4168

4269
class TestScalarModule(torch.nn.Module):
4370
def forward(self, x):
44-
y = torch.ops.aten.mean.default(x)
45-
return torch.ops.aten.clamp.default(y, min, max)
71+
y = torch.ops.aten.mean.dim(x, None, True)
72+
return torch.ops.aten.clip.default(y, min, max)
4673

4774
input_specs = [
4875
Input(

0 commit comments

Comments
 (0)