Skip to content

Commit f9f0605

Browse files
committed
Add softmax as another toy model
Summary: This is an interesting usecase to demonstrate on ARM Baremetal setup where they currently can't lower SoftMax. One of the only few from MV2. It is useful as a test as well while debugging this flow. Differential Revision: D49899589 fbshipit-source-id: c0307c7452a69eb707429328ba6654b2b660ce23
1 parent 85116fd commit f9f0605

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"linear": ("toy_model", "LinearModule"),
1212
"add": ("toy_model", "AddModule"),
1313
"add_mul": ("toy_model", "AddMulModule"),
14+
"softmax": ("toy_model", "SoftmaxModule"),
1415
"dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"),
1516
"edsr": ("edsr", "EdsrModel"),
1617
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),

examples/models/toy_model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .model import AddModule, AddMulModule, LinearModule, MulModule
7+
from .model import AddModule, AddMulModule, LinearModule, MulModule, SoftmaxModule
88

99
__all__ = [
1010
AddModule,
1111
AddMulModule,
1212
LinearModule,
1313
MulModule,
14+
SoftmaxModule,
1415
]

examples/models/toy_model/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,19 @@ def get_example_inputs(self):
7575
def get_compile_spec(self):
7676
max_value = self.get_example_inputs()[0].shape[0]
7777
return [CompileSpec("max_value", bytes([max_value]))]
78+
79+
80+
class SoftmaxModule(torch.nn.Module, EagerModelBase):
81+
def __init__(self):
82+
super().__init__()
83+
self.softmax = torch.nn.Softmax()
84+
85+
def forward(self, x):
86+
z = self.softmax(x)
87+
return z
88+
89+
def get_eager_model(self) -> torch.nn.Module:
90+
return self
91+
92+
def get_example_inputs(self):
93+
return (torch.ones(2, 2),)

0 commit comments

Comments
 (0)