Skip to content

Commit fb3a0f4

Browse files
authored
Merge pull request #65 from mehtadushy/selecsls
Incorporate SelecSLS Models
2 parents 0554b79 + 2404361 commit fb3a0f4

File tree

2 files changed

+369
-0
lines changed

2 files changed

+369
-0
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .xception import *
88
from .nasnet import *
99
from .pnasnet import *
10+
from .selecsls import *
1011
from .efficientnet import *
1112
from .mobilenetv3 import *
1213
from .inception_v3 import *

timm/models/selecsls.py

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
"""PyTorch SelecSLS Net example for ImageNet Classification
2+
License: CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode)
3+
Author: Dushyant Mehta (@mehtadushy)
4+
5+
SelecSLS (core) Network Architecture as proposed in "XNect: Real-time Multi-person 3D
6+
Human Pose Estimation with a Single RGB Camera, Mehta et al."
7+
https://arxiv.org/abs/1907.00837
8+
9+
Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models
10+
and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch
11+
"""
12+
import math
13+
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
18+
from .registry import register_model
19+
from .helpers import load_pretrained
20+
from .adaptive_avgmax_pool import SelectAdaptivePool2d
21+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
22+
23+
24+
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
25+
26+
27+
def _cfg(url='', **kwargs):
28+
return {
29+
'url': url,
30+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3),
31+
'crop_pct': 0.875, 'interpolation': 'bilinear',
32+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
33+
'first_conv': 'stem', 'classifier': 'fc',
34+
**kwargs
35+
}
36+
37+
38+
default_cfgs = {
39+
'selecsls42': _cfg(
40+
url='',
41+
interpolation='bicubic'),
42+
'selecsls42_B': _cfg(
43+
url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B.pth',
44+
interpolation='bicubic'),
45+
'selecsls60': _cfg(
46+
url='',
47+
interpolation='bicubic'),
48+
'selecsls60_B': _cfg(
49+
url='http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B.pth',
50+
interpolation='bicubic'),
51+
'selecsls84': _cfg(
52+
url='',
53+
interpolation='bicubic'),
54+
}
55+
56+
57+
def conv_bn(inp, oup, stride):
58+
return nn.Sequential(
59+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
60+
nn.BatchNorm2d(oup),
61+
nn.ReLU(inplace=True)
62+
)
63+
64+
65+
def conv_1x1_bn(inp, oup):
66+
return nn.Sequential(
67+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
68+
nn.BatchNorm2d(oup),
69+
nn.ReLU(inplace=True)
70+
)
71+
72+
class SelecSLSBlock(nn.Module):
73+
def __init__(self, inp, skip, k, oup, isFirst, stride):
74+
super(SelecSLSBlock, self).__init__()
75+
self.stride = stride
76+
self.isFirst = isFirst
77+
assert stride in [1, 2]
78+
79+
#Process input with 4 conv blocks with the same number of input and output channels
80+
self.conv1 = nn.Sequential(
81+
nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1),
82+
nn.BatchNorm2d(k),
83+
nn.ReLU(inplace=True)
84+
)
85+
self.conv2 = nn.Sequential(
86+
nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
87+
nn.BatchNorm2d(k),
88+
nn.ReLU(inplace=True)
89+
)
90+
self.conv3 = nn.Sequential(
91+
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
92+
nn.BatchNorm2d(k//2),
93+
nn.ReLU(inplace=True)
94+
)
95+
self.conv4 = nn.Sequential(
96+
nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
97+
nn.BatchNorm2d(k),
98+
nn.ReLU(inplace=True)
99+
)
100+
self.conv5 = nn.Sequential(
101+
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
102+
nn.BatchNorm2d(k//2),
103+
nn.ReLU(inplace=True)
104+
)
105+
self.conv6 = nn.Sequential(
106+
nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1),
107+
nn.BatchNorm2d(oup),
108+
nn.ReLU(inplace=True)
109+
)
110+
111+
def forward(self, x):
112+
assert isinstance(x,list)
113+
assert len(x) in [1,2]
114+
115+
d1 = self.conv1(x[0])
116+
d2 = self.conv3(self.conv2(d1))
117+
d3 = self.conv5(self.conv4(d2))
118+
if self.isFirst:
119+
out = self.conv6(torch.cat([d1, d2, d3], 1))
120+
return [out, out]
121+
else:
122+
return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]]
123+
124+
class SelecSLS(nn.Module):
125+
"""SelecSLS42 / SelecSLS60 / SelecSLS84
126+
127+
Parameters
128+
----------
129+
cfg : network config
130+
String indicating the network config
131+
num_classes : int, default 1000
132+
Number of classification classes.
133+
in_chans : int, default 3
134+
Number of input (color) channels.
135+
drop_rate : float, default 0.
136+
Dropout probability before classifier, for training
137+
global_pool : str, default 'avg'
138+
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
139+
"""
140+
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3,
141+
drop_rate=0.0, global_pool='avg'):
142+
self.num_classes = num_classes
143+
self.drop_rate = drop_rate
144+
super(SelecSLS, self).__init__()
145+
146+
self.stem = conv_bn(in_chans, 32, 2)
147+
#Core Network
148+
self.features = []
149+
if cfg=='selecsls42':
150+
self.block = SelecSLSBlock
151+
#Define configuration of the network after the initial neck
152+
self.selecSLS_config = [
153+
#inp,skip, k, oup, isFirst, stride
154+
[ 32, 0, 64, 64, True, 2],
155+
[ 64, 64, 64, 128, False, 1],
156+
[128, 0, 144, 144, True, 2],
157+
[144, 144, 144, 288, False, 1],
158+
[288, 0, 304, 304, True, 2],
159+
[304, 304, 304, 480, False, 1],
160+
]
161+
#Head can be replaced with alternative configurations depending on the problem
162+
self.head = nn.Sequential(
163+
conv_bn(480, 960, 2),
164+
conv_bn(960, 1024, 1),
165+
conv_bn(1024, 1024, 2),
166+
conv_1x1_bn(1024, 1280),
167+
)
168+
self.num_features = 1280
169+
elif cfg=='selecsls42_B':
170+
self.block = SelecSLSBlock
171+
#Define configuration of the network after the initial neck
172+
self.selecSLS_config = [
173+
#inp,skip, k, oup, isFirst, stride
174+
[ 32, 0, 64, 64, True, 2],
175+
[ 64, 64, 64, 128, False, 1],
176+
[128, 0, 144, 144, True, 2],
177+
[144, 144, 144, 288, False, 1],
178+
[288, 0, 304, 304, True, 2],
179+
[304, 304, 304, 480, False, 1],
180+
]
181+
#Head can be replaced with alternative configurations depending on the problem
182+
self.head = nn.Sequential(
183+
conv_bn(480, 960, 2),
184+
conv_bn(960, 1024, 1),
185+
conv_bn(1024, 1280, 2),
186+
conv_1x1_bn(1280, 1024),
187+
)
188+
self.num_features = 1024
189+
elif cfg=='selecsls60':
190+
self.block = SelecSLSBlock
191+
#Define configuration of the network after the initial neck
192+
self.selecSLS_config = [
193+
#inp,skip, k, oup, isFirst, stride
194+
[ 32, 0, 64, 64, True, 2],
195+
[ 64, 64, 64, 128, False, 1],
196+
[128, 0, 128, 128, True, 2],
197+
[128, 128, 128, 128, False, 1],
198+
[128, 128, 128, 288, False, 1],
199+
[288, 0, 288, 288, True, 2],
200+
[288, 288, 288, 288, False, 1],
201+
[288, 288, 288, 288, False, 1],
202+
[288, 288, 288, 416, False, 1],
203+
]
204+
#Head can be replaced with alternative configurations depending on the problem
205+
self.head = nn.Sequential(
206+
conv_bn(416, 756, 2),
207+
conv_bn(756, 1024, 1),
208+
conv_bn(1024, 1024, 2),
209+
conv_1x1_bn(1024, 1280),
210+
)
211+
self.num_features = 1280
212+
elif cfg=='selecsls60_B':
213+
self.block = SelecSLSBlock
214+
#Define configuration of the network after the initial neck
215+
self.selecSLS_config = [
216+
#inp,skip, k, oup, isFirst, stride
217+
[ 32, 0, 64, 64, True, 2],
218+
[ 64, 64, 64, 128, False, 1],
219+
[128, 0, 128, 128, True, 2],
220+
[128, 128, 128, 128, False, 1],
221+
[128, 128, 128, 288, False, 1],
222+
[288, 0, 288, 288, True, 2],
223+
[288, 288, 288, 288, False, 1],
224+
[288, 288, 288, 288, False, 1],
225+
[288, 288, 288, 416, False, 1],
226+
]
227+
#Head can be replaced with alternative configurations depending on the problem
228+
self.head = nn.Sequential(
229+
conv_bn(416, 756, 2),
230+
conv_bn(756, 1024, 1),
231+
conv_bn(1024, 1280, 2),
232+
conv_1x1_bn(1280, 1024),
233+
)
234+
self.num_features = 1024
235+
elif cfg=='selecsls84':
236+
self.block = SelecSLSBlock
237+
#Define configuration of the network after the initial neck
238+
self.selecSLS_config = [
239+
#inp,skip, k, oup, isFirst, stride
240+
[ 32, 0, 64, 64, True, 2],
241+
[ 64, 64, 64, 144, False, 1],
242+
[144, 0, 144, 144, True, 2],
243+
[144, 144, 144, 144, False, 1],
244+
[144, 144, 144, 144, False, 1],
245+
[144, 144, 144, 144, False, 1],
246+
[144, 144, 144, 304, False, 1],
247+
[304, 0, 304, 304, True, 2],
248+
[304, 304, 304, 304, False, 1],
249+
[304, 304, 304, 304, False, 1],
250+
[304, 304, 304, 304, False, 1],
251+
[304, 304, 304, 304, False, 1],
252+
[304, 304, 304, 512, False, 1],
253+
]
254+
#Head can be replaced with alternative configurations depending on the problem
255+
self.head = nn.Sequential(
256+
conv_bn(512, 960, 2),
257+
conv_bn(960, 1024, 1),
258+
conv_bn(1024, 1024, 2),
259+
conv_1x1_bn(1024, 1280),
260+
)
261+
self.num_features = 1280
262+
else:
263+
raise ValueError('Invalid net configuration '+cfg+' !!!')
264+
265+
for inp, skip, k, oup, isFirst, stride in self.selecSLS_config:
266+
self.features.append(self.block(inp, skip, k, oup, isFirst, stride))
267+
self.features = nn.Sequential(*self.features)
268+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
269+
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
270+
271+
for n, m in self.named_modules():
272+
if isinstance(m, nn.Conv2d):
273+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
274+
elif isinstance(m, nn.BatchNorm2d):
275+
nn.init.constant_(m.weight, 1.)
276+
nn.init.constant_(m.bias, 0.)
277+
278+
def get_classifier(self):
279+
return self.fc
280+
281+
def reset_classifier(self, num_classes, global_pool='avg'):
282+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
283+
self.num_classes = num_classes
284+
del self.fc
285+
if num_classes:
286+
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
287+
else:
288+
self.fc = None
289+
290+
def forward_features(self, x, pool=True):
291+
x = self.stem(x)
292+
x = self.features([x])
293+
x = self.head(x[0])
294+
295+
if pool:
296+
x = self.global_pool(x)
297+
x = x.view(x.size(0), -1)
298+
return x
299+
300+
def forward(self, x):
301+
x = self.forward_features(x)
302+
if self.drop_rate > 0.:
303+
x = F.dropout(x, p=self.drop_rate, training=self.training)
304+
x = self.fc(x)
305+
return x
306+
307+
308+
@register_model
309+
def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
310+
"""Constructs a SelecSLS42 model.
311+
"""
312+
default_cfg = default_cfgs['selecsls42']
313+
model = SelecSLS(
314+
cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs)
315+
model.default_cfg = default_cfg
316+
if pretrained:
317+
load_pretrained(model, default_cfg, num_classes, in_chans)
318+
return model
319+
320+
@register_model
321+
def selecsls42_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
322+
"""Constructs a SelecSLS42_B model.
323+
"""
324+
default_cfg = default_cfgs['selecsls42_B']
325+
model = SelecSLS(
326+
cfg='selecsls42_B', num_classes=1000, in_chans=3,**kwargs)
327+
model.default_cfg = default_cfg
328+
if pretrained:
329+
load_pretrained(model, default_cfg, num_classes, in_chans)
330+
return model
331+
332+
@register_model
333+
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
334+
"""Constructs a SelecSLS60 model.
335+
"""
336+
default_cfg = default_cfgs['selecsls60']
337+
model = SelecSLS(
338+
cfg='selecsls60', num_classes=1000, in_chans=3,**kwargs)
339+
model.default_cfg = default_cfg
340+
if pretrained:
341+
load_pretrained(model, default_cfg, num_classes, in_chans)
342+
return model
343+
344+
345+
@register_model
346+
def selecsls60_B(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
347+
"""Constructs a SelecSLS60_B model.
348+
"""
349+
default_cfg = default_cfgs['selecsls60_B']
350+
model = SelecSLS(
351+
cfg='selecsls60_B', num_classes=1000, in_chans=3,**kwargs)
352+
model.default_cfg = default_cfg
353+
if pretrained:
354+
load_pretrained(model, default_cfg, num_classes, in_chans)
355+
return model
356+
357+
@register_model
358+
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
359+
"""Constructs a SelecSLS84 model.
360+
"""
361+
default_cfg = default_cfgs['selecsls84']
362+
model = SelecSLS(
363+
cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs)
364+
model.default_cfg = default_cfg
365+
if pretrained:
366+
load_pretrained(model, default_cfg, num_classes, in_chans)
367+
return model
368+

0 commit comments

Comments
 (0)