Skip to content

Commit 4703a07

Browse files
[mlir][Linalg] NFC - Reorganize options nesting.
This removes duplication and makes nesting more clear. It also reduces the amount of changes necessary for exposing future options. Differential revision: https://reviews.llvm.org/D112344
1 parent 35553d4 commit 4703a07

File tree

6 files changed

+104
-113
lines changed

6 files changed

+104
-113
lines changed

mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,29 +46,29 @@ class RewritePatternSet;
4646
///
4747
/// When applying the pattern a second time, the existing alloca() operation
4848
/// is reused and only a second vector.type_cast is added.
49-
5049
struct VectorTransferToSCFOptions {
50+
/// Minimal rank to which vector transfer are lowered.
5151
unsigned targetRank = 1;
52+
VectorTransferToSCFOptions &setTargetRank(unsigned r) {
53+
targetRank = r;
54+
return *this;
55+
}
56+
///
5257
bool lowerPermutationMaps = false;
53-
bool lowerTensors = false;
54-
bool unroll = false;
55-
56-
VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
58+
VectorTransferToSCFOptions &enableLowerPermutationMaps(bool l = true) {
5759
lowerPermutationMaps = l;
5860
return *this;
5961
}
60-
61-
VectorTransferToSCFOptions &setLowerTensors(bool l) {
62+
/// Allows vector transfers that operated on tensors to be lowered (this is an
63+
/// uncommon alternative).
64+
bool lowerTensors = false;
65+
VectorTransferToSCFOptions &enableLowerTensors(bool l = true) {
6266
lowerTensors = l;
6367
return *this;
6468
}
65-
66-
VectorTransferToSCFOptions &setTargetRank(unsigned r) {
67-
targetRank = r;
68-
return *this;
69-
}
70-
71-
VectorTransferToSCFOptions &setUnroll(bool u) {
69+
/// Triggers full unrolling (vs iterating with a loop) during transfer to scf.
70+
bool unroll = false;
71+
VectorTransferToSCFOptions &enableFullUnroll(bool u = true) {
7272
unroll = u;
7373
return *this;
7474
}

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h

Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -195,53 +195,16 @@ struct CodegenStrategy {
195195
return b ? vectorize(opName, f) : *this;
196196
return *this;
197197
}
198-
/// Configure the post staged-patterns late vector transformations.
198+
/// Configure the post staged-patterns late vector lowering options.
199199
CodegenStrategy &
200-
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
201-
vectorTransformOptions = options;
200+
setLinalgVectorLoweringOptions(LinalgVectorLoweringOptions options) {
201+
lateVectorLoweringOptions = options;
202202
return *this;
203203
}
204-
/// Configure the post staged-patterns late vector.transfer to scf
205-
/// conversion.
204+
/// Configure the post staged-patterns global enabling passes options.
206205
CodegenStrategy &
207-
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
208-
vectorToSCFOptions = options;
209-
return *this;
210-
}
211-
///
212-
/// Configure the application of late transformations.
213-
///
214-
CodegenStrategy &setEnableLICM(bool val) {
215-
this->lateCodegenStrategyOptions.enableLICM = val;
216-
return *this;
217-
}
218-
CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) {
219-
this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val;
220-
return *this;
221-
}
222-
CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) {
223-
this->lateCodegenStrategyOptions
224-
.enableHoistRedundantVectorTransfersOnTensor = val;
225-
return *this;
226-
}
227-
CodegenStrategy &setMaxTransferRank(int64_t val) {
228-
this->lateCodegenStrategyOptions.maxTransferRank = val;
229-
return *this;
230-
}
231-
CodegenStrategy &setEnableVectorTransferLowering(bool val) {
232-
this->lateCodegenStrategyOptions.enableVectorTransferLowering = val;
233-
return *this;
234-
}
235-
CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) {
236-
this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val;
237-
return *this;
238-
}
239-
CodegenStrategy &setEnableVectorContractLowering(bool val) {
240-
this->lateCodegenStrategyOptions.enableVectorContractLowering = val;
241-
return *this;
242-
}
243-
CodegenStrategy &setEnableVectorToSCFConversion(bool val) {
244-
this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val;
206+
setVectorTransferToSCFOptions(LinalgEnablingOptions options) {
207+
linalgEnablingOptions = options;
245208
return *this;
246209
}
247210

@@ -252,10 +215,9 @@ struct CodegenStrategy {
252215
private:
253216
LogicalResult postPatternTransforms(Operation *func) const;
254217

255-
vector::VectorTransformsOptions vectorTransformOptions;
256-
VectorTransferToSCFOptions vectorToSCFOptions;
218+
LinalgEnablingOptions linalgEnablingOptions;
219+
LinalgVectorLoweringOptions lateVectorLoweringOptions;
257220
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
258-
LateCodegenStrategyOptions lateCodegenStrategyOptions;
259221
};
260222

261223
} // namespace linalg

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -846,41 +846,80 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
846846
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
847847
};
848848

849-
/// Options to control the application of late transformations.
850-
struct LateCodegenStrategyOptions {
851-
/// Hoisting transformations are always deemed beneficial and must disabled
852-
/// explicitly.
853-
bool enableLICM = true;
854-
bool enableHoistRedundantVectorTransfers = true;
855-
bool enableHoistRedundantVectorTransfersOnTensor = true;
856-
/// Vector lowering operations may result in surprising behavior when
857-
/// composing multiple codegen strategies and must be enabled explicitly.
858-
int64_t maxTransferRank = 1;
859-
bool enableVectorTransferLowering = true;
860-
bool enableVectorTransferPartialRewrite = false;
861-
bool enableVectorContractLowering = false;
862-
bool enableVectorToSCFConversion = false;
863-
};
864-
865849
/// Options to control the application of enabling transformations.
866850
/// Hoisting transformations are always deemed beneficial and must be disabled
867851
/// explicitly.
868852
struct LinalgEnablingOptions {
869-
bool enableLICM = true;
870-
bool enableHoistRedundantVectorTransfers = true;
871-
bool enableHoistRedundantVectorTransfersOnTensor = true;
853+
/// Enable LICM.
854+
bool licm = true;
855+
LinalgEnablingOptions &enableLICM(bool val = true) {
856+
licm = val;
857+
return *this;
858+
}
859+
/// Enable hoisting of redundant vector transfer ops.
860+
bool hoistRedundantVectorTransfers = true;
861+
LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) {
862+
hoistRedundantVectorTransfers = val;
863+
return *this;
864+
}
865+
/// Enable hoisting of redundant vector transfer ops on tensor.
866+
bool hoistRedundantVectorTransfersOnTensor = true;
867+
LinalgEnablingOptions &
868+
enableHoistRedundantVectorTransfersOnTensor(bool val = true) {
869+
hoistRedundantVectorTransfersOnTensor = val;
870+
return *this;
871+
}
872872
};
873873

874874
/// Vector lowering options control how ops are lowered down to 1-D and scf.for
875875
/// form.
876876
struct LinalgVectorLoweringOptions {
877+
/// Maximal transfer rank under which we do not lower further.
877878
int64_t maxTransferRank = 1;
878-
bool enableVectorTransferLowering = true;
879-
bool enableVectorTransferPartialRewrite = false;
880-
bool enableVectorContractLowering = false;
881-
bool enableVectorToSCFConversion = false;
879+
LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
880+
maxTransferRank = val;
881+
return *this;
882+
}
883+
/// Vector lowering operations may result in surprising behavior when
884+
/// composing multiple codegen strategies and must be enabled explicitly.
885+
bool transferLowering = true;
886+
LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
887+
transferLowering = val;
888+
return *this;
889+
}
890+
/// Trigger full / partial vector.transfer splits.
891+
bool transferPartialRewrite = false;
892+
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
893+
transferPartialRewrite = val;
894+
return *this;
895+
}
896+
/// Enable lowering of vector.contract.
897+
bool contractionLowering = false;
898+
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
899+
contractionLowering = val;
900+
return *this;
901+
}
902+
/// Enable lowering of vector.transfer to scf.
903+
bool transferToSCFConversion = false;
904+
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
905+
transferToSCFConversion = val;
906+
return *this;
907+
}
908+
/// Configure late vector transformations.
882909
vector::VectorTransformsOptions vectorTransformOptions;
910+
LinalgVectorLoweringOptions &
911+
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
912+
vectorTransformOptions = options;
913+
return *this;
914+
}
915+
/// Configure the post staged-patterns late vector.transfer to scf
916+
/// conversion.
883917
VectorTransferToSCFOptions vectorTransferToSCFOptions;
918+
LinalgVectorLoweringOptions &
919+
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
920+
vectorTransferToSCFOptions = options;
921+
return *this;
922+
}
884923
};
885924

886925
/// Trait to check if T provides a `getOperationName` method.

mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,6 @@ void mlir::linalg::CodegenStrategy::configurePassPipeline(
4646
t->addToPassPipeline(pm, filter);
4747
pm.addPass(createLinalgStrategyEnablePass());
4848
}
49-
LinalgVectorLoweringOptions vectorLoweringOptions;
50-
vectorLoweringOptions.maxTransferRank =
51-
lateCodegenStrategyOptions.maxTransferRank;
52-
vectorLoweringOptions.enableVectorTransferLowering =
53-
lateCodegenStrategyOptions.enableVectorTransferLowering;
54-
vectorLoweringOptions.enableVectorTransferPartialRewrite =
55-
lateCodegenStrategyOptions.enableVectorTransferPartialRewrite;
56-
vectorLoweringOptions.enableVectorContractLowering =
57-
lateCodegenStrategyOptions.enableVectorContractLowering;
58-
vectorLoweringOptions.enableVectorToSCFConversion =
59-
lateCodegenStrategyOptions.enableVectorToSCFConversion;
60-
vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions;
61-
vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions;
62-
pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions));
49+
pm.addPass(createLinalgStrategyLowerVectorsPass(lateVectorLoweringOptions));
6350
pm.addPass(createLinalgStrategyRemoveMarkersPass());
6451
}

mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ struct LinalgStrategyEnablePass
224224
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
225225
return signalPassFailure();
226226

227-
if (options.enableLICM) {
227+
if (options.licm) {
228228
if (funcOp
229229
->walk([&](LoopLikeOpInterface loopLike) {
230230
if (failed(moveLoopInvariantCode(loopLike)))
@@ -236,10 +236,10 @@ struct LinalgStrategyEnablePass
236236
}
237237

238238
promoteSingleIterationLoops(funcOp);
239-
if (options.enableHoistRedundantVectorTransfers)
239+
if (options.hoistRedundantVectorTransfers)
240240
hoistRedundantVectorTransfers(funcOp);
241241

242-
if (options.enableHoistRedundantVectorTransfersOnTensor)
242+
if (options.hoistRedundantVectorTransfersOnTensor)
243243
hoistRedundantVectorTransfersOnTensor(funcOp);
244244
}
245245

@@ -263,21 +263,21 @@ struct LinalgStrategyLowerVectorsPass
263263

264264
MLIRContext *context = funcOp.getContext();
265265
RewritePatternSet patterns(context);
266-
if (options.enableVectorTransferLowering) {
266+
if (options.transferLowering) {
267267
vector::populateVectorTransferLoweringPatterns(patterns,
268268
options.maxTransferRank);
269269
}
270-
if (options.enableVectorTransferPartialRewrite) {
270+
if (options.transferPartialRewrite) {
271271
patterns.add<vector::VectorTransferFullPartialRewriter>(
272272
context, options.vectorTransformOptions);
273273
}
274-
if (options.enableVectorContractLowering) {
274+
if (options.contractionLowering) {
275275
patterns.add<ContractionOpToOuterProductOpLowering,
276276
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
277277
options.vectorTransformOptions, context);
278278
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
279279
}
280-
if (options.enableVectorToSCFConversion) {
280+
if (options.transferToSCFConversion) {
281281
populateVectorToSCFConversionPatterns(patterns,
282282
options.vectorTransferToSCFOptions);
283283
}

mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,18 @@ void TestLinalgCodegenStrategy::runStrategy(
153153
.generalizeIf(generalize, anchorOpName)
154154
.interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
155155
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
156-
.setEnableVectorTransferPartialRewrite(true)
157-
.setEnableVectorContractLowering(true)
158-
.setEnableVectorToSCFConversion(true)
159-
.setVectorTransformsOptions(
160-
vector::VectorTransformsOptions()
161-
.setVectorTransformsOptions(vectorContractLowering)
162-
.setVectorTransferSplit(vectorTransferSplit))
163-
.setVectorTransferToSCFOptions(
164-
VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
156+
.setLinalgVectorLoweringOptions(
157+
LinalgVectorLoweringOptions()
158+
.setVectorTransformsOptions(
159+
vector::VectorTransformsOptions()
160+
.setVectorTransformsOptions(vectorContractLowering)
161+
.setVectorTransferSplit(vectorTransferSplit))
162+
.setVectorTransferToSCFOptions(
163+
VectorTransferToSCFOptions().enableFullUnroll(
164+
unrollVectorTransfers))
165+
.enableTransferPartialRewrite()
166+
.enableContractionLowering()
167+
.enableTransferToSCFConversion());
165168

166169
// Created a nested OpPassManager and run.
167170
FuncOp funcOp = getFunction();

0 commit comments

Comments
 (0)