Skip to content

Commit e34e857

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for outputting constants (#1774)
Summary: For Seamless model we ran into a case where we have to either return an int/float or a list of int's/float's from the model. We need to add support for this in the emitter and make sure that the memory planning pass ignores these. Differential Revision: D53256808
1 parent cf87e79 commit e34e857

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

exir/emit/_emitter.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,9 +1342,19 @@ def output(
13421342
self.outputs.append(args_tuple.id)
13431343
else:
13441344
for arg in args_tuple:
1345-
# Every output should already have its value emitted outputs should only be abstract
1346-
# IDs at this point.
1347-
assert isinstance(arg, _AbstractValue)
1345+
if isinstance(arg, (int, float, bool)):
1346+
arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1347+
elif isinstance(arg, (type(None), str)):
1348+
raise InternalError(
1349+
self._emit_node_specific_error(
1350+
self.node,
1351+
f"Returning {arg} is not yet supported in the emitter.",
1352+
)
1353+
)
1354+
else:
1355+
# Every other output should already have its value emitted.
1356+
# They should only be abstract IDs at this point.
1357+
assert isinstance(arg, _AbstractValue)
13481358
self.outputs.append(arg.id)
13491359

13501360
def plan(self) -> ExecutionPlan:

exir/emit/test/test_emit.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,39 @@ def forward(
187187
"T2#1#0(T1#1($),D0())",
188188
)
189189

190+
def test_constant_output(self):
191+
class M(torch.nn.Module):
192+
def forward(self, x):
193+
return [((1, 3, 1.2), True, [x + x, x * x])]
194+
195+
ep = torch.export.export(M(), (torch.ones(2, 3),))
196+
res = ep(torch.ones(2, 3))
197+
self.assertEqual(res[0][0], (1, 3, 1.2))
198+
program = to_edge(ep).to_executorch().executorch_program
199+
outputs = program.execution_plan[0].outputs
200+
self.assertEqual(len(outputs), 6)
201+
self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
202+
self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
203+
self.assertEqual(
204+
program.execution_plan[0].values[outputs[2]].val.double_val, 1.2
205+
)
206+
self.assertEqual(
207+
program.execution_plan[0].values[outputs[2]].val.bool_val, True
208+
)
209+
210+
def test_int_list_input(self):
211+
class M(torch.nn.Module):
212+
def forward(self, x, y, z):
213+
return x + y, x + x, x + y + z
214+
215+
ep = torch.export.export(M(), (torch.ones(2, 3), 2, True))
216+
ep(torch.ones(2, 3), 2, True)
217+
program = to_edge(ep).to_executorch().executorch_program
218+
inputs = program.execution_plan[0].inputs
219+
self.assertEqual(len(inputs), 3)
220+
self.assertEqual(program.execution_plan[0].values[inputs[1]].val.int_val, 2)
221+
self.assertEqual(program.execution_plan[0].values[inputs[2]].val.bool_val, True)
222+
190223
def test_buffers_with_perfect_alignment(self) -> None:
191224
class Foo(torch.nn.Module):
192225
def forward(self, x: torch.Tensor) -> torch.Tensor:

exir/memory_planning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,11 @@ def get_node_tensor_specs(
455455
if not isinstance(specs, (list, tuple)):
456456
return []
457457
else:
458-
return specs
458+
return [
459+
spec
460+
for spec in specs
461+
if not isinstance(spec, (int, float, bool, str, type(None)))
462+
]
459463

460464

461465
@register_algo

0 commit comments

Comments
 (0)