Skip to content

Commit b8c398a

Browse files
committed
(//core): Align with prim::Enter in module fallback
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 4ee9dbc commit b8c398a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

core/lowering/passes/module_fallback.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,17 @@ void NotateModuleForFallback(
4343
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
4444
<< " (" << cls_name << ")]");
4545
auto uses = n->output(0)->uses();
46+
int k = 0;
4647
for (const auto u : uses) {
48+
auto compilation_context_node = g->createNone();
49+
auto compilation_context = compilation_context_node->outputs()[0];
50+
compilation_context->setDebugName("compilation_context_" + std::to_string(k++));
4751
auto user = u.user;
48-
auto delim_start_n = g->create(torch::jit::prim::Enter, 0);
52+
auto delim_start_n = g->create(torch::jit::prim::Enter, {compilation_context});
4953
delim_start_n->s_(c10::Symbol::attr("compilation_edge"), "start");
50-
auto delim_end_n = g->create(torch::jit::prim::Exit, 0);
54+
auto delim_end_n = g->create(torch::jit::prim::Exit, {compilation_context});
5155
delim_end_n->s_(c10::Symbol::attr("compilation_edge"), "end");
56+
compilation_context_node->insertBefore(user);
5257
delim_start_n->insertBefore(user);
5358
delim_end_n->insertAfter(user);
5459
}

0 commit comments

Comments
 (0)