Skip to content

Commit cd472bb

Browse files
dulinrileypytorchmergebot
authored andcommitted
[torch][fx] Add new replacement_callback to materialize a replacement just in time (pytorch#135553)
Summary: Sometimes we only want to generate a replacement for a matched pattern once we know some information about the nodes in the pattern. So far, we have found this the most useful to do matches based on specific shapes of tensors flowing into functions. Use a callback function similar to `match_filters`. By default this isn't used. Had to make `replacement` a None-able parameter because Callable was already used to detect a case where a graph needed to be traced. Differential Revision: D62412628 Pull Request resolved: pytorch#135553 Approved by: https://github.com/SherlockNoMad
1 parent f032135 commit cd472bb

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

test/fx/test_subgraph_rewriter.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,3 +980,38 @@ def check_replacement_nodes(self, traced, matches):
980980
return len(replacement_nodes_in_graph)
981981

982982
self.assertEqual(check_replacement_nodes(self, traced, matches), 2)
983+
984+
def test_replace_pattern_with_callback(self) -> None:
985+
class M(torch.nn.Module):
986+
def forward(self, x, y):
987+
return torch.add(x, y)
988+
989+
def pattern(x, y):
990+
return torch.add(x, y)
991+
992+
def replacement(x, y):
993+
return torch.sub(torch.mul(x, y), y)
994+
995+
traced = symbolic_trace(M())
996+
# Return the same replacement graph for all matches, but have it be a unique
997+
# object each time.
998+
matches = subgraph_rewriter.replace_pattern_with_filters(
999+
traced,
1000+
pattern,
1001+
replacement_callback=lambda *args: symbolic_trace(replacement).graph,
1002+
)
1003+
1004+
def check_replacement_nodes(self, traced, matches):
1005+
replacement_nodes_in_graph = [
1006+
node
1007+
for node in traced.graph.nodes
1008+
if node.target in {torch.sub, torch.mul}
1009+
]
1010+
replacement_nodes_in_res = [r for m in matches for r in m.replacements]
1011+
self.assertEqual(
1012+
len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
1013+
)
1014+
self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
1015+
return len(replacement_nodes_in_graph)
1016+
1017+
self.assertEqual(check_replacement_nodes(self, traced, matches), 2)

torch/fx/subgraph_rewriter.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,11 @@ def forward(self, x, w1, w2):
207207
def replace_pattern_with_filters(
208208
gm: GraphModule,
209209
pattern: Union[Callable, Graph, GraphModule],
210-
replacement: Union[Callable, Graph, GraphModule],
210+
replacement: Union[Callable, Graph, GraphModule, None] = None,
211211
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
212212
ignore_literals: bool = False,
213+
# Placed at the end to avoid breaking backward compatibility
214+
replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None,
213215
) -> List[ReplacedPatterns]:
214216
"""
215217
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@@ -219,17 +221,22 @@ def replace_pattern_with_filters(
219221
(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
220222
whether the match satisfies the condition.
221223
See matcher_utils.py for definition of InternalMatch.
224+
``replacement_callback``: A function that takes in a match and returns a
225+
Graph to be used as the replacement. This allows you to construct a
226+
replacement graph based on the match.
222227
"""
223228

224-
return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals)
229+
return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback)
225230

226231

227232
def _replace_pattern(
228233
gm: GraphModule,
229234
pattern: Union[Callable, Graph, GraphModule],
230-
replacement: Union[Callable, Graph, GraphModule],
235+
replacement: Union[Callable, Graph, GraphModule, None] = None,
231236
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
232237
ignore_literals: bool = False,
238+
# Placed at the end to avoid breaking backward compatibility
239+
replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None,
233240
) -> List[ReplacedPatterns]:
234241

235242
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
@@ -247,13 +254,6 @@ def _replace_pattern(
247254
else:
248255
pattern_graph = symbolic_trace(pattern).graph
249256

250-
if isinstance(replacement, GraphModule):
251-
replacement_graph = replacement.graph
252-
elif isinstance(replacement, Graph):
253-
replacement_graph = replacement
254-
else:
255-
replacement_graph = symbolic_trace(replacement).graph
256-
257257
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
258258
remove_overlapping_matches=True, ignore_literals=ignore_literals)
259259
_matches: List[InternalMatch] = matcher.match(original_graph)
@@ -265,13 +265,27 @@ def _replace_pattern(
265265
for match_filter in match_filters)
266266
]
267267

268-
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
268+
if isinstance(replacement, GraphModule):
269+
common_replacement_graph = replacement.graph
270+
elif isinstance(replacement, Graph):
271+
common_replacement_graph = replacement
272+
elif callable(replacement):
273+
common_replacement_graph = symbolic_trace(replacement).graph
274+
else:
275+
assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback"
276+
common_replacement_graph = None
269277

270278
# As we progressively replace nodes, we'll need to keep track of how the match results should change
271279
match_changed_node: Dict[Node, Node] = {}
272280

273281
match_and_replacements = []
274-
for match in _matches:
282+
for i, match in enumerate(_matches):
283+
if replacement_callback is not None:
284+
replacement_graph = replacement_callback(match, original_graph, pattern_graph)
285+
else:
286+
assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback"
287+
replacement_graph = common_replacement_graph
288+
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
275289

276290
# Build connecting between replacement graph's input and original graph input producer node
277291

0 commit comments

Comments
 (0)