@@ -207,9 +207,11 @@ def forward(self, x, w1, w2):
207
207
def replace_pattern_with_filters (
208
208
gm : GraphModule ,
209
209
pattern : Union [Callable , Graph , GraphModule ],
210
- replacement : Union [Callable , Graph , GraphModule ] ,
210
+ replacement : Union [Callable , Graph , GraphModule , None ] = None ,
211
211
match_filters : Optional [List [Callable [["InternalMatch" , Graph , Graph ], bool ]]] = None ,
212
212
ignore_literals : bool = False ,
213
+ # Placed at the end to avoid breaking backward compatibility
214
+ replacement_callback : Optional [Callable [["InternalMatch" , Graph , Graph ], Graph ]] = None ,
213
215
) -> List [ReplacedPatterns ]:
214
216
"""
215
217
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@@ -219,17 +221,22 @@ def replace_pattern_with_filters(
219
221
(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
220
222
whether the match satisfies the condition.
221
223
See matcher_utils.py for definition of InternalMatch.
224
+ ``replacement_callback``: A function that takes in a match and returns a
225
+ Graph to be used as the replacement. This allows you to construct a
226
+ replacement graph based on the match.
222
227
"""
223
228
224
- return _replace_pattern (gm , pattern , replacement , match_filters , ignore_literals )
229
+ return _replace_pattern (gm , pattern , replacement , match_filters , ignore_literals , replacement_callback )
225
230
226
231
227
232
def _replace_pattern (
228
233
gm : GraphModule ,
229
234
pattern : Union [Callable , Graph , GraphModule ],
230
- replacement : Union [Callable , Graph , GraphModule ] ,
235
+ replacement : Union [Callable , Graph , GraphModule , None ] = None ,
231
236
match_filters : Optional [List [Callable [["InternalMatch" , Graph , Graph ], bool ]]] = None ,
232
237
ignore_literals : bool = False ,
238
+ # Placed at the end to avoid breaking backward compatibility
239
+ replacement_callback : Optional [Callable [["InternalMatch" , Graph , Graph ], Graph ]] = None ,
233
240
) -> List [ReplacedPatterns ]:
234
241
235
242
from torch .fx .passes .utils .matcher_utils import SubgraphMatcher , InternalMatch
@@ -247,13 +254,6 @@ def _replace_pattern(
247
254
else :
248
255
pattern_graph = symbolic_trace (pattern ).graph
249
256
250
- if isinstance (replacement , GraphModule ):
251
- replacement_graph = replacement .graph
252
- elif isinstance (replacement , Graph ):
253
- replacement_graph = replacement
254
- else :
255
- replacement_graph = symbolic_trace (replacement ).graph
256
-
257
257
matcher = SubgraphMatcher (pattern_graph , match_output = False , match_placeholder = False ,
258
258
remove_overlapping_matches = True , ignore_literals = ignore_literals )
259
259
_matches : List [InternalMatch ] = matcher .match (original_graph )
@@ -265,13 +265,27 @@ def _replace_pattern(
265
265
for match_filter in match_filters )
266
266
]
267
267
268
- replacement_placeholders = [n for n in replacement_graph .nodes if n .op == "placeholder" ]
268
+ if isinstance (replacement , GraphModule ):
269
+ common_replacement_graph = replacement .graph
270
+ elif isinstance (replacement , Graph ):
271
+ common_replacement_graph = replacement
272
+ elif callable (replacement ):
273
+ common_replacement_graph = symbolic_trace (replacement ).graph
274
+ else :
275
+ assert replacement_callback is not None , "Must provide either a replacement GraphModule or a replacement callback"
276
+ common_replacement_graph = None
269
277
270
278
# As we progressively replace nodes, we'll need to keep track of how the match results should change
271
279
match_changed_node : Dict [Node , Node ] = {}
272
280
273
281
match_and_replacements = []
274
- for match in _matches :
282
+ for i , match in enumerate (_matches ):
283
+ if replacement_callback is not None :
284
+ replacement_graph = replacement_callback (match , original_graph , pattern_graph )
285
+ else :
286
+ assert common_replacement_graph is not None , "Must provide either a replacement GraphModule or a replacement callback"
287
+ replacement_graph = common_replacement_graph
288
+ replacement_placeholders = [n for n in replacement_graph .nodes if n .op == "placeholder" ]
275
289
276
290
# Build connecting between replacement graph's input and original graph input producer node
277
291
0 commit comments