Skip to content

Commit 7f395fd

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Export emformer RNNT encode, predict, join (#173)
Summary: Pull Request resolved: #173 An example to check the export path for Emformer-RNN-T encode, predict, join methods. Reviewed By: kimishpatel Differential Revision: D48327041 fbshipit-source-id: 97bb3ee3feadc01ac4f96b5e5702adb867aa2877
1 parent 2590257 commit 7f395fd

File tree

6 files changed

+158
-3
lines changed

6 files changed

+158
-3
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
DEFAULT_RUNNER = "linux.2xlarge"
2020
RUNNERS = {
2121
# This one runs OOM on smaller runner, the root cause is unclear (T163016365)
22-
"w2l": "linux.12xlarge"
22+
"w2l": "linux.12xlarge",
23+
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
24+
"emformer_join": "linux.12xlarge",
2325
}
2426

2527

examples/export/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True)
2020

21-
# Explicitly force the activation of the IR validator
21+
# TODO(T163721729): Enable IR check after decomposing div.Tensor_mode
2222
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
23-
_check_ir_validity=True,
23+
_check_ir_validity=False,
2424
)
2525

2626

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
"//caffe2:torch",
1111
"//executorch/examples/models:model_base", # @manual
1212
"//executorch/examples/models/deeplab_v3:dl3_model", # @manual
13+
"//executorch/examples/models/emformer_rnnt:emformer_rnnt_model", # @manual
1314
"//executorch/examples/models/inception_v3:ic3_model", # @manual
1415
"//executorch/examples/models/inception_v4:ic4_model", # @manual
1516
"//executorch/examples/models/mobilebert:mobilebert_model", # @manual

examples/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
"add": ("toy_model", "AddModule"),
1313
"add_mul": ("toy_model", "AddMulModule"),
1414
"dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"),
15+
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
16+
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
17+
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
1518
"mobilebert": ("mobilebert", "MobileBertModelExample"),
1619
"mv2": ("mobilenet_v2", "MV2Model"),
1720
"mv3": ("mobilenet_v3", "MV3Model"),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import (
8+
EmformerRnntJoinerModel,
9+
EmformerRnntPredictorModel,
10+
EmformerRnntTranscriberModel,
11+
)
12+
13+
__all__ = [
14+
EmformerRnntTranscriberModel,
15+
EmformerRnntPredictorModel,
16+
EmformerRnntJoinerModel,
17+
]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import logging
9+
10+
import torch
11+
import torchaudio
12+
13+
from ..model_base import EagerModelBase
14+
15+
16+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
17+
logging.basicConfig(format=FORMAT)
18+
19+
20+
__all__ = [
21+
"EmformerRnntTranscriberModel",
22+
"EmformerRnntPredictorModel",
23+
"EmformerRnntJoinerModel",
24+
]
25+
26+
27+
class EmformerRnntTranscriberExample(torch.nn.Module):
28+
"""
29+
This is a wrapper for validating transcriber for the Emformer RNN-T architecture.
30+
It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
31+
"""
32+
33+
def __init__(self) -> None:
34+
super().__init__()
35+
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
36+
decoder = bundle.get_decoder()
37+
m = decoder.model
38+
self.rnnt = m
39+
40+
def forward(self, transcribe_inputs):
41+
return self.rnnt.transcribe(*transcribe_inputs)
42+
43+
44+
class EmformerRnntTranscriberModel(EagerModelBase):
45+
def __init__(self):
46+
pass
47+
48+
def get_eager_model(self) -> torch.nn.Module:
49+
logging.info("Loading emformer rnnt transcriber")
50+
m = EmformerRnntTranscriberExample()
51+
logging.info("Loaded emformer rnnt transcriber")
52+
return m
53+
54+
def get_example_inputs(self):
55+
transcribe_inputs = (
56+
torch.randn(1, 128, 80),
57+
torch.tensor([128]),
58+
)
59+
return (transcribe_inputs,)
60+
61+
62+
class EmformerRnntPredictorExample(torch.nn.Module):
63+
"""
64+
This is a wrapper for validating predictor for the Emformer RNN-T architecture.
65+
It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
66+
"""
67+
68+
def __init__(self) -> None:
69+
super().__init__()
70+
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
71+
decoder = bundle.get_decoder()
72+
m = decoder.model
73+
self.rnnt = m
74+
75+
def forward(self, predict_inputs):
76+
return self.rnnt.predict(*predict_inputs)
77+
78+
79+
class EmformerRnntPredictorModel(EagerModelBase):
80+
def __init__(self):
81+
pass
82+
83+
def get_eager_model(self) -> torch.nn.Module:
84+
logging.info("Loading emformer rnnt predictor")
85+
m = EmformerRnntPredictorExample()
86+
logging.info("Loaded emformer rnnt predictor")
87+
return m
88+
89+
def get_example_inputs(self):
90+
predict_inputs = (
91+
torch.zeros([1, 128], dtype=int),
92+
torch.tensor([128], dtype=int),
93+
None,
94+
)
95+
return (predict_inputs,)
96+
97+
98+
class EmformerRnntJoinerExample(torch.nn.Module):
99+
"""
100+
This is a wrapper for validating joiner for the Emformer RNN-T architecture.
101+
It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
102+
"""
103+
104+
def __init__(self) -> None:
105+
super().__init__()
106+
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
107+
decoder = bundle.get_decoder()
108+
m = decoder.model
109+
self.rnnt = m
110+
111+
def forward(self, predict_inputs):
112+
return self.rnnt.join(*predict_inputs)
113+
114+
115+
class EmformerRnntJoinerModel(EagerModelBase):
116+
def __init__(self):
117+
pass
118+
119+
def get_eager_model(self) -> torch.nn.Module:
120+
logging.info("Loading emformer rnnt joiner")
121+
m = EmformerRnntJoinerExample()
122+
logging.info("Loaded emformer rnnt joiner")
123+
return m
124+
125+
def get_example_inputs(self):
126+
join_inputs = (
127+
torch.rand([1, 128, 1024]),
128+
torch.tensor([128]),
129+
torch.rand([1, 128, 1024]),
130+
torch.tensor([128]),
131+
)
132+
return (join_inputs,)

0 commit comments

Comments
 (0)