@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1795
1795
unsigned maxNumElementsToExtract = 0 ;
1796
1796
};
1797
1797
1798
+ // / Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
1799
+ // / into vector.outerproduct(A, B) such as :
1800
+ // / ```mlir
1801
+ // / %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1802
+ // / %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1803
+ // / vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1804
+ // / vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1805
+ // /```
1806
+ // / Becomes :
1807
+ // /```mlir
1808
+ // / %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1809
+ // /```
1810
+ // / Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
1811
+ // / %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1812
+ // / %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1813
+ // / %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1814
+
1815
+ template <typename MulOpType>
1816
+ struct ElementwiseToOuterproduct : public OpRewritePattern <MulOpType> {
1817
+ using OpRewritePattern<MulOpType>::OpRewritePattern;
1818
+
1819
+ LogicalResult matchAndRewrite (MulOpType mulOp,
1820
+ PatternRewriter &rewriter) const override {
1821
+ auto VT = llvm::cast<VectorType>(mulOp.getResult ().getType ());
1822
+ if (!VT)
1823
+ return failure ();
1824
+ if (VT.getRank () != 2 )
1825
+ return failure ();
1826
+
1827
+ auto canonicalize = [&](Value OperandA,
1828
+ Value OperandB) -> vector::OuterProductOp {
1829
+ vector::TransposeOp transposedLhs =
1830
+ dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp ());
1831
+ if (!transposedLhs)
1832
+ return vector::OuterProductOp ();
1833
+ // Fail unless this is a true 2-D matrix transpose.
1834
+ ArrayRef<int64_t > permutation = transposedLhs.getPermutation ();
1835
+ if (permutation[0 ] != 1 || permutation[1 ] != 0 )
1836
+ return vector::OuterProductOp ();
1837
+
1838
+ // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
1839
+ // generating shape_casts/broadcasts which do not belong in this pattern.
1840
+ vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
1841
+ transposedLhs.getVector ().getDefiningOp ());
1842
+ if (!broadcastedLhs ||
1843
+ !broadcastedLhs.computeBroadcastedUnitDims ().empty ())
1844
+ return vector::OuterProductOp ();
1845
+ // Avoid broadcast f32 or vector<f32> -> ResType
1846
+ auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType ());
1847
+ if (!srcVT || srcVT.getRank () != 1 )
1848
+ return vector::OuterProductOp ();
1849
+
1850
+ vector::BroadcastOp broadcastedRhs =
1851
+ dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp ());
1852
+ if (!broadcastedRhs || broadcastedRhs.getSourceType () != srcVT)
1853
+ return vector::OuterProductOp ();
1854
+
1855
+ return rewriter.replaceOpWithNewOp <vector::OuterProductOp>(
1856
+ mulOp, VT, broadcastedLhs.getSource (), broadcastedRhs.getSource (),
1857
+ Value (), vector::CombiningKind::ADD);
1858
+ };
1859
+ Value a = mulOp->getOperand (0 ), b = mulOp->getOperand (1 );
1860
+ vector::OuterProductOp outerP = canonicalize (a, b);
1861
+ // Handle commutativity, the transposed op is the outerproduct LHS.
1862
+ outerP = outerP ? outerP : canonicalize (b, a);
1863
+ return outerP ? success () : failure ();
1864
+ }
1865
+ };
1866
+
1798
1867
} // namespace
1799
1868
1800
1869
void mlir::vector::populateFoldArithExtensionPatterns (
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
1882
1951
maxNumElementsToExtract, benefit);
1883
1952
}
1884
1953
1954
+ void mlir::vector::populateElementwiseToVectorOpsPatterns (
1955
+ RewritePatternSet &patterns) {
1956
+ patterns.add <ElementwiseToOuterproduct<arith::MulFOp>,
1957
+ ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext ());
1958
+ }
1959
+
1885
1960
// ===----------------------------------------------------------------------===//
1886
1961
// TableGen'd enum attribute definitions
1887
1962
// ===----------------------------------------------------------------------===//
0 commit comments