29
29
format_target_name ,
30
30
)
31
31
from torch .export import ExportedProgram
32
+ from torch .fx .passes .utils .source_matcher_utils import (
33
+ get_source_partitions ,
34
+ SourcePartition ,
35
+ )
32
36
33
37
34
38
class GEMMConfig (XNNPartitionerConfig ):
@@ -52,20 +56,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
52
56
# short circuit if we don't pass common constraints
53
57
return False
54
58
55
- precision = self ._detect_precision (node )
56
- if precision not in self .enabled_precision_types :
57
- # detected precision but it is either disabled or not supported
58
- return False
59
-
60
- is_valid , _ = self .get_deps (node , ep , precision )
59
+ is_valid , _ = self .get_deps (node , ep )
61
60
return is_valid
62
61
63
62
def get_node_and_deps (
64
63
self , node : torch .fx .Node , ep : ExportedProgram
65
64
) -> List [torch .fx .Node ]:
66
65
partition = [node ]
67
- precision = self ._detect_precision (node )
68
- _ , deps = self .get_deps (node , ep , precision )
66
+ _ , deps = self .get_deps (node , ep )
69
67
partition .extend (deps )
70
68
71
69
return partition
@@ -86,13 +84,20 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
86
84
return ConfigPrecisionType .STATIC_QUANT
87
85
88
86
def get_deps (
89
- self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
87
+ self ,
88
+ node : torch .fx .Node ,
89
+ ep : ExportedProgram ,
90
90
) -> Tuple [bool , List [torch .fx .Node ]]:
91
91
"""
92
92
Gets all dependencies for this gemm partition. Returns a tuple of
93
93
a bool indicating if the deps are valid and a list of all the
94
94
dep nodes
95
95
"""
96
+ precision = self ._detect_precision (node )
97
+ if precision not in self .supported_precision_types ():
98
+ # detected precision but it is either disabled or not supported
99
+ return (False , [])
100
+
96
101
valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
97
102
valid_weight , weight_deps = self ._get_weight_deps (node , ep , precision )
98
103
valid_act , act_deps = self ._get_act_deps (node , ep , precision )
@@ -178,7 +183,7 @@ def _get_bias_deps(
178
183
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
179
184
) -> Tuple [bool , List [torch .fx .Node ]]:
180
185
gemm_deps = []
181
- if len (node .all_input_nodes ) > 2 :
186
+ if len (node .all_input_nodes ) > 2 and self . bias_idx :
182
187
bias_node = get_input_node (node , self .bias_idx )
183
188
if bias_node :
184
189
if not is_param_node (ep , bias_node ):
@@ -251,11 +256,16 @@ def supported_precision_types(self):
251
256
]
252
257
253
258
254
- class AddmmConfig (GEMMConfig ):
255
- target_name = "addmm .default"
259
+ class ConvolutionConfig (GEMMConfig ):
260
+ target_name = "convolution .default"
256
261
257
262
def __init__ (self ):
258
- super ().__init__ (weight_idx = 2 , bias_idx = 0 , act_idx = 1 , fused_acts = [])
263
+ super ().__init__ (
264
+ weight_idx = 1 ,
265
+ bias_idx = 2 ,
266
+ act_idx = 0 ,
267
+ fused_acts = ["relu.default" , "hardtanh.default" ],
268
+ )
259
269
260
270
def supported_precision_types (self ):
261
271
return [
@@ -264,19 +274,126 @@ def supported_precision_types(self):
264
274
]
265
275
266
276
267
- class ConvolutionConfig (GEMMConfig ):
268
- target_name = "convolution.default"
277
+ class AddmmConfig (GEMMConfig ):
278
+ """
279
+ We will handle the legacy form of addmm partitioning which will include
280
+ partitioning using source partitions.
281
+ """
282
+
283
+ target_name = "addmm.default"
269
284
270
285
def __init__ (self ):
271
286
super ().__init__ (
272
- weight_idx = 1 ,
273
- bias_idx = 2 ,
274
- act_idx = 0 ,
287
+ weight_idx = 2 ,
288
+ bias_idx = 0 ,
289
+ act_idx = 1 ,
275
290
fused_acts = ["relu.default" , "hardtanh.default" ],
276
291
)
292
+ self .src_partitions = None
293
+ self .linear_modules = [torch .nn .functional .linear , torch .nn .Linear ]
294
+
295
+ def get_deps (
296
+ self ,
297
+ node : torch .fx .Node ,
298
+ ep : ExportedProgram ,
299
+ ) -> Tuple [bool , List [torch .fx .Node ]]:
300
+ """
301
+ Gets all dependencies for this gemm partition. Returns a tuple of
302
+ a bool indicating if the deps are valid and a list of all the
303
+ dep nodes. This handles the src partition for
304
+ """
305
+ if self .src_partitions is None :
306
+ # Cache src partitions so we don't have to recompute them every time
307
+ self .src_partitions = get_source_partitions (ep .graph , self .linear_modules )
308
+
309
+ # src_partition is None if node is not in source partition,
310
+ # otherwise gives us the linear source partition it belongs to
311
+ src_partition = None
312
+ for partition_list in self .src_partitions .values ():
313
+ for partition in partition_list :
314
+ if node in partition .nodes :
315
+ src_partition = partition
316
+
317
+ if src_partition :
318
+ # if addmm belongs to linear src partition, then partition the
319
+ # src partition and get its deps
320
+ return self .get_deps_from_src_partition (node , ep , src_partition )
321
+
322
+ return super ().get_deps (node , ep )
323
+
324
+ def get_deps_from_src_partition (
325
+ self , node : torch .fx .Node , ep : ExportedProgram , src_partition : SourcePartition
326
+ ):
327
+ """
328
+ Gets all the dependencies for the src partition. This is done by simulating the
329
+ linear node from the src partition. We find the associated weights, act, bias
330
+ from the linear src partition, and plug those in as the addmm node's args. We also
331
+ take the users of the src partitions output node as the addmm node's users. Finally
332
+ we just run the GEMMConfig's get_deps method no this faked linear node. After
333
+ getting the deps, we return the addmm nodes users and args back.
334
+ """
335
+
336
+ def find_partition_args (input_node ):
337
+ while (
338
+ len (input_node .all_input_nodes ) != 0
339
+ and input_node not in src_partition .input_nodes
340
+ ):
341
+ input_node = input_node .all_input_nodes [0 ]
342
+ return input_node
343
+
344
+ old_args , old_users = node .args , node .users
345
+
346
+ fake_args = []
347
+ for arg in node .args :
348
+ # map addmm's args to the source partition's inputs
349
+ # basically simulating what the args of the linear node would be
350
+ fake_args .append (find_partition_args (arg ))
351
+
352
+ # validate source partition
353
+ if (
354
+ # bias must be in source partition
355
+ (self .bias_idx and fake_args [self .bias_idx ] not in src_partition .nodes )
356
+ # activation input must be an input node to partition
357
+ or fake_args [self .act_idx ] not in src_partition .input_nodes
358
+ # weight can either be in the nodes or input_nodes
359
+ or fake_args [self .weight_idx ]
360
+ not in (src_partition .nodes + src_partition .input_nodes )
361
+ # there can only be a single output node in partition
362
+ or len (src_partition .output_nodes ) != 1
363
+ ):
364
+ return (False , [])
365
+
366
+ # map addmm's args to the source partition linear's inputs and users
367
+ node .args = tuple (fake_args )
368
+ node .users = src_partition .output_nodes [0 ].users
369
+ valid_deps , deps = super ().get_deps (node , ep )
370
+
371
+ # Reset addmm node back to old args and users
372
+ node .args = old_args
373
+ node .users = old_users
374
+
375
+ return valid_deps , list (set (deps ) | set (src_partition .nodes ))
277
376
278
377
def supported_precision_types (self ):
279
378
return [
280
379
ConfigPrecisionType .FP32 ,
281
380
ConfigPrecisionType .STATIC_QUANT ,
381
+ ConfigPrecisionType .DYNAMIC_QUANT ,
382
+ ]
383
+
384
+
385
+ class MMConfig (AddmmConfig ):
386
+ target_name = "mm.default"
387
+
388
+ def __init__ (self ):
389
+ super ().__init__ ()
390
+ self .bias_idx = None
391
+ self .weight_idx = 1
392
+ self .act_idx = 0
393
+
394
+ def supported_precision_types (self ):
395
+ return [
396
+ ConfigPrecisionType .FP32 ,
397
+ ConfigPrecisionType .STATIC_QUANT ,
398
+ ConfigPrecisionType .DYNAMIC_QUANT ,
282
399
]
0 commit comments