Skip to content

Commit baf07cf

Browse files
ydwu4facebook-github-bot
authored andcommitted
Replace exir cond with torch cond
Summary: This diff removes exir.control_flow.cond and replace its existing usage with torch.cond. Reviewed By: angelayi Differential Revision: D47924374 fbshipit-source-id: d296479ee1f708cb423a27feb55e258632238902
1 parent 750756b commit baf07cf

File tree

3 files changed

+9
-75
lines changed

3 files changed

+9
-75
lines changed

exir/control_flow.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -125,67 +125,6 @@ def _make_submodule(
125125
return gm
126126

127127

128-
def cond(
129-
pred: bool,
130-
true_fn: Callable[..., Tuple[torch.Tensor]],
131-
false_fn: Callable[..., Tuple[torch.Tensor]],
132-
inputs: pytree.PyTree,
133-
) -> Union[List[torch.Tensor], Value]:
134-
"""
135-
A higher order function returning result based on passed predicate
136-
value and conditionally execute one of true_fn and false_fn.
137-
138-
Detects whether a tracer is present in the context, and if so will
139-
trace_through both true_fn and false_fn with local inputs provided
140-
by tracing_context dictionary from the current tracer. When
141-
returning, wraps two traced graphs into a cond() call and construct
142-
a call_function node in the tracer's graph.
143-
144-
Checks and enforces that the returning value(s) from both
145-
branches has the same Tensor type. For now enforces that both
146-
branches have the same number of tensor inputs.
147-
"""
148-
flattened_inputs, _ = pytree.tree_flatten(inputs)
149-
150-
if not all([isinstance(i, torch.Tensor) for i in flattened_inputs]):
151-
raise ExportError(
152-
ExportErrorType.INVALID_INPUT_TYPE,
153-
f"control_flow.cond() expects all inputs values to be tensors, actual inputs: {inputs}",
154-
)
155-
156-
with using_tracer(None):
157-
outputs = true_fn(*inputs) if pred else false_fn(*inputs)
158-
159-
flattened_outputs, _ = pytree.tree_flatten(outputs)
160-
161-
if not all([isinstance(r, torch.Tensor) for r in flattened_outputs]):
162-
raise ExportError(
163-
ExportErrorType.INVALID_OUTPUT_TYPE,
164-
f"control_flow.cond() only supports tensors as output, actual output: {outputs}",
165-
)
166-
167-
tracer = DispatchTracer.get()
168-
169-
if tracer is None:
170-
return outputs
171-
172-
# Once global tracer is present, we need to assume all tensors are
173-
# PythonTensor wrapped with FunctionalTensorWrapper.
174-
175-
gm_true = _make_submodule(true_fn, example_returns=flattened_outputs)
176-
gm_false = _make_submodule(false_fn, example_returns=flattened_outputs)
177-
proxies = tuple([unwrap_proxy(i) for i in flattened_inputs])
178-
179-
proxy = tracer.create_proxy(
180-
"call_function",
181-
cond,
182-
(unwrap_proxy(pred), gm_true, gm_false, proxies),
183-
{},
184-
)
185-
186-
return tree_return(outputs, proxy, update_with_proxy)
187-
188-
189128
def while_loop(
190129
cond_fn: Callable[..., torch.Tensor],
191130
body_fn: Callable[..., Tuple[torch.Tensor]],

exir/passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
236236
# pyre-ignore
237237
to_out_var_skiplist: Set[Callable[[Any], Any]] = {
238238
_operator.getitem,
239-
control_flow.cond,
239+
torch.ops.higher_order.cond,
240240
control_flow.while_loop,
241241
# memory.alloc will be added after the to_out_variant pass so usually
242242
# we won't see it in the input graph to the to_out_variant pass, unless
@@ -321,7 +321,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
321321
continue
322322

323323
target = node.target
324-
if target == control_flow.cond or target == torch.ops.higher_order.cond:
324+
if target == torch.ops.higher_order.cond:
325325
self.call(get_submodule(node.args[1]))
326326
self.call(get_submodule(node.args[2]))
327327
continue

test/end2end/test_end2end.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from executorch.exir.tests.dynamic_shape_models import BatchNormModel
5757

5858
from executorch.exir.tests.transformer import Transformer
59+
from functorch.experimental.control_flow import cond
5960

6061
kernel_mode = None # either aten mode or lean mode
6162
try:
@@ -245,15 +246,13 @@ def addloop(x, n):
245246
out = out + x
246247
return out
247248

248-
@control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
249249
def true_branch(c, x):
250250
return addloop(x, 3)
251251

252-
@control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
253252
def false_branch(c, x):
254253
return addloop(x, 4)
255254

256-
y = control_flow.cond(c, true_branch, false_branch, (c, x))
255+
y = cond(c, true_branch, false_branch, (c, x))
257256
return y * y
258257

259258
def get_random_inputs(self):
@@ -273,18 +272,14 @@ def addloop(x, n):
273272
out = out + x
274273
return out
275274

276-
@control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
277275
def true_branch(c, x):
278276
return addloop(x, 3)
279277

280-
@control_flow.tracing_context(inputs=(torch.randn(1), torch.randn(10)))
281278
def false_branch(c, x):
282279
return addloop(x, 4)
283280

284-
# pyre-fixme[6]: Incompatible parameter type
285-
y = control_flow.cond(c, true_branch, false_branch, (c, x))
281+
y = cond(c, true_branch, false_branch, (c, x))
286282

287-
# pyre-fixme[58]: Unsupported operand type for binary operator '*'
288283
return y * y
289284

290285
def get_random_inputs(self):
@@ -319,7 +314,7 @@ def true_branch(cnt):
319314
def false_branch(cnt):
320315
return torch.zeros([1], dtype=torch.long)
321316

322-
accum = accum + control_flow.cond(
317+
accum = accum + cond(
323318
torch.BoolTensor([True]), true_branch, false_branch, (cnt,)
324319
)
325320
# 'cnt - 1' does not work yet since the runtime does not expect
@@ -372,9 +367,9 @@ def loop_body(accum, cnt):
372367
def false_branch(accum, cnt):
373368
return accum, cnt
374369

375-
return control_flow.cond(
376-
torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt)
377-
)[0]
370+
return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[
371+
0
372+
]
378373

379374
def get_random_inputs(self):
380375
return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))

0 commit comments

Comments
 (0)