Skip to content

Commit 6638112

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Add padding pass to strategy passes.
Add a strategy pass that pads and hoists after tiling and fusion. Depends On D112412 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D112480
1 parent 9668e19 commit 6638112

File tree

6 files changed

+179
-27
lines changed

6 files changed

+179
-27
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTilePass(
8888
linalg::LinalgTransformationFilter filter =
8989
linalg::LinalgTransformationFilter());
9090

91+
/// Create a LinalgStrategyPadPass.
92+
std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPadPass(
93+
StringRef opName = "",
94+
linalg::LinalgPaddingOptions opt = linalg::LinalgPaddingOptions(),
95+
linalg::LinalgTransformationFilter filter =
96+
linalg::LinalgTransformationFilter());
97+
9198
/// Create a LinalgStrategyPromotePass.
9299
std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPromotePass(
93100
StringRef opName = "",

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,19 @@ def LinalgStrategyTilePass
248248
];
249249
}
250250

251+
def LinalgStrategyPadPass
252+
: FunctionPass<"linalg-strategy-pad-pass"> {
253+
let summary = "Configurable pass to apply padding and hoisting.";
254+
let constructor = "mlir::createLinalgStrategyPadPass()";
255+
let dependentDialects = ["linalg::LinalgDialect"];
256+
let options = [
257+
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
258+
"Which func op is the anchor to latch on.">,
259+
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
260+
"Which linalg op within the func is the anchor to latch on.">,
261+
];
262+
}
263+
251264
def LinalgStrategyPromotePass
252265
: FunctionPass<"linalg-strategy-promote-pass"> {
253266
let summary = "Configurable pass to apply pattern-based linalg promotion.";

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@ struct Tile : public Transformation {
4646
linalg::LinalgTilingOptions options;
4747
};
4848

49+
/// Represent one application of LinalgStrategyPadPass.
50+
struct Pad : public Transformation {
51+
Pad(StringRef name, linalg::LinalgPaddingOptions options,
52+
LinalgTransformationFilter::FilterFunction f = nullptr)
53+
: Transformation(f), opName(name), options(options) {}
54+
55+
void addToPassPipeline(OpPassManager &pm,
56+
LinalgTransformationFilter m) const override {
57+
pm.addPass(createLinalgStrategyPadPass(opName, options, m));
58+
}
59+
60+
private:
61+
std::string opName;
62+
linalg::LinalgPaddingOptions options;
63+
};
64+
4965
/// Represent one application of createLinalgStrategyPromotePass.
5066
struct Promote : public Transformation {
5167
Promote(StringRef name, linalg::LinalgPromotionOptions options,
@@ -147,6 +163,21 @@ struct CodegenStrategy {
147163
LinalgTransformationFilter::FilterFunction f = nullptr) {
148164
return b ? tile(opName, options) : *this;
149165
}
166+
/// Append a pattern to pad and hoist the operands of Op `opName` with padding
167+
/// `options`.
168+
CodegenStrategy &pad(StringRef opName, linalg::LinalgPaddingOptions options,
169+
LinalgTransformationFilter::FilterFunction f = nullptr) {
170+
transformationSequence.emplace_back(
171+
std::make_unique<Pad>(opName, options, f));
172+
return *this;
173+
}
174+
/// Conditionally append a pattern to pad and hoist the operands of Op
175+
/// `opName` with padding `options`.
176+
CodegenStrategy &
177+
padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options,
178+
LinalgTransformationFilter::FilterFunction f = nullptr) {
179+
return b ? pad(opName, options, f) : *this;
180+
}
150181
/// Append a pattern to add a level of promotion for `LinalgOpType` with
151182
/// promotion `options`.
152183
CodegenStrategy &

mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,39 @@ struct LinalgStrategyTilePass
6868
LinalgTransformationFilter filter;
6969
};
7070

71+
/// Configurable pass to apply hoisting and padding.
72+
struct LinalgStrategyPadPass
73+
: public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
74+
75+
LinalgStrategyPadPass() = default;
76+
77+
LinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
78+
LinalgTransformationFilter filt)
79+
: options(opt), filter(filt) {
80+
this->anchorOpName.setValue(opName.str());
81+
}
82+
83+
void runOnFunction() override {
84+
auto funcOp = getFunction();
85+
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
86+
return;
87+
88+
RewritePatternSet paddingPattern(funcOp.getContext());
89+
if (!anchorOpName.empty()) {
90+
paddingPattern.add<LinalgPaddingPattern>(
91+
anchorOpName, funcOp.getContext(), options, filter);
92+
} else {
93+
paddingPattern.add<LinalgPaddingPattern>(funcOp.getContext(), options,
94+
filter);
95+
}
96+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern))))
97+
signalPassFailure();
98+
}
99+
100+
LinalgPaddingOptions options;
101+
LinalgTransformationFilter filter;
102+
};
103+
71104
/// Configurable pass to apply pattern-based linalg generalization.
72105
struct LinalgStrategyGeneralizePass
73106
: public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
@@ -332,6 +365,13 @@ mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
332365
return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
333366
}
334367

368+
/// Create a LinalgStrategyPadPass.
369+
std::unique_ptr<OperationPass<FuncOp>>
370+
mlir::createLinalgStrategyPadPass(StringRef opName, LinalgPaddingOptions opt,
371+
LinalgTransformationFilter filter) {
372+
return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
373+
}
374+
335375
/// Create a LinalgStrategyPromotePass.
336376
std::unique_ptr<OperationPass<FuncOp>>
337377
mlir::createLinalgStrategyPromotePass(StringRef opName,
Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,53 @@
1-
// Test that both anchor-op name and MatmulOp-based codegen strategy produce the same result.
2-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
3-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 tile-interchange=1,2,0 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
4-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
5-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
6-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
7-
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize iterator-interchange=0,2,1" | FileCheck %s --check-prefix=GENER
8-
9-
10-
// CHECK-LABEL: func @matmul(
11-
// OUTER-LABEL: func @matmul(
12-
// GENER-LABEL: func @matmul(
13-
func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
14-
linalg.matmul
15-
ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
16-
outs(%C: memref<1584x1584xf32>)
17-
18-
// CHECK: vector.matrix_multiply
19-
// CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32}
20-
// CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
21-
22-
// OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
23-
24-
// GENER: linalg.generic
25-
// GENER-SAME: iterator_types = ["parallel", "reduction", "parallel"]
1+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" -split-input-file | FileCheck %s --check-prefix=CHECK-INTRINSIC
2+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" -split-input-file | FileCheck %s --check-prefix=CHECK-OUTER
3+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 tile-interchange=1,2,0 generalize iterator-interchange=0,2,1" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE
4+
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD
5+
6+
// CHECK-INTRINSIC: func @matmul(
7+
// CHECK-OUTER: func @matmul(
8+
func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<72x72xf32>) {
9+
10+
// Check the matrix intrinsic lowering is triggered.
11+
// CHECK-INTRINSIC: vector.matrix_multiply
12+
// CHECK-INTRINSIC-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32}
13+
// CHECK-INTRINSIC-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
14+
15+
// Check the outer product lowering is triggered.
16+
// CHECK-OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
17+
linalg.matmul ins(%arg0, %arg1: memref<72x72xf32>, memref<72x72xf32>) outs(%arg2: memref<72x72xf32>)
2618
return
2719
}
2820

21+
// -----
22+
23+
// CHECK-INTERCHANGE: func @matmul(
24+
func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
25+
// CHECK-INTERCHANGE-DAG: %[[C16:.*]] = arith.constant 16
26+
// CHECK-INTERCHANGE-DAG: %[[C32:.*]] = arith.constant 32
27+
// CHECK-INTERCHANGE-DAG: %[[C64:.*]] = arith.constant 64
28+
29+
// Check the tile loops are interchanged.
30+
// CHECK-INTERCHANGE: scf.for {{.*}} step %[[C32]]
31+
// CHECK-INTERCHANGE: scf.for {{.*}} step %[[C64]]
32+
// CHECK-INTERCHANGE: scf.for {{.*}} step %[[C16]]
33+
34+
// Check the operation has been generalized and interchanged.
35+
// CHECK-INTERCHANGE: linalg.generic
36+
// CHECK-INTERCHANGE-SAME: iterator_types = ["parallel", "reduction", "parallel"]
37+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32>
38+
return %0 : tensor<72x72xf32>
39+
}
40+
41+
// -----
42+
43+
// CHECK-PAD: func @matmul(
44+
func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> {
45+
46+
// Check the padding of the input operands has been hoisted out of the tile loop nest.
47+
// CHECK-PAD-COUNT=2: linalg.pad_tensor %{{.*}} nofold
48+
// CHECK-PAD-COUNT=3: scf.for
49+
// CHECK-PAD: linalg.matmul
50+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32>
51+
return %0 : tensor<72x72xf32>
52+
}
53+

mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct TestLinalgCodegenStrategy
5454

5555
void runStrategy(LinalgTilingOptions tilingOptions,
5656
LinalgTilingOptions registerTilingOptions,
57+
LinalgPaddingOptions paddingOptions,
5758
vector::VectorContractLowering vectorContractLowering,
5859
vector::VectorTransferSplit vectorTransferSplit);
5960

@@ -86,6 +87,16 @@ struct TestLinalgCodegenStrategy
8687
*this, "register-promote-full-tile-pad",
8788
llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
8889
llvm::cl::init(false)};
90+
Option<bool> pad{*this, "pad", llvm::cl::desc("Pad the operands."),
91+
llvm::cl::init(false)};
92+
ListOption<int64_t> packPaddings{
93+
*this, "pack-paddings",
94+
llvm::cl::desc("Operand packing flags when test-pad-pattern"),
95+
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
96+
ListOption<int64_t> hoistPaddings{
97+
*this, "hoist-paddings",
98+
llvm::cl::desc("Operand hoisting depths when test-pad-pattern"),
99+
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
89100
Option<bool> generalize{*this, "generalize",
90101
llvm::cl::desc("Generalize named operations."),
91102
llvm::cl::init(false)};
@@ -132,9 +143,18 @@ struct TestLinalgCodegenStrategy
132143
llvm::cl::init("")};
133144
};
134145

146+
// For now, just assume it is the zero of type.
147+
// In the future, it should be the zero of type + op.
148+
static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
149+
auto t = getElementTypeOrSelf(op.get());
150+
return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
151+
b.getZeroAttr(t));
152+
}
153+
135154
void TestLinalgCodegenStrategy::runStrategy(
136155
LinalgTilingOptions tilingOptions,
137156
LinalgTilingOptions registerTilingOptions,
157+
LinalgPaddingOptions paddingOptions,
138158
vector::VectorContractLowering vectorContractLowering,
139159
vector::VectorTransferSplit vectorTransferSplit) {
140160
assert(!anchorOpName.empty());
@@ -150,6 +170,7 @@ void TestLinalgCodegenStrategy::runStrategy(
150170
LinalgPromotionOptions()
151171
.setAlignment(16)
152172
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
173+
.padIf(pad, anchorOpName, paddingOptions)
153174
.generalizeIf(generalize, anchorOpName)
154175
.interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
155176
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
@@ -191,6 +212,21 @@ void TestLinalgCodegenStrategy::runOnFunction() {
191212
registerTilingOptions =
192213
registerTilingOptions.setTileSizes(registerTileSizes);
193214

215+
LinalgPaddingOptions paddingOptions;
216+
auto packFunc = [&](OpOperand &opOperand) {
217+
return opOperand.getOperandNumber() < packPaddings.size()
218+
? packPaddings[opOperand.getOperandNumber()]
219+
: false;
220+
};
221+
auto hoistingFunc = [&](OpOperand &opOperand) {
222+
return opOperand.getOperandNumber() < hoistPaddings.size()
223+
? hoistPaddings[opOperand.getOperandNumber()]
224+
: 0;
225+
};
226+
paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
227+
paddingOptions.setPaddingNoFoldComputationFunction(packFunc);
228+
paddingOptions.setPaddingHoistComputationFunction(hoistingFunc);
229+
194230
vector::VectorContractLowering vectorContractLowering =
195231
llvm::StringSwitch<vector::VectorContractLowering>(
196232
vectorizeContractionTo.getValue())
@@ -206,8 +242,8 @@ void TestLinalgCodegenStrategy::runOnFunction() {
206242
.Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
207243
.Default(vector::VectorTransferSplit::None);
208244

209-
runStrategy(tilingOptions, registerTilingOptions, vectorContractLowering,
210-
vectorTransferSplit);
245+
runStrategy(tilingOptions, registerTilingOptions, paddingOptions,
246+
vectorContractLowering, vectorTransferSplit);
211247
}
212248

213249
namespace mlir {

0 commit comments

Comments
 (0)