Skip to content

Commit 3ceeedc

Browse files
authored
Merge pull request #53 from rwightman/condconvs_and_features
Major model merge (EfficientNet-CondConv, EfficientNet-AdvProp, TF MobileNetV3, HRNet, more)
2 parents db04677 + 7b3c235 commit 3ceeedc

19 files changed

+3980
-2512
lines changed

clean_checkpoint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import os
44
import hashlib
5+
import shutil
56
from collections import OrderedDict
67

78
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
@@ -31,10 +32,9 @@ def main():
3132
if state_dict_key in checkpoint:
3233
state_dict = checkpoint[state_dict_key]
3334
else:
34-
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
35-
exit(1)
35+
state_dict = checkpoint
3636
else:
37-
state_dict = checkpoint
37+
assert False
3838
for k, v in state_dict.items():
3939
name = k[7:] if k.startswith('module') else k
4040
new_state_dict[name] = v
@@ -43,7 +43,11 @@ def main():
4343
torch.save(new_state_dict, args.output)
4444
with open(args.output, 'rb') as f:
4545
sha_hash = hashlib.sha256(f.read()).hexdigest()
46-
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
46+
47+
checkpoint_base = os.path.splitext(args.checkpoint)[0]
48+
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
49+
shutil.move(args.output, final_filename)
50+
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
4751
else:
4852
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
4953

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
torch>=1.1.0
2-
torchvision>=0.3.0
1+
torch>=1.2.0
2+
torchvision>=0.4.0
33
pyyaml

sotabench.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
7878
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
7979
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
8080
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
81-
_entry('mobilenetv3_100', 'MobileNet V3-Large 1.0', '1905.02244',
81+
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
8282
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
8383
'paper as closely as possible.'),
8484
_entry('resnet18', 'ResNet-18', '1812.01187'),
@@ -114,6 +114,30 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
114114
model_desc='Ported from official Google AI Tensorflow weights'),
115115
_entry('tf_efficientnet_b7', 'EfficientNet-B7 (RandAugment)', '1905.11946', batch_size=BATCH_SIZE//8,
116116
model_desc='Ported from official Google AI Tensorflow weights'),
117+
_entry('tf_efficientnet_b0_ap', 'EfficientNet-B0 (AdvProp)', '1911.09665',
118+
model_desc='Ported from official Google AI Tensorflow weights'),
119+
_entry('tf_efficientnet_b1_ap', 'EfficientNet-B1 (AdvProp)', '1911.09665',
120+
model_desc='Ported from official Google AI Tensorflow weights'),
121+
_entry('tf_efficientnet_b2_ap', 'EfficientNet-B2 (AdvProp)', '1911.09665',
122+
model_desc='Ported from official Google AI Tensorflow weights'),
123+
_entry('tf_efficientnet_b3_ap', 'EfficientNet-B3 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 2,
124+
model_desc='Ported from official Google AI Tensorflow weights'),
125+
_entry('tf_efficientnet_b4_ap', 'EfficientNet-B4 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 2,
126+
model_desc='Ported from official Google AI Tensorflow weights'),
127+
_entry('tf_efficientnet_b5_ap', 'EfficientNet-B5 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 4,
128+
model_desc='Ported from official Google AI Tensorflow weights'),
129+
_entry('tf_efficientnet_b6_ap', 'EfficientNet-B6 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8,
130+
model_desc='Ported from official Google AI Tensorflow weights'),
131+
_entry('tf_efficientnet_b7_ap', 'EfficientNet-B7 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8,
132+
model_desc='Ported from official Google AI Tensorflow weights'),
133+
_entry('tf_efficientnet_b8_ap', 'EfficientNet-B8 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8,
134+
model_desc='Ported from official Google AI Tensorflow weights'),
135+
_entry('tf_efficientnet_cc_b0_4e', 'EfficientNet-CondConv-B0 4 experts', '1904.04971',
136+
model_desc='Ported from official Google AI Tensorflow weights'),
137+
_entry('tf_efficientnet_cc_b0_8e', 'EfficientNet-CondConv-B0 8 experts', '1904.04971',
138+
model_desc='Ported from official Google AI Tensorflow weights'),
139+
_entry('tf_efficientnet_cc_b1_8e', 'EfficientNet-CondConv-B1 8 experts', '1904.04971',
140+
model_desc='Ported from official Google AI Tensorflow weights'),
117141
_entry('tf_efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
118142
model_desc='Ported from official Google AI Tensorflow weights'),
119143
_entry('tf_efficientnet_em', 'EfficientNet-EdgeTPU-M', '1905.11946',
@@ -124,6 +148,18 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
124148
_entry('tf_mixnet_l', 'MixNet-L', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
125149
_entry('tf_mixnet_m', 'MixNet-M', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
126150
_entry('tf_mixnet_s', 'MixNet-S', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
151+
_entry('tf_mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244',
152+
model_desc='Ported from official Google AI Tensorflow weights'),
153+
_entry('tf_mobilenetv3_large_075', 'MobileNet V3-Large 0.75', '1905.02244',
154+
model_desc='Ported from official Google AI Tensorflow weights'),
155+
_entry('tf_mobilenetv3_large_minimal_100', 'MobileNet V3-Large Minimal 1.0', '1905.02244',
156+
model_desc='Ported from official Google AI Tensorflow weights'),
157+
_entry('tf_mobilenetv3_small_100', 'MobileNet V3-Small 1.0', '1905.02244',
158+
model_desc='Ported from official Google AI Tensorflow weights'),
159+
_entry('tf_mobilenetv3_small_075', 'MobileNet V3-Small 0.75', '1905.02244',
160+
model_desc='Ported from official Google AI Tensorflow weights'),
161+
_entry('tf_mobilenetv3_small_minimal_100', 'MobileNet V3-Small Minimal 1.0', '1905.02244',
162+
model_desc='Ported from official Google AI Tensorflow weights'),
127163

128164
## Cadene ported weights (to remove if Cadene adds sotabench)
129165
_entry('inception_resnet_v2', 'Inception ResNet V2', '1602.07261'),

timm/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from .xception import *
88
from .nasnet import *
99
from .pnasnet import *
10-
from .gen_efficientnet import *
10+
from .efficientnet import *
11+
from .mobilenetv3 import *
1112
from .inception_v3 import *
1213
from .gluon_resnet import *
1314
from .gluon_xception import *
1415
from .res2net import *
1516
from .dla import *
17+
from .hrnet import *
1618

1719
from .registry import *
1820
from .factory import create_model

timm/models/activations.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
from torch import nn as nn
3+
from torch.nn import functional as F
4+
5+
6+
_USE_MEM_EFFICIENT_ISH = True
7+
if _USE_MEM_EFFICIENT_ISH:
8+
# This version reduces memory overhead of Swish during training by
9+
# recomputing torch.sigmoid(x) in backward instead of saving it.
10+
@torch.jit.script
11+
def swish_jit_fwd(x):
12+
return x.mul(torch.sigmoid(x))
13+
14+
15+
@torch.jit.script
16+
def swish_jit_bwd(x, grad_output):
17+
x_sigmoid = torch.sigmoid(x)
18+
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
19+
20+
21+
class SwishJitAutoFn(torch.autograd.Function):
22+
""" torch.jit.script optimised Swish
23+
Inspired by conversation btw Jeremy Howard & Adam Pazske
24+
https://twitter.com/jeremyphoward/status/1188251041835315200
25+
"""
26+
27+
@staticmethod
28+
def forward(ctx, x):
29+
ctx.save_for_backward(x)
30+
return swish_jit_fwd(x)
31+
32+
@staticmethod
33+
def backward(ctx, grad_output):
34+
x = ctx.saved_tensors[0]
35+
return swish_jit_bwd(x, grad_output)
36+
37+
38+
def swish(x, _inplace=False):
39+
return SwishJitAutoFn.apply(x)
40+
41+
42+
@torch.jit.script
43+
def mish_jit_fwd(x):
44+
return x.mul(torch.tanh(F.softplus(x)))
45+
46+
47+
@torch.jit.script
48+
def mish_jit_bwd(x, grad_output):
49+
x_sigmoid = torch.sigmoid(x)
50+
x_tanh_sp = F.softplus(x).tanh()
51+
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
52+
53+
54+
class MishJitAutoFn(torch.autograd.Function):
55+
@staticmethod
56+
def forward(ctx, x):
57+
ctx.save_for_backward(x)
58+
return mish_jit_fwd(x)
59+
60+
@staticmethod
61+
def backward(ctx, grad_output):
62+
x = ctx.saved_tensors[0]
63+
return mish_jit_bwd(x, grad_output)
64+
65+
def mish(x, _inplace=False):
66+
return MishJitAutoFn.apply(x)
67+
68+
else:
69+
def swish(x, inplace=False):
70+
"""Swish - Described in: https://arxiv.org/abs/1710.05941
71+
"""
72+
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
73+
74+
75+
def mish(x, _inplace=False):
76+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
77+
"""
78+
return x.mul(F.softplus(x).tanh())
79+
80+
81+
class Swish(nn.Module):
82+
def __init__(self, inplace=False):
83+
super(Swish, self).__init__()
84+
self.inplace = inplace
85+
86+
def forward(self, x):
87+
return swish(x, self.inplace)
88+
89+
90+
class Mish(nn.Module):
91+
def __init__(self, inplace=False):
92+
super(Mish, self).__init__()
93+
self.inplace = inplace
94+
95+
def forward(self, x):
96+
return mish(x, self.inplace)
97+
98+
99+
def sigmoid(x, inplace=False):
100+
return x.sigmoid_() if inplace else x.sigmoid()
101+
102+
103+
# PyTorch has this, but not with a consistent inplace argmument interface
104+
class Sigmoid(nn.Module):
105+
def __init__(self, inplace=False):
106+
super(Sigmoid, self).__init__()
107+
self.inplace = inplace
108+
109+
def forward(self, x):
110+
return x.sigmoid_() if self.inplace else x.sigmoid()
111+
112+
113+
def tanh(x, inplace=False):
114+
return x.tanh_() if inplace else x.tanh()
115+
116+
117+
# PyTorch has this, but not with a consistent inplace argmument interface
118+
class Tanh(nn.Module):
119+
def __init__(self, inplace=False):
120+
super(Tanh, self).__init__()
121+
self.inplace = inplace
122+
123+
def forward(self, x):
124+
return x.tanh_() if self.inplace else x.tanh()
125+
126+
127+
def hard_swish(x, inplace=False):
128+
inner = F.relu6(x + 3.).div_(6.)
129+
return x.mul_(inner) if inplace else x.mul(inner)
130+
131+
132+
class HardSwish(nn.Module):
133+
def __init__(self, inplace=False):
134+
super(HardSwish, self).__init__()
135+
self.inplace = inplace
136+
137+
def forward(self, x):
138+
return hard_swish(x, self.inplace)
139+
140+
141+
def hard_sigmoid(x, inplace=False):
142+
if inplace:
143+
return x.add_(3.).clamp_(0., 6.).div_(6.)
144+
else:
145+
return F.relu6(x + 3.) / 6.
146+
147+
148+
class HardSigmoid(nn.Module):
149+
def __init__(self, inplace=False):
150+
super(HardSigmoid, self).__init__()
151+
self.inplace = inplace
152+
153+
def forward(self, x):
154+
return hard_sigmoid(x, self.inplace)
155+

timm/models/conv2d_helpers.py

Lines changed: 0 additions & 120 deletions
This file was deleted.

0 commit comments

Comments
 (0)