Skip to content

Commit aadcd64

Browse files
nujaahanhanW
authored andcommitted
[mlir][vector] Add ElementwiseToOuterproduct (llvm#93664)
1D multi-reduction are lowered to arith which can prevent some optimisations. I propose `ElementwiseToOuterproduct` matching a series of ops to generate `vector.outerproduct`. As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse other elementwiseOps to vector dialect. Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24. quote @MacDue ``` %lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32> %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32> %rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32> %mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32> ``` Can be rewritten as: ``` %mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32> ``` --------- Co-authored-by: Han-Chung Wang <[email protected]>
1 parent 6484c44 commit aadcd64

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
8080
/// into vector contract for the backends with native support.
8181
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
8282

83+
/// Collect a set of patterns that fold elementwise op on vectors to the vector
84+
/// dialect.
85+
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
86+
8387
/// Returns the integer type required for subscripts in the vector dialect.
8488
IntegerType getVectorSubscriptType(Builder &builder);
8589

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
406406
let assemblyFormat = "attr-dict";
407407
}
408408

409+
def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
410+
"apply_patterns.vector.elementwise_to_vector",
411+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
412+
let description = [{
413+
Collect a set of patterns that fold elementwise op on vectors to the vector
414+
dialect.
415+
}];
416+
417+
let assemblyFormat = "attr-dict";
418+
}
419+
409420
def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
410421
"apply_patterns.vector.reduction_to_contract",
411422
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
5959
vector::populateFoldArithExtensionPatterns(patterns);
6060
}
6161

62+
void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
63+
RewritePatternSet &patterns) {
64+
vector::populateElementwiseToVectorOpsPatterns(patterns);
65+
}
66+
6267
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
6368
RewritePatternSet &patterns) {
6469
vector::populateVectorReductionToContractPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,6 +1813,84 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
18131813
unsigned maxNumElementsToExtract = 0;
18141814
};
18151815

1816+
/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
1817+
/// B)`.
1818+
/// Example:
1819+
/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1820+
/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1821+
/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1822+
/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1823+
///
1824+
/// Becomes :
1825+
///
1826+
/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1827+
///
1828+
/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
1829+
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1830+
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1831+
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1832+
template <typename MulOpType>
1833+
struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
1834+
using OpRewritePattern<MulOpType>::OpRewritePattern;
1835+
// Returns whether a vector.broadcast matches requirements for an outerproduct
1836+
// pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
1837+
bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
1838+
// Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
1839+
// shape_casts/broadcasts which does not belong in this pattern.
1840+
if (!broadcastOp.computeBroadcastedUnitDims().empty())
1841+
return false;
1842+
// Avoid broadcast like f32 or vector<f32> -> ResType
1843+
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1844+
return srcType && srcType.getRank() != 2;
1845+
}
1846+
1847+
LogicalResult matchAndRewrite(MulOpType mulOp,
1848+
PatternRewriter &rewriter) const override {
1849+
auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1850+
if (!resType)
1851+
return failure();
1852+
if (resType.getRank() != 2)
1853+
return failure();
1854+
/// If operandA can be written as tr(broadcast(A)) and operandB as
1855+
/// broadcast(B) where broadcasts are 1D-to-2D, create and return
1856+
/// vector.outerproduct(A, B). Returns failure() otherwise.
1857+
auto matchOuterProduct =
1858+
[&](Value operandA,
1859+
Value operandB) -> FailureOr<vector::OuterProductOp> {
1860+
auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
1861+
if (!transposedLhs)
1862+
return failure();
1863+
// Fail unless this is a true 2-D matrix transpose.
1864+
ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
1865+
if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
1866+
return failure();
1867+
1868+
auto broadcastedLhs =
1869+
transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
1870+
if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
1871+
return failure();
1872+
1873+
auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
1874+
if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
1875+
return failure();
1876+
1877+
return rewriter.create<vector::OuterProductOp>(
1878+
mulOp->getLoc(), resType, broadcastedLhs.getSource(),
1879+
broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
1880+
};
1881+
1882+
Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
1883+
auto maybeOuterP = matchOuterProduct(lhs, rhs);
1884+
// Handle commutativity, the transposed op is the outerproduct LHS.
1885+
if (failed(maybeOuterP))
1886+
maybeOuterP = matchOuterProduct(rhs, lhs);
1887+
if (failed(maybeOuterP))
1888+
return failure();
1889+
rewriter.replaceOp(mulOp, maybeOuterP->getResult());
1890+
return success();
1891+
}
1892+
};
1893+
18161894
} // namespace
18171895

18181896
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
19001978
maxNumElementsToExtract, benefit);
19011979
}
19021980

1981+
void mlir::vector::populateElementwiseToVectorOpsPatterns(
1982+
RewritePatternSet &patterns) {
1983+
patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
1984+
FoldArithToVectorOuterProduct<arith::MulIOp>>(
1985+
patterns.getContext());
1986+
}
1987+
19031988
//===----------------------------------------------------------------------===//
19041989
// TableGen'd enum attribute definitions
19051990
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/transform-vector.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
9292
transform.yield
9393
}
9494
}
95+
96+
// -----
97+
98+
// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32
99+
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
100+
// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
101+
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
102+
// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
103+
func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
104+
%lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
105+
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
106+
%rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
107+
%mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
108+
return %mul: vector<[4]x[4]xi32>
109+
}
110+
111+
// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32
112+
// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
113+
// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> {
114+
// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32>
115+
// CHECK: return %[[RES]] : vector<8x16xf32>
116+
func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> {
117+
%rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32>
118+
%rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
119+
%lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32>
120+
%mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32>
121+
return %mul: vector<8x16xf32>
122+
}
123+
124+
module attributes {transform.with_named_sequence} {
125+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
126+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
127+
transform.apply_patterns to %func {
128+
transform.apply_patterns.vector.elementwise_to_vector
129+
} : !transform.any_op
130+
transform.yield
131+
}
132+
}

0 commit comments

Comments
 (0)