Skip to content

Commit 1d0d0f8

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for outputting constants (#1774)
Summary: For Seamless model we ran into a case where we have to either return an int/float or a list of int's/float's from the model. We need to add support for this in the emitter and make sure that the memory planning pass ignores these. Reviewed By: angelayi Differential Revision: D53256808
1 parent 75aa0b4 commit 1d0d0f8

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

exir/emit/_emitter.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,9 +1369,19 @@ def output(
13691369
self.outputs.append(args_tuple.id)
13701370
else:
13711371
for arg in args_tuple:
1372-
# Every output should already have its value emitted outputs should only be abstract
1373-
# IDs at this point.
1374-
assert isinstance(arg, _AbstractValue)
1372+
if isinstance(arg, (int, float, bool)):
1373+
arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1374+
elif isinstance(arg, (type(None), str)):
1375+
raise InternalError(
1376+
self._emit_node_specific_error(
1377+
self.node,
1378+
f"Returning {arg} is not yet supported in the emitter.",
1379+
)
1380+
)
1381+
else:
1382+
# Every other output should already have its value emitted.
1383+
# They should only be abstract IDs at this point.
1384+
assert isinstance(arg, _AbstractValue)
13751385
self.outputs.append(arg.id)
13761386

13771387
def plan(self) -> ExecutionPlan:

exir/emit/test/test_emit.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,39 @@ def forward(
209209
"T2#1#0(T1#1($),D0())",
210210
)
211211

212+
def test_constant_output(self):
213+
class M(torch.nn.Module):
214+
def forward(self, x):
215+
return [((1, 3, 1.2), True, [x + x, x * x])]
216+
217+
ep = torch.export.export(M(), (torch.ones(2, 3),))
218+
res = ep(torch.ones(2, 3))
219+
self.assertEqual(res[0][0], (1, 3, 1.2))
220+
program = to_edge(ep).to_executorch().executorch_program
221+
outputs = program.execution_plan[0].outputs
222+
self.assertEqual(len(outputs), 6)
223+
self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
224+
self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
225+
self.assertEqual(
226+
program.execution_plan[0].values[outputs[2]].val.double_val, 1.2
227+
)
228+
self.assertEqual(
229+
program.execution_plan[0].values[outputs[2]].val.bool_val, True
230+
)
231+
232+
def test_int_list_input(self):
233+
class M(torch.nn.Module):
234+
def forward(self, x, y, z):
235+
return x + y, x + x, x + y + z
236+
237+
ep = torch.export.export(M(), (torch.ones(2, 3), 2, True))
238+
ep(torch.ones(2, 3), 2, True)
239+
program = to_edge(ep).to_executorch().executorch_program
240+
inputs = program.execution_plan[0].inputs
241+
self.assertEqual(len(inputs), 3)
242+
self.assertEqual(program.execution_plan[0].values[inputs[1]].val.int_val, 2)
243+
self.assertEqual(program.execution_plan[0].values[inputs[2]].val.bool_val, True)
244+
212245
def test_buffers_with_perfect_alignment(self) -> None:
213246
class Foo(torch.nn.Module):
214247
def forward(self, x: torch.Tensor) -> torch.Tensor:

exir/memory_planning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,11 @@ def get_node_tensor_specs(
495495
if not isinstance(specs, (list, tuple)):
496496
return []
497497
else:
498-
return specs
498+
return [
499+
spec
500+
for spec in specs
501+
if not isinstance(spec, (int, float, bool, str, type(None)))
502+
]
499503

500504

501505
@register_algo

0 commit comments

Comments
 (0)