7
7
from torch_tensorrt ._enums import dtype
8
8
from torch_tensorrt .dynamo .conversion import impl
9
9
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
10
- from torch_tensorrt .dynamo .conversion .converter_utils import SourceIR , get_trt_tensor
10
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
11
+ SourceIR ,
12
+ cast_trt_tensor ,
13
+ get_trt_tensor ,
14
+ )
11
15
from torch_tensorrt .fx .types import TRTTensor
12
16
13
17
18
+ def tril (
19
+ ctx : ConversionContext ,
20
+ target : Union [Target , str ],
21
+ source_ir : Optional [SourceIR ],
22
+ name : str ,
23
+ input : TRTTensor ,
24
+ ) -> TRTTensor :
25
+ # the lower triangle of the tensor means the rows greater than and equal to the cols
26
+ row = impl .shape .shape (ctx , target , source_ir , name + "_shape_0" , input , 0 )
27
+ col = impl .shape .shape (ctx , target , source_ir , name + "_shape_1" , input , 1 )
28
+ rc = impl .elementwise .mul (ctx , target , source_ir , name + "_mul" , row , col )
29
+ arange_tensor = impl .arange .arange (
30
+ ctx , target , source_ir , name + "_arange" , start = 0 , end = rc , step = 1
31
+ )
32
+ # get the rows
33
+ row_tensor = impl .elementwise .trunc_div (
34
+ ctx , target , source_ir , name + "_trunc_div_col" , arange_tensor , col
35
+ )
36
+ # get the cols
37
+ col_tensor = impl .elementwise .fmod (
38
+ ctx , target , source_ir , name + "_trunc_div_row" , arange_tensor , col
39
+ )
40
+ cond = impl .elementwise .ge (
41
+ ctx , target , source_ir , name + "_ge" , row_tensor , col_tensor
42
+ )
43
+ return impl .shuffle .reshape (
44
+ ctx , target , source_ir , name + "_reshape" , cond , [row , col ]
45
+ )
46
+
47
+
14
48
def scaled_dot_product_attention (
15
49
ctx : ConversionContext ,
16
50
target : Union [Target , str ],
@@ -22,8 +56,7 @@ def scaled_dot_product_attention(
22
56
is_causal : bool ,
23
57
scale : Optional [float ],
24
58
) -> TRTTensor :
25
- L , S = query .shape [- 2 ], key .shape [- 2 ]
26
-
59
+ # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
27
60
mm = impl .matmul .matrix_multiply (
28
61
ctx ,
29
62
target ,
@@ -34,13 +67,21 @@ def scaled_dot_product_attention(
34
67
other_matrix_op = trt .MatrixOperation .TRANSPOSE ,
35
68
)
36
69
if scale is None :
70
+ scale = query .shape [- 1 ]
71
+ if scale < 0 :
72
+ # dynamic shape
73
+ scale = impl .shape .shape (ctx , target , source_ir , name + "_shape" , query , - 1 )
74
+ sqrt_scaled = impl .unary .sqrt (ctx , target , source_ir , name + "_sqrt" , scale )
75
+ else :
76
+ # static shape
77
+ sqrt_scaled = math .sqrt (scale )
37
78
scaled = impl .elementwise .div (
38
79
ctx ,
39
80
target ,
40
81
source_ir ,
41
82
name + "_scale" ,
42
83
mm ,
43
- math . sqrt ( query . shape [ - 1 ]) ,
84
+ sqrt_scaled ,
44
85
)
45
86
else :
46
87
scaled = impl .elementwise .mul (
@@ -53,10 +94,57 @@ def scaled_dot_product_attention(
53
94
)
54
95
55
96
if is_causal :
56
- attn_bias = np .zeros ((L , S ), dtype = dtype ._from (query .dtype ).to (np .dtype ))
57
- temp_mask = np .logical_not (np .tril (np .ones ((L , S ), dtype = np .bool_ ), k = 0 ))
58
- attn_bias = np .ma .array (attn_bias , mask = temp_mask ).filled (float ("-inf" ))
59
- attn_bias = get_trt_tensor (ctx , attn_bias , name + "_attn_bias" )
97
+ L , S = query .shape [- 2 ], key .shape [- 2 ]
98
+ if L >= 0 and S >= 0 :
99
+ # static shape
100
+ attn_bias = np .zeros ((L , S ), dtype = dtype ._from (query .dtype ).to (np .dtype ))
101
+ temp_mask = np .logical_not (np .tril (np .ones ((L , S ), dtype = np .bool_ ), k = 0 ))
102
+ attn_bias = np .ma .array (attn_bias , mask = temp_mask ).filled (float ("-inf" ))
103
+ attn_bias = get_trt_tensor (ctx , attn_bias , name + "_attn_bias" )
104
+ else :
105
+ # if any of the L or S is dynamic shape
106
+ if L < 0 :
107
+ L = impl .shape .shape (
108
+ ctx , target , source_ir , name + "_shape_0" , query , - 2
109
+ )
110
+ if S < 0 :
111
+ S = impl .shape .shape (ctx , target , source_ir , name + "_shape_1" , key , - 2 )
112
+
113
+ LS = impl .elementwise .mul (ctx , target , source_ir , name + "_mul" , L , S )
114
+
115
+ # this is to generate a tensor which has shape (L, S), type is int32
116
+ arange_tensor = impl .arange .arange (
117
+ ctx , target , source_ir , name = name + "_arange" , start = 0 , end = LS , step = 1
118
+ )
119
+ shape_tensor = impl .shuffle .reshape (
120
+ ctx , target , source_ir , name + "_reshape" , arange_tensor , [L , S ]
121
+ )
122
+
123
+ # since we want our attn_bias to be in float32, so cast it to float32
124
+ shape_tensor = cast_trt_tensor (
125
+ ctx , shape_tensor , trt .float32 , name + "_casted" , target , source_ir
126
+ )
127
+
128
+ # initialize the attn_bias as the zeros tensor
129
+ attn_bias = impl .elementwise .mul (
130
+ ctx , target , source_ir , name + "_mul_zero" , shape_tensor , 0.0
131
+ )
132
+
133
+ # generate the mask tensor
134
+ tril_tensor = tril (ctx , target , source_ir , name + "_tril" , shape_tensor )
135
+ temp_mask = impl .unary .logical_not (
136
+ ctx , target , source_ir , name + "_logical_not" , tril_tensor
137
+ )
138
+ inf_tensor = impl .elementwise .mul (
139
+ ctx , target , source_ir , name + "_mul_-inf" , shape_tensor , float ("-inf" )
140
+ )
141
+ cond = impl .elementwise .eq (
142
+ ctx , target , source_ir , name + "_cond_true" , temp_mask , bool (True )
143
+ )
144
+ # mask out the certain part of the attn_bias
145
+ attn_bias = impl .condition .select (
146
+ ctx , target , source_ir , name + "_select" , inf_tensor , attn_bias , cond
147
+ )
60
148
61
149
scaled = impl .elementwise .add (
62
150
ctx , target , source_ir , name + "_attn_bias_add" , scaled , attn_bias
0 commit comments