Skip to content

Commit 3a43fd2

Browse files
add dynamic shape support for scaled_dot_product_attention, logical_or/xor (#2975)
1 parent abed8f0 commit 3a43fd2

File tree

9 files changed

+542
-73
lines changed

9 files changed

+542
-73
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ def aten_ops_embedding_bag(
315315
)
316316

317317

318-
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar)
319-
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor)
318+
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar, supports_dynamic_shapes=True)
319+
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor, supports_dynamic_shapes=True)
320320
def aten_ops_fmod(
321321
ctx: ConversionContext,
322322
target: Target,
@@ -3435,3 +3435,24 @@ def aten_ops_prelu(
34353435
args[0],
34363436
args[1],
34373437
)
3438+
3439+
3440+
@dynamo_tensorrt_converter(
3441+
torch.ops.aten.arange.start_step, supports_dynamic_shapes=True
3442+
)
3443+
def aten_ops_arange_start_step(
3444+
ctx: ConversionContext,
3445+
target: Target,
3446+
args: Tuple[Argument, ...],
3447+
kwargs: Dict[str, Argument],
3448+
name: str,
3449+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3450+
return impl.arange.arange(
3451+
ctx,
3452+
target,
3453+
SourceIR.ATEN,
3454+
name,
3455+
start=args[0],
3456+
end=args[1],
3457+
step=args_bounds_check(args, 2, 1),
3458+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import (
44
activation,
55
addmm,
6+
arange,
67
attention,
78
cast,
89
cat,
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo.conversion import impl
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
SourceIR,
10+
cast_trt_tensor,
11+
get_trt_tensor,
12+
)
13+
from torch_tensorrt.fx.types import TRTTensor
14+
15+
16+
def arange(
17+
ctx: ConversionContext,
18+
target: Target,
19+
source_ir: Optional[SourceIR],
20+
name: str,
21+
start: Union[int, TRTTensor],
22+
end: Union[int, TRTTensor],
23+
step: Union[int, TRTTensor],
24+
) -> TRTTensor:
25+
26+
if any(isinstance(tensor, TRTTensor) for tensor in (start, end, step)):
27+
start_rank_0 = get_trt_tensor(ctx, start, name + "_start_rank_0", min_rank=0)
28+
start_rank_1 = get_trt_tensor(ctx, start, name + "_start_rank_1", min_rank=1)
29+
end = get_trt_tensor(ctx, end, name + "_end", min_rank=1)
30+
step = get_trt_tensor(ctx, step, name + "_step", min_rank=1)
31+
# Calculate shape = (end-start) / step
32+
shape = impl.elementwise.sub(
33+
ctx,
34+
target,
35+
source_ir,
36+
name + "_sub",
37+
end,
38+
start_rank_1,
39+
)
40+
shape = impl.elementwise.trunc_div(
41+
ctx,
42+
target,
43+
source_ir,
44+
name + "_shape",
45+
shape,
46+
step,
47+
)
48+
shape = cast_trt_tensor(ctx, shape, end.dtype, name + "_shape_casted")
49+
fill_layer = ctx.net.add_fill(
50+
shape.shape, trt.FillOperation.LINSPACE, shape.dtype
51+
)
52+
fill_layer.set_input(0, shape)
53+
# Set start index
54+
fill_layer.set_input(1, start_rank_0)
55+
# Set delta/step
56+
fill_layer.set_input(2, step)
57+
return fill_layer.get_output(0)
58+
return np.arange(start, end, step)

py/torch_tensorrt/dynamo/conversion/impl/attention.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,44 @@
77
from torch_tensorrt._enums import dtype
88
from torch_tensorrt.dynamo.conversion import impl
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
SourceIR,
12+
cast_trt_tensor,
13+
get_trt_tensor,
14+
)
1115
from torch_tensorrt.fx.types import TRTTensor
1216

1317

18+
def tril(
19+
ctx: ConversionContext,
20+
target: Union[Target, str],
21+
source_ir: Optional[SourceIR],
22+
name: str,
23+
input: TRTTensor,
24+
) -> TRTTensor:
25+
# the lower triangle of the tensor means the rows greater than and equal to the cols
26+
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
27+
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
28+
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
29+
arange_tensor = impl.arange.arange(
30+
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
31+
)
32+
# get the rows
33+
row_tensor = impl.elementwise.trunc_div(
34+
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
35+
)
36+
# get the cols
37+
col_tensor = impl.elementwise.fmod(
38+
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
39+
)
40+
cond = impl.elementwise.ge(
41+
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
42+
)
43+
return impl.shuffle.reshape(
44+
ctx, target, source_ir, name + "_reshape", cond, [row, col]
45+
)
46+
47+
1448
def scaled_dot_product_attention(
1549
ctx: ConversionContext,
1650
target: Union[Target, str],
@@ -22,8 +56,7 @@ def scaled_dot_product_attention(
2256
is_causal: bool,
2357
scale: Optional[float],
2458
) -> TRTTensor:
25-
L, S = query.shape[-2], key.shape[-2]
26-
59+
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
2760
mm = impl.matmul.matrix_multiply(
2861
ctx,
2962
target,
@@ -34,13 +67,21 @@ def scaled_dot_product_attention(
3467
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
3568
)
3669
if scale is None:
70+
scale = query.shape[-1]
71+
if scale < 0:
72+
# dynamic shape
73+
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
74+
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
75+
else:
76+
# static shape
77+
sqrt_scaled = math.sqrt(scale)
3778
scaled = impl.elementwise.div(
3879
ctx,
3980
target,
4081
source_ir,
4182
name + "_scale",
4283
mm,
43-
math.sqrt(query.shape[-1]),
84+
sqrt_scaled,
4485
)
4586
else:
4687
scaled = impl.elementwise.mul(
@@ -53,10 +94,57 @@ def scaled_dot_product_attention(
5394
)
5495

5596
if is_causal:
56-
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
57-
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
58-
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
59-
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
97+
L, S = query.shape[-2], key.shape[-2]
98+
if L >= 0 and S >= 0:
99+
# static shape
100+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
101+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
102+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
103+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
104+
else:
105+
# if any of the L or S is dynamic shape
106+
if L < 0:
107+
L = impl.shape.shape(
108+
ctx, target, source_ir, name + "_shape_0", query, -2
109+
)
110+
if S < 0:
111+
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)
112+
113+
LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)
114+
115+
# this is to generate a tensor which has shape (L, S), type is int32
116+
arange_tensor = impl.arange.arange(
117+
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
118+
)
119+
shape_tensor = impl.shuffle.reshape(
120+
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
121+
)
122+
123+
# since we want our attn_bias to be in float32, so cast it to float32
124+
shape_tensor = cast_trt_tensor(
125+
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
126+
)
127+
128+
# initialize the attn_bias as the zeros tensor
129+
attn_bias = impl.elementwise.mul(
130+
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
131+
)
132+
133+
# generate the mask tensor
134+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
135+
temp_mask = impl.unary.logical_not(
136+
ctx, target, source_ir, name + "_logical_not", tril_tensor
137+
)
138+
inf_tensor = impl.elementwise.mul(
139+
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
140+
)
141+
cond = impl.elementwise.eq(
142+
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
143+
)
144+
# mask out the certain part of the attn_bias
145+
attn_bias = impl.condition.select(
146+
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
147+
)
60148

61149
scaled = impl.elementwise.add(
62150
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
1+
# mypy: disallow-untyped-decorators=False
2+
13
import logging
24
import operator
35
from typing import Dict, Sequence, Tuple, Union
46

57
import numpy as np
6-
import tensorrt as trt
78
import torch
89
from torch.fx.node import Argument, Node, Target
9-
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1212
ConverterRegistry,
1313
dynamo_tensorrt_converter,
1414
)
15-
from torch_tensorrt.dynamo.conversion.converter_utils import (
16-
cast_trt_tensor,
17-
get_trt_tensor,
18-
)
19-
from torch_tensorrt.dynamo.conversion.impl.elementwise import sub, trunc_div
2015
from torch_tensorrt.fx.types import TRTTensor
2116
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2217

@@ -50,53 +45,6 @@ def generic_evaluator(
5045
return target(*args)
5146

5247

53-
@dynamo_tensorrt_converter(
54-
torch.ops.aten.arange.start_step, supports_dynamic_shapes=True
55-
)
56-
def aten_ops_arange_start_step(
57-
ctx: ConversionContext,
58-
target: Target,
59-
args: Tuple[Argument, ...],
60-
kwargs: Dict[str, Argument],
61-
name: str,
62-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
63-
# Case where inputs to arange are dynamic
64-
if any(isinstance(tensor, TRTTensor) for tensor in args):
65-
start_rank_0 = get_trt_tensor(ctx, args[0], name + "_start_rank_0", min_rank=0)
66-
start_rank_1 = get_trt_tensor(ctx, args[0], name + "_start_rank_1", min_rank=1)
67-
end = get_trt_tensor(ctx, args[1], name + "_end", min_rank=1)
68-
step = args[2] if len(args) > 2 else 1
69-
step = get_trt_tensor(ctx, step, name + "_step", min_rank=1)
70-
# Calculate shape = (end-start) / step
71-
shape = sub(
72-
ctx,
73-
target,
74-
SourceIR.ATEN,
75-
name + "_sub",
76-
end,
77-
start_rank_1,
78-
)
79-
shape = trunc_div(
80-
ctx,
81-
target,
82-
SourceIR.ATEN,
83-
name + "_shape",
84-
shape,
85-
step,
86-
)
87-
shape = cast_trt_tensor(ctx, shape, end.dtype, name + "_shape_casted")
88-
fill_layer = ctx.net.add_fill(
89-
shape.shape, trt.FillOperation.LINSPACE, shape.dtype
90-
)
91-
fill_layer.set_input(0, shape)
92-
# Set start index
93-
fill_layer.set_input(1, start_rank_0)
94-
# Set delta/step
95-
fill_layer.set_input(2, step)
96-
return fill_layer.get_output(0)
97-
return np.arange(*args)
98-
99-
10048
def rand_validator(rand_node: Node) -> bool:
10149
dtype = rand_node.kwargs.get("dtype", None)
10250
layout = rand_node.kwargs.get("layout", None)

0 commit comments

Comments
 (0)