You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix: Bugfix in Linear-to-AddMM Fusion lowering pass
- Fix 2 bugs in linear-to-addmm lowering pass:
- Lowering pass did not explore nested sub-blocks of a node, of the
sort contained in `prim::If` when `bias=None`
- Lowering pass did not insert fused linear code inside sub-blocks of
`prim::If` even when the original function call occurred within such a
block
- The latter causes issues when the control-flow switches between two
versions of `aten::linear`, only one of which is a valid operation.
Thus, evaluating both branches can cause compilation to crash, as
invalid Tensor shapes can be encountered
- Update implementation to run recursively through all nested blocks
within all nodes
- Update implementation to remove the use of `RegisterRewritePattern`
paradigm for Tensor biases, as the rewrite does not always place the
subgraph in the desired location
- Add regression test cases to isolate both bugs
0 commit comments