Skip to content

Commit bbb13de

Browse files
committed
Fix classifier input dim for mnv3 after last changes
1 parent 24872c9 commit bbb13de

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type, bias=head_bias)
119119
self.act2 = act_layer(inplace=True)
120120
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
121-
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
121+
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
122122

123123
efficientnet_init_weights(self)
124124

0 commit comments

Comments
 (0)