Skip to content

Commit 959f3e5

Browse files
authored
Merge branch 'main' into min_cpp_build
2 parents 22faab8 + b76024d commit 959f3e5

File tree

14 files changed

+1227
-5
lines changed

14 files changed

+1227
-5
lines changed

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
4747
RUN add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /"
4848
RUN apt-get update
4949

50-
RUN apt-get install -y libnvinfer8=${TENSORRT_VERSION}.* libnvinfer-plugin8=${TENSORRT_VERSION}.* libnvinfer-dev=${TENSORRT_VERSION}.* libnvinfer-plugin-dev=${TENSORRT_VERSION}.* libnvonnxparsers8=${TENSORRT_VERSION}.* libnvonnxparsers-dev=${TENSORRT_VERSION}.* libnvparsers8=${TENSORRT_VERSION}.* libnvparsers-dev=${TENSORRT_VERSION}.*
50+
RUN apt-get install -y libnvinfer8=${TENSORRT_VERSION}.* libnvinfer-plugin8=${TENSORRT_VERSION}.* libnvinfer-dev=${TENSORRT_VERSION}.* libnvinfer-plugin-dev=${TENSORRT_VERSION}.* libnvonnxparsers8=${TENSORRT_VERSION}.* libnvonnxparsers-dev=${TENSORRT_VERSION}.* libnvparsers8=${TENSORRT_VERSION}.* libnvparsers-dev=${TENSORRT_VERSION}.* libnvinfer-headers-dev=${TENSORRT_VERSION}.* libnvinfer-headers-plugin-dev=${TENSORRT_VERSION}.*
5151

5252
# Setup Bazel via Bazelisk
5353
RUN wget -q https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 -O /usr/bin/bazel &&\

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,23 @@ def aten_ops_exp(
11361136
)
11371137

11381138

1139+
@dynamo_tensorrt_converter(torch.ops.aten.expm1.default)
1140+
def aten_ops_expm1(
1141+
ctx: ConversionContext,
1142+
target: Target,
1143+
args: Tuple[Argument, ...],
1144+
kwargs: Dict[str, Argument],
1145+
name: str,
1146+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1147+
return impl.unary.expm1(
1148+
ctx,
1149+
target,
1150+
SourceIR.ATEN,
1151+
name,
1152+
args[0],
1153+
)
1154+
1155+
11391156
@dynamo_tensorrt_converter(torch.ops.aten.log.default)
11401157
def aten_ops_log(
11411158
ctx: ConversionContext,
@@ -1391,6 +1408,30 @@ def aten_ops_atanh(
13911408
)
13921409

13931410

1411+
@dynamo_tensorrt_converter(torch.ops.aten.atan2.default)
1412+
@enforce_tensor_types(
1413+
{
1414+
0: (TRTTensor,),
1415+
1: (TRTTensor,),
1416+
}
1417+
)
1418+
def aten_ops_atan2(
1419+
ctx: ConversionContext,
1420+
target: Target,
1421+
args: Tuple[Argument, ...],
1422+
kwargs: Dict[str, Argument],
1423+
name: str,
1424+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1425+
return impl.elementwise.atan2(
1426+
ctx,
1427+
target,
1428+
SourceIR.ATEN,
1429+
name,
1430+
args[0],
1431+
args[1],
1432+
)
1433+
1434+
13941435
@dynamo_tensorrt_converter(torch.ops.aten.ceil.default)
13951436
def aten_ops_ceil(
13961437
ctx: ConversionContext,
@@ -1493,6 +1534,23 @@ def aten_ops_isinf(
14931534
)
14941535

14951536

1537+
@dynamo_tensorrt_converter(torch.ops.aten.isnan.default)
1538+
def aten_ops_isnan(
1539+
ctx: ConversionContext,
1540+
target: Target,
1541+
args: Tuple[Argument, ...],
1542+
kwargs: Dict[str, Argument],
1543+
name: str,
1544+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1545+
return impl.unary.isnan(
1546+
ctx,
1547+
target,
1548+
SourceIR.ATEN,
1549+
name,
1550+
args[0],
1551+
)
1552+
1553+
14961554
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
14971555
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
14981556
def aten_ops_add(
@@ -2185,7 +2243,12 @@ def aten_ops_avg_pool(
21852243

21862244

21872245
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
2188-
def aten_ops_adaptive_avg_pool(
2246+
@enforce_tensor_types(
2247+
{
2248+
0: (TRTTensor,),
2249+
}
2250+
)
2251+
def aten_ops_adaptive_avg_pool1d(
21892252
ctx: ConversionContext,
21902253
target: Target,
21912254
args: Tuple[Argument, ...],
@@ -2202,6 +2265,32 @@ def aten_ops_adaptive_avg_pool(
22022265
)
22032266

22042267

2268+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
2269+
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
2270+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
2271+
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
2272+
@enforce_tensor_types(
2273+
{
2274+
0: (TRTTensor,),
2275+
}
2276+
)
2277+
def aten_ops_adaptive_avg_poolNd(
2278+
ctx: ConversionContext,
2279+
target: Target,
2280+
args: Tuple[Argument, ...],
2281+
kwargs: Dict[str, Argument],
2282+
name: str,
2283+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2284+
return impl.pool.adaptive_avg_poolNd(
2285+
ctx,
2286+
target,
2287+
source_ir=SourceIR.ATEN,
2288+
name=name,
2289+
input=args[0],
2290+
output_size=args[1],
2291+
)
2292+
2293+
22052294
def max_pool_param_validator(pool_node: Node) -> bool:
22062295
dilation = args_bounds_check(pool_node.args, 4, 1)
22072296
ceil_mode = args_bounds_check(pool_node.args, 5, False)
@@ -2319,6 +2408,29 @@ def aten_ops_pixel_shuffle(
23192408
)
23202409

23212410

2411+
@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default)
2412+
@enforce_tensor_types(
2413+
{
2414+
0: (TRTTensor,),
2415+
}
2416+
)
2417+
def aten_ops_pixel_unshuffle(
2418+
ctx: ConversionContext,
2419+
target: Target,
2420+
args: Tuple[Argument, ...],
2421+
kwargs: Dict[str, Argument],
2422+
name: str,
2423+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2424+
return impl.shuffle.pixel_unshuffle(
2425+
ctx,
2426+
target,
2427+
SourceIR.ATEN,
2428+
name,
2429+
args[0],
2430+
args[1],
2431+
)
2432+
2433+
23222434
@enforce_tensor_types({0: (TRTTensor,)})
23232435
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
23242436
def aten_ops_argmax(
@@ -2782,3 +2894,28 @@ def aten_ops_roll(
27822894
args[1],
27832895
args_bounds_check(args, 2, []),
27842896
)
2897+
2898+
2899+
@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
2900+
@enforce_tensor_types(
2901+
{
2902+
0: (TRTTensor,),
2903+
2: (TRTTensor,),
2904+
}
2905+
)
2906+
def aten_ops_index_select(
2907+
ctx: ConversionContext,
2908+
target: Target,
2909+
args: Tuple[Argument, ...],
2910+
kwargs: Dict[str, Argument],
2911+
name: str,
2912+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2913+
return impl.select.index_select(
2914+
ctx,
2915+
target,
2916+
SourceIR.ATEN,
2917+
name,
2918+
args[0],
2919+
args[1],
2920+
args[2],
2921+
)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from torch_tensorrt.dynamo.conversion.converter_utils import (
1010
cast_int_int_div_trt_tensor,
1111
cast_int_or_float_to_bool,
12+
cast_trt_tensor,
1213
get_trt_tensor,
1314
)
1415
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1516
convert_binary_elementwise,
1617
)
17-
from torch_tensorrt.dynamo.conversion.impl.unary import sign
18+
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
1819
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
20+
from torch_tensorrt.fx.converters.converter_utils import broadcast
1921
from torch_tensorrt.fx.types import TRTTensor
2022

2123
import tensorrt as trt
@@ -214,6 +216,180 @@ def remainder(
214216
return fmod2_value
215217

216218

219+
def atan2(
220+
ctx: ConversionContext,
221+
target: Target,
222+
source_ir: Optional[SourceIR],
223+
name: str,
224+
input: TRTTensor,
225+
other: TRTTensor,
226+
) -> TRTTensor:
227+
"""
228+
Perform atan2 operation on Tensor, calculating the arctangent of the quotient of input tensors.
229+
atan2(x,y) = atan(x/y) if y > 0,
230+
= atan(x/y) + π if x ≥ 0 and y < 0,
231+
= atan(x/y) - π if x < 0 and y < 0,
232+
= π/2 if x > 0 and y = 0,
233+
= -π/2 if x < 0 and y = 0,
234+
= 0 if x = 0 and y = 0
235+
236+
Args:
237+
ctx: ConversionContext.
238+
target: node target
239+
source_ir (SourceIR): Source IR calling the function.
240+
name: namespace for the op
241+
input: Tensor or constant representing the dividend.
242+
other: Tensor or constant representing the divisor.
243+
244+
Returns:
245+
A TensorRT tensor representing the result of the atan2 operation.
246+
"""
247+
pi_value = 3.141592653589793
248+
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi")
249+
250+
if isinstance(input, TRTTensor):
251+
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input")
252+
if isinstance(other, TRTTensor):
253+
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
254+
255+
input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other")
256+
257+
# Calculate x_zero, y_zero (whether inputs are zero)
258+
x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0)
259+
y_zero = eq(ctx, target, source_ir, f"{name}_y_zero", other, 0)
260+
261+
# Get sign of inputs
262+
x_positive = gt(ctx, target, source_ir, f"{name}_x_positive", input, 0)
263+
x_zero_positive = ge(ctx, target, source_ir, f"{name}_x_zero_positive", input, 0)
264+
x_negative = lt(ctx, target, source_ir, f"{name}_x_negative", input, 0)
265+
y_positive = gt(ctx, target, source_ir, f"{name}_y_positive", other, 0)
266+
y_negative = lt(ctx, target, source_ir, f"{name}_y_negative", other, 0)
267+
268+
# Calculate atan(x/y)
269+
input_div_other = div(
270+
ctx, target, source_ir, f"{name}_input_div_other", input, other
271+
)
272+
atan_val = atan(ctx, target, source_ir, f"{name}_atan", input_div_other)
273+
274+
# atan(x/y)+π if x≥0 and y<0,
275+
atan_add_pi = add(
276+
ctx, target, source_ir, f"{name}_atan_add_pi", atan_val, pi_tensor
277+
)
278+
279+
# atan(x/y)-π if x<0 and y<0,
280+
atan_sub_pi = sub(
281+
ctx, target, source_ir, f"{name}_atan_sub_pi", atan_val, pi_tensor
282+
)
283+
284+
# atan(x/y)+π if x≥0 and y<0,
285+
atan_corrected = impl.condition.select(
286+
ctx,
287+
target,
288+
source_ir,
289+
f"{name}_atan_corrected",
290+
atan_add_pi,
291+
atan_val,
292+
logical_and(
293+
ctx,
294+
target,
295+
source_ir,
296+
f"{name}_x_zero_positive_and_y_negative",
297+
x_zero_positive,
298+
y_negative,
299+
),
300+
)
301+
302+
# atan(x/y)-π if x<0 and y<0,
303+
atan_corrected_2 = impl.condition.select(
304+
ctx,
305+
target,
306+
source_ir,
307+
f"{name}_atan_corrected_2",
308+
atan_sub_pi,
309+
atan_corrected,
310+
logical_and(
311+
ctx,
312+
target,
313+
source_ir,
314+
f"{name}_x_negative_and_y_negative",
315+
x_negative,
316+
y_negative,
317+
),
318+
)
319+
320+
# atan(x/y) if y>0
321+
atan_output = impl.condition.select(
322+
ctx,
323+
target,
324+
source_ir,
325+
f"{name}_atan_output",
326+
atan_val,
327+
atan_corrected_2,
328+
y_positive,
329+
)
330+
331+
# on x or y-axis
332+
pi_over_2_tensor = get_trt_tensor(
333+
ctx,
334+
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
335+
f"{name}_pi_over_2_tensor",
336+
dtype=trt.float32,
337+
)
338+
minus_pi_over_2_tensor = get_trt_tensor(
339+
ctx,
340+
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
341+
f"{name}_minus_pi_over_2_tensor",
342+
dtype=trt.float32,
343+
)
344+
zero_tensor = get_trt_tensor(
345+
ctx,
346+
np.zeros(input.shape, dtype=np.float32),
347+
f"{name}_zero_tensor",
348+
dtype=trt.float32,
349+
)
350+
351+
# π/2 if x>0 and y=0,
352+
pi_over_2_output = impl.condition.select(
353+
ctx,
354+
target,
355+
source_ir,
356+
f"{name}_pi_over_2_output",
357+
pi_over_2_tensor,
358+
atan_output,
359+
logical_and(
360+
ctx, target, source_ir, f"{name}_x_zero_and_y_positive", x_positive, y_zero
361+
),
362+
)
363+
364+
# -π/2 if x<0 and y=0,
365+
minus_pi_over_2_output = impl.condition.select(
366+
ctx,
367+
target,
368+
source_ir,
369+
f"{name}_minus_pi_over_2_output",
370+
minus_pi_over_2_tensor,
371+
pi_over_2_output,
372+
logical_and(
373+
ctx, target, source_ir, f"{name}_x_zero_and_y_negative", x_negative, y_zero
374+
),
375+
)
376+
377+
# 0 if x=0 and y=0,
378+
zero_output = impl.condition.select(
379+
ctx,
380+
target,
381+
source_ir,
382+
f"{name}_zero_output",
383+
zero_tensor,
384+
minus_pi_over_2_output,
385+
logical_and(
386+
ctx, target, source_ir, f"{name}_x_zero_and_y_zero", y_zero, x_zero
387+
),
388+
)
389+
390+
return zero_output
391+
392+
217393
def clamp(
218394
ctx: ConversionContext,
219395
target: Target,

0 commit comments

Comments
 (0)