Skip to content

Commit b6ed1c5

Browse files
authored
chore: Adapt CIA ops decomposition handling in upsample converters to torch 2.6 (#3227)
1 parent 40193dc commit b6ed1c5

File tree

5 files changed

+117
-371
lines changed

5 files changed

+117
-371
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 277 deletions
Original file line numberDiff line numberDiff line change
@@ -3110,232 +3110,21 @@ def aten_ops_pad(
31103110
)
31113111

31123112

3113-
for op in (
3114-
torch.ops.aten.upsample_nearest1d,
3115-
torch.ops.aten.upsample_nearest2d,
3116-
torch.ops.aten.upsample_nearest3d,
3117-
torch.ops.aten.upsample_linear1d,
3118-
torch.ops.aten.upsample_bilinear2d,
3119-
torch.ops.aten.upsample_trilinear3d,
3120-
torch.ops.aten.upsample_bicubic2d,
3121-
):
3122-
for key in (
3123-
torch._C.DispatchKey.Autograd,
3124-
torch._C.DispatchKey.CompositeImplicitAutograd,
3125-
):
3126-
if key in op.default.py_kernels:
3127-
del op.default.py_kernels[key]
3128-
if key in op.vec.py_kernels:
3129-
del op.vec.py_kernels[key]
3130-
3131-
3132-
def upsample_compute_output_size(
3133-
input_size: torch.Size,
3134-
output_size: Optional[Sequence[int]],
3135-
scale_factors: Optional[Sequence[float]],
3136-
) -> Optional[Sequence[int]]:
3137-
spatial_dimensions = len(input_size) - 2
3138-
3139-
if output_size is None and scale_factors is None:
3140-
raise AssertionError(
3141-
"Must specify exactly one of output_size and scale_factors"
3142-
)
3143-
3144-
if output_size is not None:
3145-
torch._check(
3146-
scale_factors is None,
3147-
lambda: "Must specify exactly one of output_size and scale_factors",
3148-
)
3149-
torch._check(len(output_size) == spatial_dimensions)
3150-
3151-
if scale_factors is not None:
3152-
torch._check(
3153-
output_size is None,
3154-
lambda: "Must specify exactly one of output_size and scale_factors",
3155-
)
3156-
torch._check(len(scale_factors) == spatial_dimensions)
3157-
output_size = []
3158-
for i, s in enumerate(scale_factors):
3159-
output_size.append(int(input_size[i + 2] * s))
3160-
3161-
return output_size
3162-
3163-
3164-
@torch.ops.aten.upsample_nearest1d.vec.py_impl(
3165-
torch._C.DispatchKey.CompositeImplicitAutograd
3166-
)
3167-
def upsample_nearest1d_vec(
3168-
input: torch.Tensor,
3169-
output_size: Optional[Sequence[int]],
3170-
scale_factors: Optional[Sequence[float]],
3171-
) -> torch.Tensor:
3172-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3173-
if scale_factors is not None:
3174-
return torch.ops.aten.upsample_nearest1d.default(input, osize, *scale_factors)
3175-
return torch.ops.aten.upsample_nearest1d.default(input, osize)
3176-
3177-
3178-
@torch.ops.aten.upsample_nearest2d.vec.py_impl(
3179-
torch._C.DispatchKey.CompositeImplicitAutograd
3180-
)
3181-
def upsample_nearest2d_vec(
3182-
input: torch.Tensor,
3183-
output_size: Optional[Sequence[int]],
3184-
scale_factors: Optional[Sequence[float]],
3185-
) -> torch.Tensor:
3186-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3187-
if scale_factors is not None:
3188-
return torch.ops.aten.upsample_nearest2d.default(input, osize, *scale_factors)
3189-
return torch.ops.aten.upsample_nearest2d.default(input, osize)
3190-
3191-
3192-
@torch.ops.aten.upsample_nearest3d.vec.py_impl(
3193-
torch._C.DispatchKey.CompositeImplicitAutograd
3194-
)
3195-
def upsample_nearest3d_vec(
3196-
input: torch.Tensor,
3197-
output_size: Optional[Sequence[int]],
3198-
scale_factors: Optional[Sequence[float]],
3199-
) -> torch.Tensor:
3200-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3201-
if scale_factors is not None:
3202-
return torch.ops.aten.upsample_nearest3d.default(input, osize, *scale_factors)
3203-
return torch.ops.aten.upsample_nearest3d.default(input, osize)
3204-
3205-
3206-
@torch.ops.aten.upsample_linear1d.vec.py_impl(
3207-
torch._C.DispatchKey.CompositeImplicitAutograd
3208-
)
3209-
def upsample_linear1d_vec(
3210-
input: torch.Tensor,
3211-
output_size: Optional[Sequence[int]],
3212-
align_corners: bool,
3213-
scale_factors: Optional[Sequence[float]],
3214-
) -> torch.Tensor:
3215-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3216-
if scale_factors is not None:
3217-
return torch.ops.aten.upsample_linear1d.default(
3218-
input, osize, align_corners, *scale_factors
3219-
)
3220-
return torch.ops.aten.upsample_linear1d.default(input, osize, align_corners)
3221-
3222-
3223-
@torch.ops.aten.upsample_bilinear2d.vec.py_impl(
3224-
torch._C.DispatchKey.CompositeImplicitAutograd
3225-
)
3226-
def upsample_bilinear2d_vec(
3227-
input: torch.Tensor,
3228-
output_size: Optional[Sequence[int]],
3229-
align_corners: bool,
3230-
scale_factors: Optional[Sequence[float]],
3231-
) -> torch.Tensor:
3232-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3233-
if scale_factors is not None:
3234-
return torch.ops.aten.upsample_bilinear2d.default(
3235-
input, osize, align_corners, *scale_factors
3236-
)
3237-
return torch.ops.aten.upsample_bilinear2d.default(input, osize, align_corners)
3238-
3239-
3240-
@torch.ops.aten.upsample_trilinear3d.vec.py_impl(
3241-
torch._C.DispatchKey.CompositeImplicitAutograd
3242-
)
3243-
def upsample_trilinear3d_vec(
3244-
input: torch.Tensor,
3245-
output_size: Optional[Sequence[int]],
3246-
align_corners: bool,
3247-
scale_factors: Optional[Sequence[float]],
3248-
) -> torch.Tensor:
3249-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3250-
if scale_factors is not None:
3251-
return torch.ops.aten.upsample_trilinear3d.default(
3252-
input, osize, align_corners, *scale_factors
3253-
)
3254-
return torch.ops.aten.upsample_trilinear3d.default(input, osize, align_corners)
3255-
3256-
3257-
@torch.ops.aten.upsample_bicubic2d.vec.py_impl(
3258-
torch._C.DispatchKey.CompositeImplicitAutograd
3259-
)
3260-
def upsample_bicubic2d_vec(
3261-
input: torch.Tensor,
3262-
output_size: Optional[Sequence[int]],
3263-
align_corners: bool,
3264-
scale_factors: Optional[Sequence[float]],
3265-
) -> torch.Tensor:
3266-
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3267-
if scale_factors is not None:
3268-
return torch.ops.aten.upsample_bicubic2d.default(
3269-
input, osize, align_corners, *scale_factors
3270-
)
3271-
return torch.ops.aten.upsample_bicubic2d.default(input, osize, align_corners)
3272-
3273-
32743113
@dynamo_tensorrt_converter(
3275-
torch.ops.aten.upsample_nearest1d.default, supports_dynamic_shapes=True
3114+
torch.ops.aten.upsample_nearest1d.vec, supports_dynamic_shapes=True
32763115
)
3277-
@enforce_tensor_types(
3278-
{
3279-
0: (TRTTensor,),
3280-
}
3281-
)
3282-
def aten_ops_upsample_nearest1d(
3283-
ctx: ConversionContext,
3284-
target: Target,
3285-
args: Tuple[Argument, ...],
3286-
kwargs: Dict[str, Argument],
3287-
name: str,
3288-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3289-
return impl.upsample.upsample(
3290-
ctx,
3291-
target,
3292-
SourceIR.ATEN,
3293-
name,
3294-
args[0],
3295-
size=args[1],
3296-
scale_factor=None if len(args) < 3 else [args[2]],
3297-
mode="nearest",
3298-
align_corners=False,
3299-
)
3300-
3301-
33023116
@dynamo_tensorrt_converter(
3303-
torch.ops.aten.upsample_nearest2d.default, supports_dynamic_shapes=True
3117+
torch.ops.aten.upsample_nearest2d.vec, supports_dynamic_shapes=True
33043118
)
3305-
@enforce_tensor_types(
3306-
{
3307-
0: (TRTTensor,),
3308-
}
3309-
)
3310-
def aten_ops_upsample_nearest2d(
3311-
ctx: ConversionContext,
3312-
target: Target,
3313-
args: Tuple[Argument, ...],
3314-
kwargs: Dict[str, Argument],
3315-
name: str,
3316-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3317-
return impl.upsample.upsample(
3318-
ctx,
3319-
target,
3320-
SourceIR.ATEN,
3321-
name,
3322-
args[0],
3323-
size=args[1],
3324-
scale_factor=None if len(args) < 4 else [args[2], args[3]],
3325-
mode="nearest",
3326-
align_corners=False,
3327-
)
3328-
3329-
33303119
@dynamo_tensorrt_converter(
3331-
torch.ops.aten.upsample_nearest3d.default, supports_dynamic_shapes=True
3120+
torch.ops.aten.upsample_nearest3d.vec, supports_dynamic_shapes=True
33323121
)
33333122
@enforce_tensor_types(
33343123
{
33353124
0: (TRTTensor,),
33363125
}
33373126
)
3338-
def aten_ops_upsample_nearest3d(
3127+
def aten_ops_upsample_nearest(
33393128
ctx: ConversionContext,
33403129
target: Target,
33413130
args: Tuple[Argument, ...],
@@ -3348,78 +3137,28 @@ def aten_ops_upsample_nearest3d(
33483137
SourceIR.ATEN,
33493138
name,
33503139
args[0],
3351-
size=args[1],
3352-
scale_factor=None if len(args) < 5 else [args[2], args[3], args[4]],
3140+
size=args_bounds_check(args, 1),
3141+
scale_factor=args_bounds_check(args, 2),
33533142
mode="nearest",
33543143
align_corners=False,
33553144
)
33563145

33573146

33583147
@dynamo_tensorrt_converter(
3359-
torch.ops.aten.upsample_linear1d.default, supports_dynamic_shapes=True
3360-
)
3361-
@enforce_tensor_types(
3362-
{
3363-
0: (TRTTensor,),
3364-
}
3148+
torch.ops.aten.upsample_linear1d.vec, supports_dynamic_shapes=True
33653149
)
3366-
def aten_ops_upsample_linear1d(
3367-
ctx: ConversionContext,
3368-
target: Target,
3369-
args: Tuple[Argument, ...],
3370-
kwargs: Dict[str, Argument],
3371-
name: str,
3372-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3373-
return impl.upsample.upsample(
3374-
ctx,
3375-
target,
3376-
SourceIR.ATEN,
3377-
name,
3378-
args[0],
3379-
size=args[1],
3380-
scale_factor=None if len(args) < 4 else [args[3]],
3381-
mode="linear",
3382-
align_corners=args[2],
3383-
)
3384-
3385-
33863150
@dynamo_tensorrt_converter(
3387-
torch.ops.aten.upsample_bilinear2d.default, supports_dynamic_shapes=True
3151+
torch.ops.aten.upsample_bilinear2d.vec, supports_dynamic_shapes=True
33883152
)
3389-
@enforce_tensor_types(
3390-
{
3391-
0: (TRTTensor,),
3392-
}
3393-
)
3394-
def aten_ops_upsample_bilinear2d(
3395-
ctx: ConversionContext,
3396-
target: Target,
3397-
args: Tuple[Argument, ...],
3398-
kwargs: Dict[str, Argument],
3399-
name: str,
3400-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3401-
return impl.upsample.upsample(
3402-
ctx,
3403-
target,
3404-
SourceIR.ATEN,
3405-
name,
3406-
args[0],
3407-
size=args[1],
3408-
scale_factor=None if len(args) < 5 else [args[3], args[4]],
3409-
mode="bilinear",
3410-
align_corners=args[2],
3411-
)
3412-
3413-
34143153
@dynamo_tensorrt_converter(
3415-
torch.ops.aten.upsample_trilinear3d.default, supports_dynamic_shapes=True
3154+
torch.ops.aten.upsample_trilinear3d.vec, supports_dynamic_shapes=True
34163155
)
34173156
@enforce_tensor_types(
34183157
{
34193158
0: (TRTTensor,),
34203159
}
34213160
)
3422-
def aten_ops_upsample_trilinear3d(
3161+
def aten_ops_upsample_linear(
34233162
ctx: ConversionContext,
34243163
target: Target,
34253164
args: Tuple[Argument, ...],
@@ -3432,15 +3171,15 @@ def aten_ops_upsample_trilinear3d(
34323171
SourceIR.ATEN,
34333172
name,
34343173
args[0],
3435-
size=args[1],
3436-
scale_factor=None if len(args) < 6 else [args[3], args[4], args[5]],
3437-
mode="trilinear",
3174+
size=args_bounds_check(args, 1),
3175+
scale_factor=args_bounds_check(args, 3),
3176+
mode="linear",
34383177
align_corners=args[2],
34393178
)
34403179

34413180

34423181
@dynamo_tensorrt_converter(
3443-
torch.ops.aten.upsample_bicubic2d.default, supports_dynamic_shapes=True
3182+
torch.ops.aten.upsample_bicubic2d.vec, supports_dynamic_shapes=True
34443183
)
34453184
@enforce_tensor_types(
34463185
{
@@ -3460,8 +3199,8 @@ def aten_ops_upsample_bicubic2d(
34603199
SourceIR.ATEN,
34613200
name,
34623201
args[0],
3463-
size=args[1],
3464-
scale_factor=None if len(args) < 5 else [args[3], args[4]],
3202+
size=args_bounds_check(args, 1),
3203+
scale_factor=args_bounds_check(args, 3),
34653204
mode="bicubic",
34663205
align_corners=args[2],
34673206
)

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ def upsample(
1818
source_ir: Optional[SourceIR],
1919
name: str,
2020
input: TRTTensor,
21-
size: Sequence[int],
21+
size: Optional[Sequence[int]],
2222
scale_factor: Optional[Sequence[float]],
2323
mode: str,
2424
align_corners: bool,
2525
) -> TRTTensor:
2626
layer = ctx.net.add_resize(input)
2727

28-
if scale_factor is not None and all(s is not None for s in scale_factor):
28+
if scale_factor is not None:
2929
layer.scales = [1.0, 1.0] + list(scale_factor)
3030
else:
3131
shape = list(input.shape)[:2] + list(size)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@
164164
}
165165
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
166166
aten._softmax.default,
167+
aten.upsample_nearest1d.vec,
168+
aten.upsample_nearest2d.vec,
169+
aten.upsample_nearest3d.vec,
170+
aten.upsample_linear1d.vec,
171+
aten.upsample_bilinear2d.vec,
172+
aten.upsample_trilinear3d.vec,
173+
aten.upsample_bicubic2d.vec,
167174
}
168175

169176

0 commit comments

Comments
 (0)