Skip to content

Commit 9cc7dda

Browse files
committed
Fixup byoanet configs to pass unit tests. Add swin_attn and swinnet26t model for testing.
1 parent e15c388 commit 9cc7dda

File tree

3 files changed

+216
-7
lines changed

3 files changed

+216
-7
lines changed

timm/models/byoanet.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
def _cfg(url='', **kwargs):
3636
return {
3737
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
38-
'crop_pct': 0.875, 'interpolation': 'bilinear',
38+
'crop_pct': 0.875, 'interpolation': 'bicubic',
3939
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
4040
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
4141
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
@@ -45,17 +45,19 @@ def _cfg(url='', **kwargs):
4545

4646
default_cfgs = {
4747
# GPU-Efficient (ResNet) weights
48-
'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256)),
48+
'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
4949
'botnet50t_224': _cfg(url='', fixed_input_size=True),
5050
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
5151

5252
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
5353
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
54-
'halonet26t': _cfg(url='', input_size=(3, 256, 256)),
55-
'halonet50t': _cfg(url=''),
54+
'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
55+
'halonet50t': _cfg(url='', min_input_size=(3, 224, 224)),
5656

57-
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256)),
57+
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
5858
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
59+
60+
'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
5961
}
6062

6163

@@ -95,10 +97,10 @@ def interleave_attn(
9597

9698
botnet26t=ByoaCfg(
9799
blocks=(
98-
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
100+
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
99101
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
100102
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
101-
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
103+
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
102104
),
103105
stem_chs=64,
104106
stem_type='tiered',
@@ -230,6 +232,22 @@ def interleave_attn(
230232
self_attn_layer='lambda',
231233
self_attn_kwargs=dict()
232234
),
235+
236+
swinnet26t=ByoaCfg(
237+
blocks=(
238+
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
239+
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
240+
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
241+
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
242+
),
243+
stem_chs=64,
244+
stem_type='tiered',
245+
stem_pool='maxpool',
246+
num_features=0,
247+
self_attn_layer='swin',
248+
self_attn_fixed_size=True,
249+
self_attn_kwargs=dict(win_size=8)
250+
),
233251
)
234252

235253

@@ -452,3 +470,11 @@ def lambda_resnet50t(pretrained=False, **kwargs):
452470
""" Lambda-ResNet-50T. Lambda layers in one C4 stage and all C5.
453471
"""
454472
return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs)
473+
474+
475+
@register_model
476+
def swinnet26t_256(pretrained=False, **kwargs):
477+
"""
478+
"""
479+
kwargs.setdefault('img_size', 256)
480+
return _create_byoanet('swinnet26t_256', 'swinnet26t', pretrained=pretrained, **kwargs)

timm/models/layers/create_self_attn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .bottleneck_attn import BottleneckAttn
22
from .halo_attn import HaloAttn
33
from .lambda_layer import LambdaLayer
4+
from .swin_attn import WindowAttention
45

56

67
def get_self_attn(attn_type):
@@ -10,6 +11,10 @@ def get_self_attn(attn_type):
1011
return HaloAttn
1112
elif attn_type == 'lambda':
1213
return LambdaLayer
14+
elif attn_type == 'swin':
15+
return WindowAttention
16+
else:
17+
assert False, f"Unknown attn type ({attn_type})"
1318

1419

1520
def create_self_attn(attn_type, dim, stride=1, **kwargs):

timm/models/layers/swin_attn.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
""" Shifted Window Attn
2+
3+
This is a WIP experiment to apply windowed attention from the Swin Transformer
4+
to a stand-alone module for use as an attn block in conv nets.
5+
6+
Based on original swin window code at https://github.com/microsoft/Swin-Transformer
7+
Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf
8+
"""
9+
from typing import Optional
10+
11+
import torch
12+
import torch.nn as nn
13+
14+
from .drop import DropPath
15+
from .helpers import to_2tuple
16+
from .weight_init import trunc_normal_
17+
18+
19+
def window_partition(x, win_size: int):
20+
"""
21+
Args:
22+
x: (B, H, W, C)
23+
win_size (int): window size
24+
25+
Returns:
26+
windows: (num_windows*B, window_size, window_size, C)
27+
"""
28+
B, H, W, C = x.shape
29+
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
30+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
31+
return windows
32+
33+
34+
def window_reverse(windows, win_size: int, H: int, W: int):
35+
"""
36+
Args:
37+
windows: (num_windows*B, window_size, window_size, C)
38+
win_size (int): Window size
39+
H (int): Height of image
40+
W (int): Width of image
41+
42+
Returns:
43+
x: (B, H, W, C)
44+
"""
45+
B = int(windows.shape[0] / (H * W / win_size / win_size))
46+
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
47+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
48+
return x
49+
50+
51+
class WindowAttention(nn.Module):
52+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
53+
It supports both of shifted and non-shifted window.
54+
55+
Args:
56+
dim (int): Number of input channels.
57+
win_size (int): The height and width of the window.
58+
num_heads (int): Number of attention heads.
59+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
60+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
61+
"""
62+
63+
def __init__(
64+
self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8,
65+
qkv_bias=True, attn_drop=0.):
66+
67+
super().__init__()
68+
self.dim_out = dim_out or dim
69+
self.feat_size = to_2tuple(feat_size)
70+
self.win_size = win_size
71+
self.shift_size = shift_size or win_size // 2
72+
if min(self.feat_size) <= win_size:
73+
# if window size is larger than input resolution, we don't partition windows
74+
self.shift_size = 0
75+
self.win_size = min(self.feat_size)
76+
assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size"
77+
self.num_heads = num_heads
78+
head_dim = self.dim_out // num_heads
79+
self.scale = head_dim ** -0.5
80+
81+
if self.shift_size > 0:
82+
# calculate attention mask for SW-MSA
83+
H, W = self.feat_size
84+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
85+
h_slices = (
86+
slice(0, -self.win_size),
87+
slice(-self.win_size, -self.shift_size),
88+
slice(-self.shift_size, None))
89+
w_slices = (
90+
slice(0, -self.win_size),
91+
slice(-self.win_size, -self.shift_size),
92+
slice(-self.shift_size, None))
93+
cnt = 0
94+
for h in h_slices:
95+
for w in w_slices:
96+
img_mask[:, h, w, :] = cnt
97+
cnt += 1
98+
mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1
99+
mask_windows = mask_windows.view(-1, self.win_size * self.win_size)
100+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
101+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
102+
else:
103+
attn_mask = None
104+
self.register_buffer("attn_mask", attn_mask)
105+
106+
# define a parameter table of relative position bias
107+
self.relative_position_bias_table = nn.Parameter(
108+
# 2 * Wh - 1 * 2 * Ww - 1, nH
109+
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
110+
111+
# get pair-wise relative position index for each token inside the window
112+
coords_h = torch.arange(self.win_size)
113+
coords_w = torch.arange(self.win_size)
114+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
115+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
116+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
117+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
118+
relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0
119+
relative_coords[:, :, 1] += self.win_size - 1
120+
relative_coords[:, :, 0] *= 2 * self.win_size - 1
121+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
122+
self.register_buffer("relative_position_index", relative_position_index)
123+
trunc_normal_(self.relative_position_bias_table, std=.02)
124+
125+
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
126+
self.attn_drop = nn.Dropout(attn_drop)
127+
self.softmax = nn.Softmax(dim=-1)
128+
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
129+
130+
def forward(self, x):
131+
B, C, H, W = x.shape
132+
x = x.permute(0, 2, 3, 1)
133+
134+
# cyclic shift
135+
if self.shift_size > 0:
136+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
137+
else:
138+
shifted_x = x
139+
140+
# partition windows
141+
win_size_sq = self.win_size * self.win_size
142+
x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C
143+
x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C
144+
BW, N, _ = x_windows.shape
145+
146+
qkv = self.qkv(x_windows)
147+
qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4)
148+
q, k, v = qkv[0], qkv[1], qkv[2]
149+
q = q * self.scale
150+
attn = (q @ k.transpose(-2, -1))
151+
152+
relative_position_bias = self.relative_position_bias_table[
153+
self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1)
154+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww
155+
attn = attn + relative_position_bias.unsqueeze(0)
156+
if self.attn_mask is not None:
157+
num_win = self.attn_mask.shape[0]
158+
attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0)
159+
attn = attn.view(-1, self.num_heads, N, N)
160+
attn = self.softmax(attn)
161+
attn = self.attn_drop(attn)
162+
163+
x = (attn @ v).transpose(1, 2).reshape(BW, N, self.dim_out)
164+
165+
# merge windows
166+
x = x.view(-1, self.win_size, self.win_size, self.dim_out)
167+
shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C
168+
169+
# reverse cyclic shift
170+
if self.shift_size > 0:
171+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
172+
else:
173+
x = shifted_x
174+
x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2)
175+
x = self.pool(x)
176+
return x
177+
178+

0 commit comments

Comments
 (0)