Skip to content

Commit 96aa07c

Browse files
yf225xuhancn
authored andcommitted
[Inductor][PatternMatcher] Always prevent match across mutations (pytorch#130584)
Preventing match across mutations should always be the safe thing to do. This will be especially important for Traceable FSDP2 because in that case we do have mutation ops (`.set_` and `.resize_(0)`) in the middle of the graph for both joint-graph and post-grad graph, so making sure the pattern matcher passes work well with middle-of-graph mutation ops is important. Q: Why can't we move these mutation ops to the end of graph, to make pass writing easier? A: We attempted to do that in pytorch#129852, but the custom FX passes (in `torch/_functorch/_aot_autograd/fx_passes.py`) for the re-functionalization is complicated to maintain, and the changes to partitioner (in `torch/_functorch/partitioners.py`) also feels hacky. Hence we want to preserve these mutation ops in the middle of graph to avoid the complexity. Test commands: - `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_uint4x2_mixed_mm` - `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_serialized_patterns_up_to_date` Pull Request resolved: pytorch#130584 Approved by: https://github.com/jansel
1 parent 22ab428 commit 96aa07c

File tree

5 files changed

+78
-37
lines changed

5 files changed

+78
-37
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Arg,
1717
CallFunction,
1818
gen_pattern,
19+
is_mutation_op,
1920
KeywordArg,
2021
Match,
2122
PatternMatcherPass,
@@ -1000,9 +1001,7 @@ def foo(x, y):
10001001

10011002
def test_match_with_mutation(self):
10021003
counter = 0
1003-
test_pass = PatternMatcherPass(
1004-
prevent_match_across_mutations=True, pass_name="test"
1005-
)
1004+
test_pass = PatternMatcherPass(pass_name="test")
10061005

10071006
@register_graph_pattern(
10081007
CallFunction(
@@ -1159,7 +1158,7 @@ def remap_fake_tensor(x):
11591158

11601159
def test_match_equivalent_function_invocations1(self):
11611160
counter = 0
1162-
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
1161+
test_pass = PatternMatcherPass()
11631162

11641163
args = [
11651164
torch.randn(20, device="cuda"),
@@ -1215,7 +1214,7 @@ def repl(inp, x1, x2):
12151214

12161215
def test_match_equivalent_function_invocations2(self):
12171216
counter = 0
1218-
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
1217+
test_pass = PatternMatcherPass()
12191218

12201219
args = [
12211220
torch.randn(20, device="cuda"),
@@ -1260,7 +1259,7 @@ def repl(inp, x1, x2):
12601259

12611260
def test_match_equivalent_function_invocations3(self):
12621261
counter = 0
1263-
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
1262+
test_pass = PatternMatcherPass()
12641263

12651264
args = [
12661265
torch.randn(20, device="cuda"),
@@ -1371,6 +1370,61 @@ def div_softmax(x, inv_scale):
13711370
self.common(mul_softmax, (scale, x), 0, 0)
13721371
self.common(div_softmax, (x, scale), 0, 0)
13731372

1373+
def test_mutation_op_matching(self):
1374+
def check(type, func_name, args, kwargs, expect=True):
1375+
assert type in ["call_function", "call_method"]
1376+
graph = torch.fx.Graph()
1377+
getattr(graph, type)(func_name, args, kwargs)
1378+
res = is_mutation_op(next(iter(graph.nodes)))
1379+
if expect:
1380+
self.assertTrue(res)
1381+
else:
1382+
self.assertFalse(res)
1383+
1384+
t = torch.randn(1)
1385+
check("call_function", torch._C._set_grad_enabled, (False,), {})
1386+
check("call_method", "copy_", (t, t), {})
1387+
check("call_method", "relu_", (t,), {})
1388+
check("call_function", torch.manual_seed, (0,), {})
1389+
check("call_function", torch.ops.aten.set_.source_Tensor, (t, t), {})
1390+
check(
1391+
"call_function",
1392+
torch.amp.autocast_mode._enter_autocast,
1393+
("cuda", None, True, None),
1394+
{},
1395+
)
1396+
check("call_function", torch.amp.autocast_mode._exit_autocast, (None,), {})
1397+
check(
1398+
"call_function",
1399+
torch.ops._c10d_functional.all_gather_into_tensor_out,
1400+
(t, 2, "0"),
1401+
{"out": t},
1402+
)
1403+
check("call_function", torch.ops.inductor.resize_storage_bytes_, (t, 0), {})
1404+
check(
1405+
"call_function",
1406+
torch.ops.inductor.resize_storage_bytes_.default,
1407+
(t, 0),
1408+
{},
1409+
)
1410+
check(
1411+
"call_function",
1412+
torch.ops.fsdp.split_with_sizes_copy,
1413+
(t, [64, 128, 8, 8]),
1414+
{"dim": 1, "out": [t, t, t, t]},
1415+
)
1416+
check("call_function", torch.ops.fsdp.set_, (t, t), {})
1417+
check(
1418+
"call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False
1419+
)
1420+
check(
1421+
"call_function",
1422+
torch.ops._c10d_functional.all_gather_into_tensor,
1423+
(t, 2, "0"),
1424+
{},
1425+
expect=False,
1426+
)
1427+
13741428

13751429
if __name__ == "__main__":
13761430
if IS_LINUX and HAS_CUDA:

torch/_inductor/fx_passes/b2b_gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def b2b_gemm_grid(M, P, meta):
100100

101101

102102
B2B_GEMM_PASS = PatternMatcherPass(
103-
prevent_match_across_mutations=True,
104103
pass_name="b2b_gemm_pass",
105104
)
106105

torch/_inductor/fx_passes/pre_grad.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,28 @@
3333
log = logging.getLogger(__name__)
3434

3535
efficient_conv_bn_eval_pass = PatternMatcherPass(
36-
prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass"
36+
pass_name="efficient_conv_bn_eval_pass"
3737
)
3838

3939
fuse_split_linear_add_pass = PatternMatcherPass(
40-
prevent_match_across_mutations=True,
4140
pass_name="fuse_split_linear_add_pass",
4241
)
4342
fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
44-
prevent_match_across_mutations=True,
4543
pass_name="fuse_chunk_squeeze_cat_pass",
4644
)
4745
remove_reshape_pass = PatternMatcherPass(
48-
prevent_match_across_mutations=True,
4946
pass_name="remove_reshape_pass",
5047
)
5148

5249
# based on predispatch aten IR
53-
normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
54-
merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
55-
split_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
56-
unbind_stack_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
57-
merge_getitem_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
58-
merge_stack_tahn_unbind_pass_aten = PatternMatcherPass(
59-
prevent_match_across_mutations=True
60-
)
61-
mutate_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
62-
remove_split_with_size_one_pass_aten = PatternMatcherPass(
63-
prevent_match_across_mutations=True
64-
)
50+
normalization_pass_aten = PatternMatcherPass()
51+
merge_splits_pass_aten = PatternMatcherPass()
52+
split_cat_pass_aten = PatternMatcherPass()
53+
unbind_stack_pass_aten = PatternMatcherPass()
54+
merge_getitem_cat_pass_aten = PatternMatcherPass()
55+
merge_stack_tahn_unbind_pass_aten = PatternMatcherPass()
56+
mutate_cat_pass_aten = PatternMatcherPass()
57+
remove_split_with_size_one_pass_aten = PatternMatcherPass()
6558

6659

6760
def save_inductor_dict(pass_to_compare=None):

torch/_inductor/fx_passes/split_cat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
if pass_name in PRE_GRAD_FUSIONS:
6868
continue
6969
PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
70-
prevent_match_across_mutations=True,
7170
pass_name=pass_name,
7271
)
7372

@@ -77,7 +76,6 @@
7776
if pass_name in POST_GRAD_FUSIONS:
7877
continue
7978
POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
80-
prevent_match_across_mutations=True,
8179
pass_name=pass_name,
8280
)
8381

torch/_inductor/pattern_matcher.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,8 +1600,9 @@ def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
16001600
return node is next(iter(graph.nodes))
16011601

16021602

1603-
# match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc
1604-
_mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)")
1603+
# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc
1604+
# doesn't match: __rshift__, etc
1605+
_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
16051606

16061607

16071608
def is_mutation_op(node: torch.fx.Node) -> bool:
@@ -1642,14 +1643,12 @@ def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
16421643
class PatternMatcherPass:
16431644
def __init__(
16441645
self,
1645-
prevent_match_across_mutations: bool = False,
16461646
pass_name: Optional[str] = None,
16471647
) -> None:
16481648
super().__init__()
16491649
self.patterns: DefaultDict[
16501650
Tuple[str, torch.fx.node.Target], List[PatternEntry]
16511651
] = defaultdict(list)
1652-
self.prevent_match_across_mutations = prevent_match_across_mutations
16531652
self.pass_name = pass_name
16541653

16551654
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
@@ -1667,12 +1666,11 @@ def apply(self, gm: torch.fx.GraphModule) -> int:
16671666
raise RuntimeError(
16681667
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
16691668
)
1670-
if self.prevent_match_across_mutations:
1671-
if should_compute_mutation_region_ids(graph):
1672-
compute_mutation_region_ids(graph)
1673-
get_mutation_region_id_partial = functools.partial(
1674-
get_mutation_region_id, graph
1675-
)
1669+
if should_compute_mutation_region_ids(graph):
1670+
compute_mutation_region_ids(graph)
1671+
get_mutation_region_id_partial = functools.partial(
1672+
get_mutation_region_id, graph
1673+
)
16761674
count = 0
16771675
nodes = []
16781676
has_call_module = False
@@ -1705,8 +1703,7 @@ def apply(self, gm: torch.fx.GraphModule) -> int:
17051703
m = entry.pattern.match(node)
17061704
# pattern match crosses mutation barrier - discard
17071705
if (
1708-
self.prevent_match_across_mutations
1709-
and is_match(m)
1706+
is_match(m)
17101707
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
17111708
):
17121709
continue

0 commit comments

Comments
 (0)