@@ -1813,6 +1813,84 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1813
1813
unsigned maxNumElementsToExtract = 0 ;
1814
1814
};
1815
1815
1816
+ // / Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
1817
+ // / B)`.
1818
+ // / Example:
1819
+ // / %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1820
+ // / %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1821
+ // / vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1822
+ // / vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1823
+ // /
1824
+ // / Becomes :
1825
+ // /
1826
+ // / %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1827
+ // /
1828
+ // / Supports only 1D-to-2D broadcasts. The following cases are not supported.
1829
+ // / %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1830
+ // / %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1831
+ // / %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1832
+ template <typename MulOpType>
1833
+ struct FoldArithToVectorOuterProduct : public OpRewritePattern <MulOpType> {
1834
+ using OpRewritePattern<MulOpType>::OpRewritePattern;
1835
+ // Returns whether a vector.broadcast matches requirements for an outerproduct
1836
+ // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
1837
+ bool isValidBroadcastSource (vector::BroadcastOp broadcastOp) const {
1838
+ // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
1839
+ // shape_casts/broadcasts which does not belong in this pattern.
1840
+ if (!broadcastOp.computeBroadcastedUnitDims ().empty ())
1841
+ return false ;
1842
+ // Avoid broadcast like f32 or vector<f32> -> ResType
1843
+ auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
1844
+ return srcType && srcType.getRank () != 2 ;
1845
+ }
1846
+
1847
+ LogicalResult matchAndRewrite (MulOpType mulOp,
1848
+ PatternRewriter &rewriter) const override {
1849
+ auto resType = llvm::cast<VectorType>(mulOp.getResult ().getType ());
1850
+ if (!resType)
1851
+ return failure ();
1852
+ if (resType.getRank () != 2 )
1853
+ return failure ();
1854
+ // / If operandA can be written as tr(broadcast(A)) and operandB as
1855
+ // / broadcast(B) where broadcasts are 1D-to-2D, create and return
1856
+ // / vector.outerproduct(A, B). Returns failure() otherwise.
1857
+ auto matchOuterProduct =
1858
+ [&](Value operandA,
1859
+ Value operandB) -> FailureOr<vector::OuterProductOp> {
1860
+ auto transposedLhs = operandA.getDefiningOp <vector::TransposeOp>();
1861
+ if (!transposedLhs)
1862
+ return failure ();
1863
+ // Fail unless this is a true 2-D matrix transpose.
1864
+ ArrayRef<int64_t > permutation = transposedLhs.getPermutation ();
1865
+ if (permutation.size () != 2 || permutation[0 ] != 1 || permutation[1 ] != 0 )
1866
+ return failure ();
1867
+
1868
+ auto broadcastedLhs =
1869
+ transposedLhs.getVector ().getDefiningOp <vector::BroadcastOp>();
1870
+ if (!broadcastedLhs || !isValidBroadcastSource (broadcastedLhs))
1871
+ return failure ();
1872
+
1873
+ auto broadcastedRhs = operandB.getDefiningOp <vector::BroadcastOp>();
1874
+ if (!broadcastedRhs || !isValidBroadcastSource (broadcastedRhs))
1875
+ return failure ();
1876
+
1877
+ return rewriter.create <vector::OuterProductOp>(
1878
+ mulOp->getLoc (), resType, broadcastedLhs.getSource (),
1879
+ broadcastedRhs.getSource (), Value (), vector::CombiningKind::ADD);
1880
+ };
1881
+
1882
+ Value lhs = mulOp->getOperand (0 ), rhs = mulOp->getOperand (1 );
1883
+ auto maybeOuterP = matchOuterProduct (lhs, rhs);
1884
+ // Handle commutativity, the transposed op is the outerproduct LHS.
1885
+ if (failed (maybeOuterP))
1886
+ maybeOuterP = matchOuterProduct (rhs, lhs);
1887
+ if (failed (maybeOuterP))
1888
+ return failure ();
1889
+ rewriter.replaceOp (mulOp, maybeOuterP->getResult ());
1890
+ return success ();
1891
+ }
1892
+ };
1893
+
1816
1894
} // namespace
1817
1895
1818
1896
void mlir::vector::populateFoldArithExtensionPatterns (
@@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
1900
1978
maxNumElementsToExtract, benefit);
1901
1979
}
1902
1980
1981
+ void mlir::vector::populateElementwiseToVectorOpsPatterns (
1982
+ RewritePatternSet &patterns) {
1983
+ patterns.add <FoldArithToVectorOuterProduct<arith::MulFOp>,
1984
+ FoldArithToVectorOuterProduct<arith::MulIOp>>(
1985
+ patterns.getContext ());
1986
+ }
1987
+
1903
1988
// ===----------------------------------------------------------------------===//
1904
1989
// TableGen'd enum attribute definitions
1905
1990
// ===----------------------------------------------------------------------===//
0 commit comments