Skip to content

Commit 204759c

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Dropping Trailing None's When Calculating Emitted ExecutionPlan Outputs (#2243)
Summary: Previously Support for outputting int/float was added in D53256808 for Seamless This extends on this by dropping None Outputs (Which can be outputted by NanoGPT) Differential Revision: D54328663
1 parent 748b09f commit 204759c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

exir/emit/_emitter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,9 +1367,9 @@ def output(
13671367
self.outputs.append(args_tuple.id)
13681368
else:
13691369
for arg in args_tuple:
1370-
if isinstance(arg, (int, float, bool)):
1370+
if isinstance(arg, (int, float, bool, type(None))):
13711371
arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1372-
elif isinstance(arg, (type(None), str)):
1372+
elif isinstance(arg, str):
13731373
raise InternalError(
13741374
self._emit_node_specific_error(
13751375
self.node,

exir/emit/test/test_emit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,14 @@ def forward(
213213
def test_constant_output(self):
214214
class M(torch.nn.Module):
215215
def forward(self, x):
216-
return [((1, 3, 1.2), True, [x + x, x * x])]
216+
return [((1, 3, 1.2), True, [x + x, x * x], None)]
217217

218218
ep = torch.export.export(M(), (torch.ones(2, 3),))
219219
res = ep.module()(torch.ones(2, 3))
220220
self.assertEqual(res[0][0], (1, 3, 1.2))
221221
program = to_edge(ep).to_executorch().executorch_program
222222
outputs = program.execution_plan[0].outputs
223-
self.assertEqual(len(outputs), 6)
223+
self.assertEqual(len(outputs), 7)
224224
self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
225225
self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
226226
self.assertEqual(
@@ -229,6 +229,7 @@ def forward(self, x):
229229
self.assertEqual(
230230
program.execution_plan[0].values[outputs[3]].val.bool_val, True
231231
)
232+
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
232233

233234
def test_int_list_input(self):
234235
class M(torch.nn.Module):

0 commit comments

Comments
 (0)