Skip to content

Commit b22c224

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Support None's When Calculating Emitted ExecutionPlan Outputs (#2243)
Summary: Pull Request resolved: #2243 Previously Support for outputting int/float was added in D53256808 for Seamless This extends on this by dropping None Outputs (Which can be outputted by NanoGPT) bypass-github-export-checks Reviewed By: JacobSzwejbka, Gasoonjia Differential Revision: D54328663 fbshipit-source-id: 8aa3c4f85772735c01d0fe0cfa52d6251af93dbc
1 parent 23c8172 commit b22c224

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

exir/emit/_emitter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,9 +1409,10 @@ def output(
14091409
self.outputs.append(args_tuple.id)
14101410
else:
14111411
for arg in args_tuple:
1412-
if isinstance(arg, (int, float, bool)):
1412+
if isinstance(arg, (int, float, bool, type(None))):
14131413
arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1414-
elif isinstance(arg, (type(None), str)):
1414+
elif isinstance(arg, str):
1415+
# TODO(jackkhuu): T181599879 Add support for string outputs IFF compiler supports
14151416
raise InternalError(
14161417
self._emit_node_specific_error(
14171418
self.node,

exir/emit/test/test_emit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,14 @@ def forward(
215215
def test_constant_output(self):
216216
class M(torch.nn.Module):
217217
def forward(self, x):
218-
return [((1, 3, 1.2), True, [x + x, x * x])]
218+
return [((1, 3, 1.2), True, [x + x, x * x], None)]
219219

220220
ep = torch.export.export(M(), (torch.ones(2, 3),))
221221
res = ep.module()(torch.ones(2, 3))
222222
self.assertEqual(res[0][0], (1, 3, 1.2))
223223
program = to_edge(ep).to_executorch().executorch_program
224224
outputs = program.execution_plan[0].outputs
225-
self.assertEqual(len(outputs), 6)
225+
self.assertEqual(len(outputs), 7)
226226
self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
227227
self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
228228
self.assertEqual(
@@ -231,6 +231,7 @@ def forward(self, x):
231231
self.assertEqual(
232232
program.execution_plan[0].values[outputs[3]].val.bool_val, True
233233
)
234+
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
234235

235236
def test_int_list_input(self):
236237
class M(torch.nn.Module):

0 commit comments

Comments
 (0)