Skip to content

Commit cdbfa9c

Browse files
zou3519facebook-github-bot
authored andcommitted
Update how Dynamo decides to graph break on an OpOverloadPacket (#112200)
Summary: Previously, under config.only_allow_pt2_compliant_ops, Dynamo graph breaks when it see an OpOverloadPacket where any overloads are not PT2 compliant. This is potentially brittle: if someone (unlikely) adds a new overload for a custom operator, then this would cause a previously non-graph-breaking call to the OpOverloadPacket to graph break. In this PR: - When Dynamo is about to write a call to an operator to the FX graph, we check if it is PT2 compliant. - For OpOverload, we check to see if the tag is on it - For OpOverloadPacket, we do overload resolution and check to see if the tag is on the OpOverload that it resolves to. X-link: pytorch/pytorch#112200 Approved by: https://github.com/bdhirsh Reviewed By: ZainRizvi Differential Revision: D50873052 Pulled By: zou3519 fbshipit-source-id: c95b9797185f4bfc9edd2eb2e48f73b8dd56154f
1 parent 9221c79 commit cdbfa9c

File tree

1 file changed

+19
-11
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+19
-11
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,21 @@ def extract_fake_example_value(node, required=True):
13871387
return None
13881388

13891389

1390+
def ensure_graph_fake(e, tx):
1391+
assert maybe_get_fake_mode(e) is tx.fake_mode
1392+
return e
1393+
1394+
1395+
def get_fake_values_from_nodes(tx, nodes):
1396+
def visit(n: torch.fx.Node):
1397+
return n.meta["example_value"]
1398+
1399+
args_kwargs = torch.fx.node.map_arg(nodes, visit)
1400+
return tree_map_only(
1401+
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), args_kwargs
1402+
)
1403+
1404+
13901405
def get_fake_value(node, tx, allow_non_graph_fake=False):
13911406
"""
13921407
Run the computation represented by `node` using fake tensors and return the result.
@@ -1410,16 +1425,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
14101425
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
14111426
return node.meta["example_value"]
14121427

1413-
def ensure_graph_fake(e):
1414-
assert maybe_get_fake_mode(e) is tx.fake_mode
1415-
return e
1416-
1417-
def visit(n: torch.fx.Node):
1418-
return n.meta["example_value"]
1419-
1420-
args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit)
1421-
args = tree_map_only(torch.Tensor, ensure_graph_fake, args)
1422-
kwargs = tree_map_only(torch.Tensor, ensure_graph_fake, kwargs)
1428+
args, kwargs = get_fake_values_from_nodes(tx, (node.args, node.kwargs))
14231429

14241430
nnmodule = None
14251431
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
@@ -1483,7 +1489,9 @@ def visit(n: torch.fx.Node):
14831489
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
14841490

14851491
if not allow_non_graph_fake:
1486-
_ = tree_map_only(torch.Tensor, ensure_graph_fake, ret_val)
1492+
_ = tree_map_only(
1493+
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
1494+
)
14871495
return ret_val
14881496

14891497

0 commit comments

Comments
 (0)