Skip to content

Commit be96db0

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for outputting constants
Summary: For Seamless model we ran into a case where we have to either return an int/float or a list of int's/float's from the model. We need to add support for this in the emitter and make sure that the memory planning pass ignores these. Differential Revision: D53256808
1 parent d0050dd commit be96db0

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

exir/emit/_emitter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,9 +1342,12 @@ def output(
13421342
self.outputs.append(args_tuple.id)
13431343
else:
13441344
for arg in args_tuple:
1345-
# Every output should already have its value emitted outputs should only be abstract
1346-
# IDs at this point.
1347-
assert isinstance(arg, _AbstractValue)
1345+
if isinstance(arg, (int, float)):
1346+
arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1347+
else:
1348+
# Every other output should already have its value emitted.
1349+
# They should only be abstract IDs at this point.
1350+
assert isinstance(arg, _AbstractValue)
13481351
self.outputs.append(arg.id)
13491352

13501353
def plan(self) -> ExecutionPlan:

exir/emit/test/test_emit.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,23 @@ def forward(
187187
"T2#1#0(T1#1($),D0())",
188188
)
189189

190+
def test_int_list_output(self):
191+
class M(torch.nn.Module):
192+
def forward(self, x):
193+
return [((1, 3, 1.2), [x + x, x * x])]
194+
195+
ep = torch.export.export(M(), (torch.ones(2, 3),))
196+
res = ep(torch.ones(2, 3))
197+
self.assertEqual(res[0][0], (1, 3, 1.2))
198+
program = to_edge(ep).to_executorch().executorch_program
199+
outputs = program.execution_plan[0].outputs
200+
self.assertEqual(len(outputs), 5)
201+
self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
202+
self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
203+
self.assertEqual(
204+
program.execution_plan[0].values[outputs[2]].val.double_val, 1.2
205+
)
206+
190207
def test_buffers_with_perfect_alignment(self) -> None:
191208
class Foo(torch.nn.Module):
192209
def forward(self, x: torch.Tensor) -> torch.Tensor:

exir/memory_planning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def get_node_tensor_specs(
453453
if not isinstance(specs, (list, tuple)):
454454
return []
455455
else:
456-
return specs
456+
return [spec for spec in specs if not isinstance(spec, (int, float))]
457457

458458

459459
@register_algo

0 commit comments

Comments
 (0)