Skip to content

Commit 7924942

Browse files
Allow addmm and mm to call dynamic fp32 kernels Xnnpack
Differential Revision: D66898281 Pull Request resolved: #7232
1 parent 3fcf0bd commit 7924942

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,17 @@ def __init__(self, **kwargs):
337337
self.src_partitions = None
338338
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
339339

340+
def _get_weight_deps(
341+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
342+
) -> Tuple[bool, List[torch.fx.Node]]:
343+
# TODO(maxren, T210537195):
344+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
345+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
346+
# do not partition the weight node
347+
return (True, [])
348+
349+
return super()._get_weight_deps(node, ep, precision)
350+
340351
def get_deps(
341352
self,
342353
node: torch.fx.Node,
@@ -436,6 +447,16 @@ def __init__(self, **kwargs):
436447
self.weight_idx = 1
437448
self.act_idx = 0
438449

450+
def _get_weight_deps(
451+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
452+
) -> Tuple[bool, List[torch.fx.Node]]:
453+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
454+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
455+
# do not partition the weight node
456+
return (True, [])
457+
458+
return super()._get_weight_deps(node, ep, precision)
459+
439460
def supported_precision_types(self):
440461
return [
441462
ConfigPrecisionType.FP32,

0 commit comments

Comments
 (0)