Skip to content

Commit 5ee0676

Browse files
committed
Fix classifier input dim for mnv3 after last changes
1 parent a5a2ad2 commit 5ee0676

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
self.norm_head = nn.Identity()
135135
self.act2 = act_layer(inplace=True)
136136
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
137-
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
137+
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
138138

139139
efficientnet_init_weights(self)
140140

0 commit comments

Comments
 (0)