Skip to content

Commit dab57eb

Browse files
committed
[mlir][vector] Add ElementwiseToOuterproduct
1 parent 1e44a96 commit dab57eb

File tree

5 files changed

+133
-0
lines changed

5 files changed

+133
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
8080
/// into vector contract for the backends with native support.
8181
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
8282

83+
/// Collect a set of patterns that fold elementwise op on vectors to the vector
84+
/// dialect.
85+
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
86+
8387
/// Returns the integer type required for subscripts in the vector dialect.
8488
IntegerType getVectorSubscriptType(Builder &builder);
8589

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
392392
let assemblyFormat = "attr-dict";
393393
}
394394

395+
def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
396+
"apply_patterns.vector.elementwise_to_vector",
397+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
398+
let description = [{
399+
Collect a set of patterns that fold elementwise op on vectors to the vector
400+
dialect.
401+
}];
402+
403+
let assemblyFormat = "attr-dict";
404+
}
405+
395406
def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
396407
"apply_patterns.vector.reduction_to_contract",
397408
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
5959
vector::populateFoldArithExtensionPatterns(patterns);
6060
}
6161

62+
void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
63+
RewritePatternSet &patterns) {
64+
vector::populateElementwiseToVectorOpsPatterns(patterns);
65+
}
66+
6267
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
6368
RewritePatternSet &patterns) {
6469
vector::populateVectorReductionToContractPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
17951795
unsigned maxNumElementsToExtract = 0;
17961796
};
17971797

1798+
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
1799+
/// into vector.outerproduct(A, B) such as :
1800+
/// ```mlir
1801+
/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1802+
/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1803+
/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1804+
/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1805+
///```
1806+
/// Becomes :
1807+
///```mlir
1808+
/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1809+
///```
1810+
/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
1811+
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1812+
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1813+
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1814+
1815+
template <typename MulOpType>
1816+
struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
1817+
using OpRewritePattern<MulOpType>::OpRewritePattern;
1818+
1819+
LogicalResult matchAndRewrite(MulOpType mulOp,
1820+
PatternRewriter &rewriter) const override {
1821+
auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
1822+
if (!VT)
1823+
return failure();
1824+
if (VT.getRank() != 2)
1825+
return failure();
1826+
1827+
auto canonicalize = [&](Value OperandA,
1828+
Value OperandB) -> vector::OuterProductOp {
1829+
vector::TransposeOp transposedLhs =
1830+
dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
1831+
if (!transposedLhs)
1832+
return vector::OuterProductOp();
1833+
// Fail unless this is a true 2-D matrix transpose.
1834+
ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
1835+
if (permutation[0] != 1 || permutation[1] != 0)
1836+
return vector::OuterProductOp();
1837+
1838+
// Fail in case it is not a 1-to-2 dimension to broadcast to avoid
1839+
// generating shape_casts/broadcasts which do not belong in this pattern.
1840+
vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
1841+
transposedLhs.getVector().getDefiningOp());
1842+
if (!broadcastedLhs ||
1843+
!broadcastedLhs.computeBroadcastedUnitDims().empty())
1844+
return vector::OuterProductOp();
1845+
// Avoid broadcast f32 or vector<f32> -> ResType
1846+
auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
1847+
if (!srcVT || srcVT.getRank() != 1)
1848+
return vector::OuterProductOp();
1849+
1850+
vector::BroadcastOp broadcastedRhs =
1851+
dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
1852+
if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
1853+
return vector::OuterProductOp();
1854+
1855+
return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
1856+
mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
1857+
Value(), vector::CombiningKind::ADD);
1858+
};
1859+
Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
1860+
vector::OuterProductOp outerP = canonicalize(a, b);
1861+
// Handle commutativity, the transposed op is the outerproduct LHS.
1862+
outerP = outerP ? outerP : canonicalize(b, a);
1863+
return outerP ? success() : failure();
1864+
}
1865+
};
1866+
17981867
} // namespace
17991868

18001869
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
18821951
maxNumElementsToExtract, benefit);
18831952
}
18841953

1954+
void mlir::vector::populateElementwiseToVectorOpsPatterns(
1955+
RewritePatternSet &patterns) {
1956+
patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
1957+
ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
1958+
}
1959+
18851960
//===----------------------------------------------------------------------===//
18861961
// TableGen'd enum attribute definitions
18871962
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/transform-vector.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
9292
transform.yield
9393
}
9494
}
95+
96+
// -----
97+
98+
// CHECK-LABEL: func.func @ewise_outerproduct
99+
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
100+
// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
101+
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
102+
// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
103+
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
104+
%lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
105+
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
106+
%rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
107+
%mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
108+
return %mul: vector<[4]x[4]xi32>
109+
}
110+
111+
// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
112+
// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
113+
// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
114+
// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
115+
// CHECK: return %[[RES]] : vector<16x16xf32>
116+
func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
117+
%rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
118+
%rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
119+
%lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
120+
%mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
121+
return %mul: vector<16x16xf32>
122+
}
123+
124+
module attributes {transform.with_named_sequence} {
125+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
126+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
127+
transform.apply_patterns to %func {
128+
transform.apply_patterns.vector.elementwise_to_vector
129+
} : !transform.any_op
130+
transform.yield
131+
}
132+
}

0 commit comments

Comments
 (0)