Skip to content

Commit 7b0a532

Browse files
authored
Merge pull request #2198 from huggingface/openai_clip_resnet
Mapping OpenAI CLIP Modified ResNet weights -> ByobNet.
2 parents 5aa49d5 + 57adc1a commit 7b0a532

12 files changed

+656
-200
lines changed

timm/layers/attention_pool.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
out_features: int = None,
2121
embed_dim: int = None,
2222
num_heads: int = 8,
23+
feat_size: Optional[int] = None,
2324
mlp_ratio: float = 4.0,
2425
qkv_bias: bool = True,
2526
qk_norm: bool = False,
@@ -36,13 +37,14 @@ def __init__(
3637
assert embed_dim % num_heads == 0
3738
self.num_heads = num_heads
3839
self.head_dim = embed_dim // num_heads
40+
self.feat_size = feat_size
3941
self.scale = self.head_dim ** -0.5
4042
self.pool = pool_type
4143
self.fused_attn = use_fused_attn()
4244

4345
if pos_embed == 'abs':
44-
spatial_len = self.feat_size
45-
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
46+
assert feat_size is not None
47+
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
4648
else:
4749
self.pos_embed = None
4850

timm/layers/attention_pool2d.py

Lines changed: 207 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
88
Hacked together by / Copyright 2021 Ross Wightman
99
"""
10-
from typing import Union, Tuple
10+
from typing import Optional, Union, Tuple
1111

1212
import torch
1313
import torch.nn as nn
1414

15+
from. config import use_fused_attn
1516
from .helpers import to_2tuple
17+
from .pos_embed import resample_abs_pos_embed
1618
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
1719
from .weight_init import trunc_normal_
1820

@@ -27,53 +29,122 @@ class RotAttentionPool2d(nn.Module):
2729
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
2830
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
2931
"""
32+
fused_attn: torch.jit.Final[bool]
33+
3034
def __init__(
3135
self,
3236
in_features: int,
33-
out_features: int = None,
34-
embed_dim: int = None,
35-
num_heads: int = 4,
37+
out_features: Optional[int] = None,
38+
ref_feat_size: Union[int, Tuple[int, int]] = 7,
39+
embed_dim: Optional[int] = None,
40+
head_dim: Optional[int] = 64,
41+
num_heads: Optional[int] = None,
3642
qkv_bias: bool = True,
43+
qkv_separate: bool = False,
44+
pool_type: str = 'token',
45+
class_token: bool = False,
46+
drop_rate: float = 0.,
3747
):
3848
super().__init__()
39-
embed_dim = embed_dim or in_features
40-
out_features = out_features or in_features
41-
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
42-
self.proj = nn.Linear(embed_dim, out_features)
49+
assert pool_type in ('', 'token')
50+
self.embed_dim = embed_dim = embed_dim or in_features
51+
self.in_features = in_features
52+
self.out_features = out_features or in_features
53+
ref_feat_size = to_2tuple(ref_feat_size)
54+
if num_heads is not None:
55+
assert embed_dim % num_heads == 0
56+
head_dim = embed_dim // num_heads
57+
else:
58+
assert embed_dim % head_dim == 0
59+
num_heads = embed_dim // head_dim
4360
self.num_heads = num_heads
44-
assert embed_dim % num_heads == 0
45-
self.head_dim = embed_dim // num_heads
61+
self.head_dim = head_dim
62+
self.pool_type = pool_type.lower()
4663
self.scale = self.head_dim ** -0.5
47-
self.pos_embed = RotaryEmbedding(self.head_dim)
64+
self.fused_attn = use_fused_attn()
4865

49-
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
50-
nn.init.zeros_(self.qkv.bias)
66+
if class_token:
67+
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
68+
else:
69+
self.cls_token = None
5170

52-
def forward(self, x):
53-
B, _, H, W = x.shape
54-
N = H * W
55-
x = x.reshape(B, -1, N).permute(0, 2, 1)
71+
if qkv_separate:
72+
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
73+
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
74+
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
75+
self.qkv = None
76+
else:
77+
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
78+
self.drop = nn.Dropout(drop_rate)
79+
self.proj = nn.Linear(embed_dim, self.out_features)
80+
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)
5681

57-
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
82+
def init_weights(self, zero_init_last: bool = False):
83+
if self.qkv is None:
84+
in_features = self.q.in_features
85+
trunc_normal_(self.q.weight, std=in_features ** -0.5)
86+
nn.init.zeros_(self.q.bias)
87+
trunc_normal_(self.k.weight, std=in_features ** -0.5)
88+
nn.init.zeros_(self.k.bias)
89+
trunc_normal_(self.v.weight, std=in_features ** -0.5)
90+
nn.init.zeros_(self.v.bias)
91+
else:
92+
in_features = self.qkv.in_features
93+
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
94+
nn.init.zeros_(self.qkv.bias)
5895

59-
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
60-
q, k, v = x[0], x[1], x[2]
96+
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
97+
# NOTE: this module is being used as a head, so need compatible reset()
98+
if pool_type is not None:
99+
assert pool_type in ('', 'token')
100+
self.pool_type = pool_type
101+
if num_classes is not None:
102+
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
103+
self.out_features = num_classes if num_classes > 0 else self.embed_dim
61104

62-
qc, q = q[:, :, :1], q[:, :, 1:]
63-
sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
64-
q = apply_rot_embed(q, sin_emb, cos_emb)
65-
q = torch.cat([qc, q], dim=2)
105+
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
106+
if self.pool_type == 'token':
107+
x = x[:, 0]
108+
else:
109+
# if not pooled, return spatial output without token
110+
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
111+
return x
66112

67-
kc, k = k[:, :, :1], k[:, :, 1:]
68-
k = apply_rot_embed(k, sin_emb, cos_emb)
69-
k = torch.cat([kc, k], dim=2)
113+
def forward(self, x, pre_logits: bool = False):
114+
B, _, H, W = x.shape
115+
N = H * W
116+
x = x.flatten(2).transpose(1, 2)
117+
if self.cls_token is None:
118+
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
119+
else:
120+
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
121+
if self.qkv is None:
122+
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
123+
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
124+
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
125+
else:
126+
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
127+
q, k, v = x.unbind(0)
70128

71-
attn = (q @ k.transpose(-2, -1)) * self.scale
72-
attn = attn.softmax(dim=-1)
129+
rse, rce = self.pos_embed.get_embed((H, W))
130+
q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
131+
k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
73132

74-
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
133+
if self.fused_attn:
134+
x = nn.functional.scaled_dot_product_attention(q, k, v)
135+
else:
136+
q = q * self.scale
137+
attn = q @ k.transpose(-2, -1)
138+
attn = attn.softmax(dim=-1)
139+
x = attn @ v
140+
x = x.transpose(1, 2).reshape(B, N + 1, -1)
141+
x = self.drop(x)
142+
if pre_logits:
143+
x = self._pool(x, H, W)
144+
return x
75145
x = self.proj(x)
76-
return x[:, 0]
146+
x = self._pool(x, H, W)
147+
return x
77148

78149

79150
class AttentionPool2d(nn.Module):
@@ -85,47 +156,123 @@ class AttentionPool2d(nn.Module):
85156
86157
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87158
"""
159+
fused_attn: torch.jit.Final[bool]
160+
88161
def __init__(
89162
self,
90163
in_features: int,
91-
feat_size: Union[int, Tuple[int, int]],
92-
out_features: int = None,
93-
embed_dim: int = None,
94-
num_heads: int = 4,
164+
feat_size: Union[int, Tuple[int, int]] = 7,
165+
out_features: Optional[int] = None,
166+
embed_dim: Optional[int] = None,
167+
head_dim: Optional[int] = 64,
168+
num_heads: Optional[int] = None,
95169
qkv_bias: bool = True,
170+
qkv_separate: bool = False,
171+
pool_type: str = 'token',
172+
class_token: bool = False,
173+
drop_rate: float = 0.,
96174
):
97175
super().__init__()
98-
99-
embed_dim = embed_dim or in_features
100-
out_features = out_features or in_features
101-
assert embed_dim % num_heads == 0
176+
assert pool_type in ('', 'token')
177+
self.embed_dim = embed_dim = embed_dim or in_features
178+
self.in_features = in_features
179+
self.out_features = out_features or in_features
180+
if num_heads is not None:
181+
assert embed_dim % num_heads == 0
182+
head_dim = embed_dim // num_heads
183+
else:
184+
assert embed_dim % head_dim == 0
185+
num_heads = embed_dim // head_dim
102186
self.feat_size = to_2tuple(feat_size)
103-
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
104-
self.proj = nn.Linear(embed_dim, out_features)
187+
self.seq_len = self.feat_size[0] * self.feat_size[1]
105188
self.num_heads = num_heads
106-
self.head_dim = embed_dim // num_heads
189+
self.head_dim = head_dim
190+
self.pool_type = pool_type
107191
self.scale = self.head_dim ** -0.5
192+
self.fused_attn = use_fused_attn()
108193

109-
spatial_dim = self.feat_size[0] * self.feat_size[1]
110-
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
194+
if class_token:
195+
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
196+
else:
197+
self.cls_token = None
198+
199+
if qkv_separate:
200+
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
201+
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
202+
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
203+
self.qkv = None
204+
else:
205+
self.q = self.k = self.v = None
206+
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
207+
self.drop = nn.Dropout(drop_rate)
208+
self.proj = nn.Linear(embed_dim, self.out_features)
209+
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
210+
211+
self.init_weights()
212+
213+
def init_weights(self, zero_init_last: bool = False):
214+
if self.qkv is None:
215+
in_features = self.q.in_features
216+
trunc_normal_(self.q.weight, std=in_features ** -0.5)
217+
nn.init.zeros_(self.q.bias)
218+
trunc_normal_(self.k.weight, std=in_features ** -0.5)
219+
nn.init.zeros_(self.k.bias)
220+
trunc_normal_(self.v.weight, std=in_features ** -0.5)
221+
nn.init.zeros_(self.v.bias)
222+
else:
223+
in_features = self.qkv.in_features
224+
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
225+
nn.init.zeros_(self.qkv.bias)
111226
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
112-
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
113-
nn.init.zeros_(self.qkv.bias)
114227

115-
def forward(self, x):
228+
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
229+
# NOTE: this module is being used as a head, so need compatible reset()
230+
if pool_type is not None:
231+
assert pool_type in ('', 'token')
232+
self.pool_type = pool_type
233+
if num_classes is not None:
234+
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
235+
self.out_features = num_classes if num_classes > 0 else self.embed_dim
236+
237+
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
238+
if self.pool_type == 'token':
239+
x = x[:, 0]
240+
else:
241+
# if not pooled, return spatial output without token
242+
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
243+
return x
244+
245+
def forward(self, x, pre_logits: bool = False):
116246
B, _, H, W = x.shape
117247
N = H * W
118-
assert self.feat_size[0] == H
119-
assert self.feat_size[1] == W
120-
x = x.reshape(B, -1, N).permute(0, 2, 1)
121-
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
122-
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
123-
124-
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
125-
q, k, v = x[0], x[1], x[2]
126-
attn = (q @ k.transpose(-2, -1)) * self.scale
127-
attn = attn.softmax(dim=-1)
128-
129-
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
248+
x = x.flatten(2).transpose(1, 2)
249+
if self.cls_token is None:
250+
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
251+
else:
252+
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
253+
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
254+
x = x + pos_embed
255+
256+
if self.qkv is None:
257+
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
258+
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
259+
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
260+
else:
261+
x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
262+
q, k, v = x.unbind(0)
263+
264+
if self.fused_attn:
265+
x = nn.functional.scaled_dot_product_attention(q, k, v)
266+
else:
267+
q = q * self.scale
268+
attn = q @ k.transpose(-2, -1)
269+
attn = attn.softmax(dim=-1)
270+
x = attn @ v
271+
x = x.transpose(1, 2).reshape(B, N + 1, -1)
272+
x = self.drop(x)
273+
if pre_logits:
274+
x = self._pool(x, H, W)
275+
return x
130276
x = self.proj(x)
131-
return x[:, 0]
277+
x = self._pool(x, H, W)
278+
return x

timm/layers/classifier.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ def _create_pool(
2424
):
2525
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
2626
if not pool_type:
27-
assert num_classes == 0 or use_conv,\
28-
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
2927
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
3028
global_pool = SelectAdaptivePool2d(
3129
pool_type=pool_type,

0 commit comments

Comments
 (0)