Skip to content

Commit eef92a5

Browse files
ydwu4facebook-github-bot
authored andcommitted
Remove _deprecated_global_ns from cond
Summary: Remove _deprecated_global_ns from cond following pytorch/pytorch#104105. To do that, we need to change how graph_module generates python_code for "cond" target. Otherwise, it will generate target as "torch.ops.cond", which is invalid after the change. Will import this PR to fix internal tests. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 X-link: pytorch/pytorch#104380 Reviewed By: zou3519 Differential Revision: D47110919 Pulled By: ydwu4 fbshipit-source-id: eed2f7e0aa6bfc0d0a46f0064630abed872e3d75
1 parent 480e28b commit eef92a5

File tree

7 files changed

+22
-13
lines changed

7 files changed

+22
-13
lines changed

backends/test/test_backends_nested.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def process(gm):
115115
processed_bytes = ""
116116
for node in gm.graph.nodes:
117117
if node.op == "call_function":
118-
if node.target is torch.ops.cond:
118+
if node.target is torch.ops.higher_order.cond:
119119
_, true_gm, _ = _get_submodule(gm, node, 1)
120120
_, false_gm, _ = _get_submodule(gm, node, 2)
121121
processed_bytes += f"{node.target.__name__}({process(true_gm)},{process(false_gm)});"
@@ -134,7 +134,7 @@ def process(gm):
134134

135135
class CondOperatorSupport(OperatorSupportBase):
136136
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
137-
return node.op == "call_function" and node.target is torch.ops.cond
137+
return node.op == "call_function" and node.target is torch.ops.higher_order.cond
138138

139139

140140
@final
@@ -168,7 +168,10 @@ def partition(
168168
for partition in partition_list:
169169
for node in partition.nodes:
170170
delegation_tag = f"backend1_tag{partition.id}"
171-
if node.op == "call_function" and node.target is torch.ops.cond:
171+
if (
172+
node.op == "call_function"
173+
and node.target is torch.ops.higher_order.cond
174+
):
172175
# Tag the arguments that take in the submodules to cond
173176
# pyre-ignore
174177
node.args[1].meta["delegation_tag"] = delegation_tag

docs/website/docs/ir_spec/control_flow.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ EXIR has a couple of special operators used to help specify control flow within
44
some code, similar to jax's control flow operators. Currently these operators
55
are only supported for inference.
66

7-
## torch.ops.cond
7+
## torch.ops.higher_order.cond
88

99
The `cond` function represents an “if” statement in other programming languages.
1010
It can logically be seen as implemented as follows:
@@ -87,7 +87,7 @@ class GraphModule(torch.nn.Module):
8787
%gt : Sym(s0 > 4) = call_function[target=operator.gt](args = (%sym_size, 4), kwargs = {})
8888
%true_graph_0 = get_attr[target=true_graph_0]
8989
%false_graph_0 = get_attr[target=false_graph_0]
90-
%cond : f32[s0, s1] = call_function[target=torch.ops.cond](args = (%gt, %true_graph_0, %false_graph_0, [%arg0]), kwargs = {})
90+
%cond : f32[s0, s1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%arg0]), kwargs = {})
9191
return [cond]
9292
9393
# true_graph_0

exir/graph_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def get_control_flow_submodules(
299299
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
300300
"""
301301
Returns a list of submodules used for control flow operations
302-
(torch.ops.cond/map) that are in the given toplevel graph (does not look
302+
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
303303
into submodules). Specifically, the returned value is a list containing a
304304
tuple of (name of the submodule that's stored in the graph module, the
305305
submodule itself, and the fx node that uses this submodule).
@@ -309,7 +309,7 @@ def get_control_flow_submodules(
309309
if node.op != "call_function":
310310
continue
311311

312-
if node.target is torch.ops.cond:
312+
if node.target is torch.ops.higher_order.cond:
313313
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
314314
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
315315
if node.target is torch.ops.map_impl:

exir/pass_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def make_inline_interpreter(
5353
) -> Type[torch.fx.Interpreter]:
5454
class InlineInterpreter(parent):
5555
def call_function(self, target, args, kwargs):
56-
if target == torch.ops.cond:
56+
if target == torch.ops.higher_order.cond:
5757
pred, true, false, params = args
5858
return InlineInterpreter(true).run(*params)
5959
elif target == torch.ops.map:

exir/passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
314314
continue
315315

316316
target = node.target
317-
if target == control_flow.cond or target == torch.ops.cond:
317+
if target == control_flow.cond or target == torch.ops.higher_order.cond:
318318
self.call(get_submodule(node.args[1]))
319319
self.call(get_submodule(node.args[2]))
320320
continue

exir/tests/control_flow_models.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def true_branch(x):
1313
def false_branch(x):
1414
return x * x
1515

16-
return torch.ops.cond(inp.size(0) > 4, true_branch, false_branch, [inp])
16+
return torch.ops.higher_order.cond(
17+
inp.size(0) > 4, true_branch, false_branch, [inp]
18+
)
1719

1820
def get_random_inputs(self):
1921
return (torch.rand(5),)
@@ -30,7 +32,9 @@ def true_branch(x):
3032
def false_branch(x):
3133
return x * x * x
3234

33-
return torch.ops.cond(inp.size(0) > 4, true_branch, false_branch, [inp])
35+
return torch.ops.higher_order.cond(
36+
inp.size(0) > 4, true_branch, false_branch, [inp]
37+
)
3438

3539
def get_upper_bound_inputs(self):
3640
return (torch.rand(8),)
@@ -60,7 +64,9 @@ def true_branch(x):
6064
def false_branch(x):
6165
return x * 2
6266

63-
return torch.ops.cond(inp.size(0) > 4, true_branch, false_branch, [inp])
67+
return torch.ops.higher_order.cond(
68+
inp.size(0) > 4, true_branch, false_branch, [inp]
69+
)
6470

6571
def get_random_inputs(self):
6672
return (torch.eye(5) * 2,)

test/end2end/test_end2end.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ class DynamicModelE2ETest(unittest.TestCase):
749749
run_executor=False,
750750
)
751751

752-
# basic test for functorch torch.ops.cond
752+
# basic test for functorch torch.ops.higher_order.cond
753753
test_ft_cond_basic = maketest(
754754
FTCondBasic,
755755
capture_config=exir.CaptureConfig(

0 commit comments

Comments
 (0)