Skip to content

Commit 8d992e7

Browse files
committed
Add softmax as another toy model
Summary: This is an interesting usecase to demonstrate on ARM Baremetal setup where they currently can't lower SoftMax. One of the only few from MV2. It is useful as a test as well while debugging this flow. Differential Revision: D49899589 fbshipit-source-id: c0307c7452a69eb707429328ba6654b2b660ce23
1 parent 85116fd commit 8d992e7

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"linear": ("toy_model", "LinearModule"),
1212
"add": ("toy_model", "AddModule"),
1313
"add_mul": ("toy_model", "AddMulModule"),
14+
"softmax" : ("toy_model", "SoftmaxModule"),
1415
"dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"),
1516
"edsr": ("edsr", "EdsrModel"),
1617
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),

examples/models/toy_model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .model import AddModule, AddMulModule, LinearModule, MulModule
7+
from .model import AddModule, AddMulModule, LinearModule, MulModule, SoftmaxModule
88

99
__all__ = [
1010
AddModule,
1111
AddMulModule,
1212
LinearModule,
1313
MulModule,
14+
SoftmaxModule,
1415
]

examples/models/toy_model/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,19 @@ def get_example_inputs(self):
7575
def get_compile_spec(self):
7676
max_value = self.get_example_inputs()[0].shape[0]
7777
return [CompileSpec("max_value", bytes([max_value]))]
78+
79+
class SoftmaxModule(torch.nn.Module, EagerModelBase):
80+
def __init__(self):
81+
super().__init__()
82+
self.softmax = torch.nn.Softmax()
83+
84+
def forward(self, x):
85+
z = self.softmax(x)
86+
return z
87+
88+
def get_eager_model(self) -> torch.nn.Module:
89+
return self
90+
91+
def get_example_inputs(self):
92+
return (torch.ones(2, 2),)
93+

0 commit comments

Comments
 (0)