@@ -337,6 +337,17 @@ def __init__(self, **kwargs):
337
337
self .src_partitions = None
338
338
self .linear_modules = [torch .nn .functional .linear , torch .nn .Linear ]
339
339
340
+ def _get_weight_deps (
341
+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
342
+ ) -> Tuple [bool , List [torch .fx .Node ]]:
343
+ # TODO(maxren, T210537195):
344
+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
345
+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
346
+ # do not partition the weight node
347
+ return (True , [])
348
+
349
+ return super ()._get_weight_deps (node , ep , precision )
350
+
340
351
def get_deps (
341
352
self ,
342
353
node : torch .fx .Node ,
@@ -436,6 +447,16 @@ def __init__(self, **kwargs):
436
447
self .weight_idx = 1
437
448
self .act_idx = 0
438
449
450
+ def _get_weight_deps (
451
+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
452
+ ) -> Tuple [bool , List [torch .fx .Node ]]:
453
+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
454
+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
455
+ # do not partition the weight node
456
+ return (True , [])
457
+
458
+ return super ()._get_weight_deps (node , ep , precision )
459
+
439
460
def supported_precision_types (self ):
440
461
return [
441
462
ConfigPrecisionType .FP32 ,
0 commit comments