7
7
8
8
Hacked together by / Copyright 2021 Ross Wightman
9
9
"""
10
- from typing import Union , Tuple
10
+ from typing import Optional , Union , Tuple
11
11
12
12
import torch
13
13
import torch .nn as nn
14
14
15
+ from . config import use_fused_attn
15
16
from .helpers import to_2tuple
17
+ from .pos_embed import resample_abs_pos_embed
16
18
from .pos_embed_sincos import apply_rot_embed , RotaryEmbedding
17
19
from .weight_init import trunc_normal_
18
20
@@ -27,53 +29,122 @@ class RotAttentionPool2d(nn.Module):
27
29
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
28
30
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
29
31
"""
32
+ fused_attn : torch .jit .Final [bool ]
33
+
30
34
def __init__ (
31
35
self ,
32
36
in_features : int ,
33
- out_features : int = None ,
34
- embed_dim : int = None ,
35
- num_heads : int = 4 ,
37
+ out_features : Optional [int ] = None ,
38
+ ref_feat_size : Union [int , Tuple [int , int ]] = 7 ,
39
+ embed_dim : Optional [int ] = None ,
40
+ head_dim : Optional [int ] = 64 ,
41
+ num_heads : Optional [int ] = None ,
36
42
qkv_bias : bool = True ,
43
+ qkv_separate : bool = False ,
44
+ pool_type : str = 'token' ,
45
+ class_token : bool = False ,
46
+ drop_rate : float = 0. ,
37
47
):
38
48
super ().__init__ ()
39
- embed_dim = embed_dim or in_features
40
- out_features = out_features or in_features
41
- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
42
- self .proj = nn .Linear (embed_dim , out_features )
49
+ assert pool_type in ('' , 'token' )
50
+ self .embed_dim = embed_dim = embed_dim or in_features
51
+ self .in_features = in_features
52
+ self .out_features = out_features or in_features
53
+ ref_feat_size = to_2tuple (ref_feat_size )
54
+ if num_heads is not None :
55
+ assert embed_dim % num_heads == 0
56
+ head_dim = embed_dim // num_heads
57
+ else :
58
+ assert embed_dim % head_dim == 0
59
+ num_heads = embed_dim // head_dim
43
60
self .num_heads = num_heads
44
- assert embed_dim % num_heads == 0
45
- self .head_dim = embed_dim // num_heads
61
+ self . head_dim = head_dim
62
+ self .pool_type = pool_type . lower ()
46
63
self .scale = self .head_dim ** - 0.5
47
- self .pos_embed = RotaryEmbedding ( self . head_dim )
64
+ self .fused_attn = use_fused_attn ( )
48
65
49
- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
50
- nn .init .zeros_ (self .qkv .bias )
66
+ if class_token :
67
+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
68
+ else :
69
+ self .cls_token = None
51
70
52
- def forward (self , x ):
53
- B , _ , H , W = x .shape
54
- N = H * W
55
- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
71
+ if qkv_separate :
72
+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
73
+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
74
+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
75
+ self .qkv = None
76
+ else :
77
+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
78
+ self .drop = nn .Dropout (drop_rate )
79
+ self .proj = nn .Linear (embed_dim , self .out_features )
80
+ self .pos_embed = RotaryEmbedding (self .head_dim , in_pixels = False , ref_feat_shape = ref_feat_size )
56
81
57
- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
82
+ def init_weights (self , zero_init_last : bool = False ):
83
+ if self .qkv is None :
84
+ in_features = self .q .in_features
85
+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
86
+ nn .init .zeros_ (self .q .bias )
87
+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
88
+ nn .init .zeros_ (self .k .bias )
89
+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
90
+ nn .init .zeros_ (self .v .bias )
91
+ else :
92
+ in_features = self .qkv .in_features
93
+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
94
+ nn .init .zeros_ (self .qkv .bias )
58
95
59
- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
60
- q , k , v = x [0 ], x [1 ], x [2 ]
96
+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
97
+ # NOTE: this module is being used as a head, so need compatible reset()
98
+ if pool_type is not None :
99
+ assert pool_type in ('' , 'token' )
100
+ self .pool_type = pool_type
101
+ if num_classes is not None :
102
+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
103
+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
61
104
62
- qc , q = q [:, :, :1 ], q [:, :, 1 :]
63
- sin_emb , cos_emb = self .pos_embed .get_embed ((H , W ))
64
- q = apply_rot_embed (q , sin_emb , cos_emb )
65
- q = torch .cat ([qc , q ], dim = 2 )
105
+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
106
+ if self .pool_type == 'token' :
107
+ x = x [:, 0 ]
108
+ else :
109
+ # if not pooled, return spatial output without token
110
+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
111
+ return x
66
112
67
- kc , k = k [:, :, :1 ], k [:, :, 1 :]
68
- k = apply_rot_embed (k , sin_emb , cos_emb )
69
- k = torch .cat ([kc , k ], dim = 2 )
113
+ def forward (self , x , pre_logits : bool = False ):
114
+ B , _ , H , W = x .shape
115
+ N = H * W
116
+ x = x .flatten (2 ).transpose (1 , 2 )
117
+ if self .cls_token is None :
118
+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
119
+ else :
120
+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
121
+ if self .qkv is None :
122
+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
123
+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
124
+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
125
+ else :
126
+ x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
127
+ q , k , v = x .unbind (0 )
70
128
71
- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
72
- attn = attn .softmax (dim = - 1 )
129
+ rse , rce = self .pos_embed .get_embed ((H , W ))
130
+ q = torch .cat ([q [:, :, :1 , :], apply_rot_embed (q [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
131
+ k = torch .cat ([k [:, :, :1 , :], apply_rot_embed (k [:, :, 1 :, :], rse , rce )], dim = 2 ).type_as (v )
73
132
74
- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
133
+ if self .fused_attn :
134
+ x = nn .functional .scaled_dot_product_attention (q , k , v )
135
+ else :
136
+ q = q * self .scale
137
+ attn = q @ k .transpose (- 2 , - 1 )
138
+ attn = attn .softmax (dim = - 1 )
139
+ x = attn @ v
140
+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
141
+ x = self .drop (x )
142
+ if pre_logits :
143
+ x = self ._pool (x , H , W )
144
+ return x
75
145
x = self .proj (x )
76
- return x [:, 0 ]
146
+ x = self ._pool (x , H , W )
147
+ return x
77
148
78
149
79
150
class AttentionPool2d (nn .Module ):
@@ -85,47 +156,123 @@ class AttentionPool2d(nn.Module):
85
156
86
157
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87
158
"""
159
+ fused_attn : torch .jit .Final [bool ]
160
+
88
161
def __init__ (
89
162
self ,
90
163
in_features : int ,
91
- feat_size : Union [int , Tuple [int , int ]],
92
- out_features : int = None ,
93
- embed_dim : int = None ,
94
- num_heads : int = 4 ,
164
+ feat_size : Union [int , Tuple [int , int ]] = 7 ,
165
+ out_features : Optional [int ] = None ,
166
+ embed_dim : Optional [int ] = None ,
167
+ head_dim : Optional [int ] = 64 ,
168
+ num_heads : Optional [int ] = None ,
95
169
qkv_bias : bool = True ,
170
+ qkv_separate : bool = False ,
171
+ pool_type : str = 'token' ,
172
+ class_token : bool = False ,
173
+ drop_rate : float = 0. ,
96
174
):
97
175
super ().__init__ ()
98
-
99
- embed_dim = embed_dim or in_features
100
- out_features = out_features or in_features
101
- assert embed_dim % num_heads == 0
176
+ assert pool_type in ('' , 'token' )
177
+ self .embed_dim = embed_dim = embed_dim or in_features
178
+ self .in_features = in_features
179
+ self .out_features = out_features or in_features
180
+ if num_heads is not None :
181
+ assert embed_dim % num_heads == 0
182
+ head_dim = embed_dim // num_heads
183
+ else :
184
+ assert embed_dim % head_dim == 0
185
+ num_heads = embed_dim // head_dim
102
186
self .feat_size = to_2tuple (feat_size )
103
- self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
104
- self .proj = nn .Linear (embed_dim , out_features )
187
+ self .seq_len = self .feat_size [0 ] * self .feat_size [1 ]
105
188
self .num_heads = num_heads
106
- self .head_dim = embed_dim // num_heads
189
+ self .head_dim = head_dim
190
+ self .pool_type = pool_type
107
191
self .scale = self .head_dim ** - 0.5
192
+ self .fused_attn = use_fused_attn ()
108
193
109
- spatial_dim = self .feat_size [0 ] * self .feat_size [1 ]
110
- self .pos_embed = nn .Parameter (torch .zeros (spatial_dim + 1 , in_features ))
194
+ if class_token :
195
+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
196
+ else :
197
+ self .cls_token = None
198
+
199
+ if qkv_separate :
200
+ self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
201
+ self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
202
+ self .v = nn .Linear (in_features , embed_dim , bias = qkv_bias )
203
+ self .qkv = None
204
+ else :
205
+ self .q = self .k = self .v = None
206
+ self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
207
+ self .drop = nn .Dropout (drop_rate )
208
+ self .proj = nn .Linear (embed_dim , self .out_features )
209
+ self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
210
+
211
+ self .init_weights ()
212
+
213
+ def init_weights (self , zero_init_last : bool = False ):
214
+ if self .qkv is None :
215
+ in_features = self .q .in_features
216
+ trunc_normal_ (self .q .weight , std = in_features ** - 0.5 )
217
+ nn .init .zeros_ (self .q .bias )
218
+ trunc_normal_ (self .k .weight , std = in_features ** - 0.5 )
219
+ nn .init .zeros_ (self .k .bias )
220
+ trunc_normal_ (self .v .weight , std = in_features ** - 0.5 )
221
+ nn .init .zeros_ (self .v .bias )
222
+ else :
223
+ in_features = self .qkv .in_features
224
+ trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
225
+ nn .init .zeros_ (self .qkv .bias )
111
226
trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
112
- trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
113
- nn .init .zeros_ (self .qkv .bias )
114
227
115
- def forward (self , x ):
228
+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
229
+ # NOTE: this module is being used as a head, so need compatible reset()
230
+ if pool_type is not None :
231
+ assert pool_type in ('' , 'token' )
232
+ self .pool_type = pool_type
233
+ if num_classes is not None :
234
+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
235
+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
236
+
237
+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
238
+ if self .pool_type == 'token' :
239
+ x = x [:, 0 ]
240
+ else :
241
+ # if not pooled, return spatial output without token
242
+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
243
+ return x
244
+
245
+ def forward (self , x , pre_logits : bool = False ):
116
246
B , _ , H , W = x .shape
117
247
N = H * W
118
- assert self .feat_size [0 ] == H
119
- assert self .feat_size [1 ] == W
120
- x = x .reshape (B , - 1 , N ).permute (0 , 2 , 1 )
121
- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
122
- x = x + self .pos_embed .unsqueeze (0 ).to (x .dtype )
123
-
124
- x = self .qkv (x ).reshape (B , N + 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
125
- q , k , v = x [0 ], x [1 ], x [2 ]
126
- attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
127
- attn = attn .softmax (dim = - 1 )
128
-
129
- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
248
+ x = x .flatten (2 ).transpose (1 , 2 )
249
+ if self .cls_token is None :
250
+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
251
+ else :
252
+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
253
+ pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
254
+ x = x + pos_embed
255
+
256
+ if self .qkv is None :
257
+ q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
258
+ k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
259
+ v = self .v (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
260
+ else :
261
+ x = self .qkv (x ).reshape (B , - 1 , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
262
+ q , k , v = x .unbind (0 )
263
+
264
+ if self .fused_attn :
265
+ x = nn .functional .scaled_dot_product_attention (q , k , v )
266
+ else :
267
+ q = q * self .scale
268
+ attn = q @ k .transpose (- 2 , - 1 )
269
+ attn = attn .softmax (dim = - 1 )
270
+ x = attn @ v
271
+ x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
272
+ x = self .drop (x )
273
+ if pre_logits :
274
+ x = self ._pool (x , H , W )
275
+ return x
130
276
x = self .proj (x )
131
- return x [:, 0 ]
277
+ x = self ._pool (x , H , W )
278
+ return x
0 commit comments