Skip to content

Commit e03b443

Browse files
Revert "[mlir][Linalg] NFC - Reorganize options nesting."
This reverts commit 4703a07. Didnt' mean to push this yet, sorry about the noise.
1 parent 4f5e9a2 commit e03b443

File tree

6 files changed

+113
-104
lines changed

6 files changed

+113
-104
lines changed

mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,29 +46,29 @@ class RewritePatternSet;
4646
///
4747
/// When applying the pattern a second time, the existing alloca() operation
4848
/// is reused and only a second vector.type_cast is added.
49+
4950
struct VectorTransferToSCFOptions {
50-
/// Minimal rank to which vector transfer are lowered.
5151
unsigned targetRank = 1;
52-
VectorTransferToSCFOptions &setTargetRank(unsigned r) {
53-
targetRank = r;
54-
return *this;
55-
}
56-
///
5752
bool lowerPermutationMaps = false;
58-
VectorTransferToSCFOptions &enableLowerPermutationMaps(bool l = true) {
53+
bool lowerTensors = false;
54+
bool unroll = false;
55+
56+
VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
5957
lowerPermutationMaps = l;
6058
return *this;
6159
}
62-
/// Allows vector transfers that operated on tensors to be lowered (this is an
63-
/// uncommon alternative).
64-
bool lowerTensors = false;
65-
VectorTransferToSCFOptions &enableLowerTensors(bool l = true) {
60+
61+
VectorTransferToSCFOptions &setLowerTensors(bool l) {
6662
lowerTensors = l;
6763
return *this;
6864
}
69-
/// Triggers full unrolling (vs iterating with a loop) during transfer to scf.
70-
bool unroll = false;
71-
VectorTransferToSCFOptions &enableFullUnroll(bool u = true) {
65+
66+
VectorTransferToSCFOptions &setTargetRank(unsigned r) {
67+
targetRank = r;
68+
return *this;
69+
}
70+
71+
VectorTransferToSCFOptions &setUnroll(bool u) {
7272
unroll = u;
7373
return *this;
7474
}

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,53 @@ struct CodegenStrategy {
195195
return b ? vectorize(opName, f) : *this;
196196
return *this;
197197
}
198-
/// Configure the post staged-patterns late vector lowering options.
198+
/// Configure the post staged-patterns late vector transformations.
199199
CodegenStrategy &
200-
setLinalgVectorLoweringOptions(LinalgVectorLoweringOptions options) {
201-
lateVectorLoweringOptions = options;
200+
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
201+
vectorTransformOptions = options;
202202
return *this;
203203
}
204-
/// Configure the post staged-patterns global enabling passes options.
204+
/// Configure the post staged-patterns late vector.transfer to scf
205+
/// conversion.
205206
CodegenStrategy &
206-
setVectorTransferToSCFOptions(LinalgEnablingOptions options) {
207-
linalgEnablingOptions = options;
207+
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
208+
vectorToSCFOptions = options;
209+
return *this;
210+
}
211+
///
212+
/// Configure the application of late transformations.
213+
///
214+
CodegenStrategy &setEnableLICM(bool val) {
215+
this->lateCodegenStrategyOptions.enableLICM = val;
216+
return *this;
217+
}
218+
CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) {
219+
this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val;
220+
return *this;
221+
}
222+
CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) {
223+
this->lateCodegenStrategyOptions
224+
.enableHoistRedundantVectorTransfersOnTensor = val;
225+
return *this;
226+
}
227+
CodegenStrategy &setMaxTransferRank(int64_t val) {
228+
this->lateCodegenStrategyOptions.maxTransferRank = val;
229+
return *this;
230+
}
231+
CodegenStrategy &setEnableVectorTransferLowering(bool val) {
232+
this->lateCodegenStrategyOptions.enableVectorTransferLowering = val;
233+
return *this;
234+
}
235+
CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) {
236+
this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val;
237+
return *this;
238+
}
239+
CodegenStrategy &setEnableVectorContractLowering(bool val) {
240+
this->lateCodegenStrategyOptions.enableVectorContractLowering = val;
241+
return *this;
242+
}
243+
CodegenStrategy &setEnableVectorToSCFConversion(bool val) {
244+
this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val;
208245
return *this;
209246
}
210247

@@ -215,9 +252,10 @@ struct CodegenStrategy {
215252
private:
216253
LogicalResult postPatternTransforms(Operation *func) const;
217254

218-
LinalgEnablingOptions linalgEnablingOptions;
219-
LinalgVectorLoweringOptions lateVectorLoweringOptions;
255+
vector::VectorTransformsOptions vectorTransformOptions;
256+
VectorTransferToSCFOptions vectorToSCFOptions;
220257
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
258+
LateCodegenStrategyOptions lateCodegenStrategyOptions;
221259
};
222260

223261
} // namespace linalg

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -846,80 +846,41 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
846846
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
847847
};
848848

849+
/// Options to control the application of late transformations.
850+
struct LateCodegenStrategyOptions {
851+
/// Hoisting transformations are always deemed beneficial and must disabled
852+
/// explicitly.
853+
bool enableLICM = true;
854+
bool enableHoistRedundantVectorTransfers = true;
855+
bool enableHoistRedundantVectorTransfersOnTensor = true;
856+
/// Vector lowering operations may result in surprising behavior when
857+
/// composing multiple codegen strategies and must be enabled explicitly.
858+
int64_t maxTransferRank = 1;
859+
bool enableVectorTransferLowering = true;
860+
bool enableVectorTransferPartialRewrite = false;
861+
bool enableVectorContractLowering = false;
862+
bool enableVectorToSCFConversion = false;
863+
};
864+
849865
/// Options to control the application of enabling transformations.
850866
/// Hoisting transformations are always deemed beneficial and must be disabled
851867
/// explicitly.
852868
struct LinalgEnablingOptions {
853-
/// Enable LICM.
854-
bool licm = true;
855-
LinalgEnablingOptions &enableLICM(bool val = true) {
856-
licm = val;
857-
return *this;
858-
}
859-
/// Enable hoisting of redundant vector transfer ops.
860-
bool hoistRedundantVectorTransfers = true;
861-
LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) {
862-
hoistRedundantVectorTransfers = val;
863-
return *this;
864-
}
865-
/// Enable hoisting of redundant vector transfer ops on tensor.
866-
bool hoistRedundantVectorTransfersOnTensor = true;
867-
LinalgEnablingOptions &
868-
enableHoistRedundantVectorTransfersOnTensor(bool val = true) {
869-
hoistRedundantVectorTransfersOnTensor = val;
870-
return *this;
871-
}
869+
bool enableLICM = true;
870+
bool enableHoistRedundantVectorTransfers = true;
871+
bool enableHoistRedundantVectorTransfersOnTensor = true;
872872
};
873873

874874
/// Vector lowering options control how ops are lowered down to 1-D and scf.for
875875
/// form.
876876
struct LinalgVectorLoweringOptions {
877-
/// Maximal transfer rank under which we do not lower further.
878877
int64_t maxTransferRank = 1;
879-
LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
880-
maxTransferRank = val;
881-
return *this;
882-
}
883-
/// Vector lowering operations may result in surprising behavior when
884-
/// composing multiple codegen strategies and must be enabled explicitly.
885-
bool transferLowering = true;
886-
LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
887-
transferLowering = val;
888-
return *this;
889-
}
890-
/// Trigger full / partial vector.transfer splits.
891-
bool transferPartialRewrite = false;
892-
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
893-
transferPartialRewrite = val;
894-
return *this;
895-
}
896-
/// Enable lowering of vector.contract.
897-
bool contractionLowering = false;
898-
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
899-
contractionLowering = val;
900-
return *this;
901-
}
902-
/// Enable lowering of vector.transfer to scf.
903-
bool transferToSCFConversion = false;
904-
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
905-
transferToSCFConversion = val;
906-
return *this;
907-
}
908-
/// Configure late vector transformations.
878+
bool enableVectorTransferLowering = true;
879+
bool enableVectorTransferPartialRewrite = false;
880+
bool enableVectorContractLowering = false;
881+
bool enableVectorToSCFConversion = false;
909882
vector::VectorTransformsOptions vectorTransformOptions;
910-
LinalgVectorLoweringOptions &
911-
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
912-
vectorTransformOptions = options;
913-
return *this;
914-
}
915-
/// Configure the post staged-patterns late vector.transfer to scf
916-
/// conversion.
917883
VectorTransferToSCFOptions vectorTransferToSCFOptions;
918-
LinalgVectorLoweringOptions &
919-
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
920-
vectorTransferToSCFOptions = options;
921-
return *this;
922-
}
923884
};
924885

925886
/// Trait to check if T provides a `getOperationName` method.

mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ void mlir::linalg::CodegenStrategy::configurePassPipeline(
4646
t->addToPassPipeline(pm, filter);
4747
pm.addPass(createLinalgStrategyEnablePass());
4848
}
49-
pm.addPass(createLinalgStrategyLowerVectorsPass(lateVectorLoweringOptions));
49+
LinalgVectorLoweringOptions vectorLoweringOptions;
50+
vectorLoweringOptions.maxTransferRank =
51+
lateCodegenStrategyOptions.maxTransferRank;
52+
vectorLoweringOptions.enableVectorTransferLowering =
53+
lateCodegenStrategyOptions.enableVectorTransferLowering;
54+
vectorLoweringOptions.enableVectorTransferPartialRewrite =
55+
lateCodegenStrategyOptions.enableVectorTransferPartialRewrite;
56+
vectorLoweringOptions.enableVectorContractLowering =
57+
lateCodegenStrategyOptions.enableVectorContractLowering;
58+
vectorLoweringOptions.enableVectorToSCFConversion =
59+
lateCodegenStrategyOptions.enableVectorToSCFConversion;
60+
vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions;
61+
vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions;
62+
pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions));
5063
pm.addPass(createLinalgStrategyRemoveMarkersPass());
5164
}

mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ struct LinalgStrategyEnablePass
224224
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
225225
return signalPassFailure();
226226

227-
if (options.licm) {
227+
if (options.enableLICM) {
228228
if (funcOp
229229
->walk([&](LoopLikeOpInterface loopLike) {
230230
if (failed(moveLoopInvariantCode(loopLike)))
@@ -236,10 +236,10 @@ struct LinalgStrategyEnablePass
236236
}
237237

238238
promoteSingleIterationLoops(funcOp);
239-
if (options.hoistRedundantVectorTransfers)
239+
if (options.enableHoistRedundantVectorTransfers)
240240
hoistRedundantVectorTransfers(funcOp);
241241

242-
if (options.hoistRedundantVectorTransfersOnTensor)
242+
if (options.enableHoistRedundantVectorTransfersOnTensor)
243243
hoistRedundantVectorTransfersOnTensor(funcOp);
244244
}
245245

@@ -263,21 +263,21 @@ struct LinalgStrategyLowerVectorsPass
263263

264264
MLIRContext *context = funcOp.getContext();
265265
RewritePatternSet patterns(context);
266-
if (options.transferLowering) {
266+
if (options.enableVectorTransferLowering) {
267267
vector::populateVectorTransferLoweringPatterns(patterns,
268268
options.maxTransferRank);
269269
}
270-
if (options.transferPartialRewrite) {
270+
if (options.enableVectorTransferPartialRewrite) {
271271
patterns.add<vector::VectorTransferFullPartialRewriter>(
272272
context, options.vectorTransformOptions);
273273
}
274-
if (options.contractionLowering) {
274+
if (options.enableVectorContractLowering) {
275275
patterns.add<ContractionOpToOuterProductOpLowering,
276276
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
277277
options.vectorTransformOptions, context);
278278
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
279279
}
280-
if (options.transferToSCFConversion) {
280+
if (options.enableVectorToSCFConversion) {
281281
populateVectorToSCFConversionPatterns(patterns,
282282
options.vectorTransferToSCFOptions);
283283
}

mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,15 @@ void TestLinalgCodegenStrategy::runStrategy(
153153
.generalizeIf(generalize, anchorOpName)
154154
.interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
155155
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
156-
.setLinalgVectorLoweringOptions(
157-
LinalgVectorLoweringOptions()
158-
.setVectorTransformsOptions(
159-
vector::VectorTransformsOptions()
160-
.setVectorTransformsOptions(vectorContractLowering)
161-
.setVectorTransferSplit(vectorTransferSplit))
162-
.setVectorTransferToSCFOptions(
163-
VectorTransferToSCFOptions().enableFullUnroll(
164-
unrollVectorTransfers))
165-
.enableTransferPartialRewrite()
166-
.enableContractionLowering()
167-
.enableTransferToSCFConversion());
156+
.setEnableVectorTransferPartialRewrite(true)
157+
.setEnableVectorContractLowering(true)
158+
.setEnableVectorToSCFConversion(true)
159+
.setVectorTransformsOptions(
160+
vector::VectorTransformsOptions()
161+
.setVectorTransformsOptions(vectorContractLowering)
162+
.setVectorTransferSplit(vectorTransferSplit))
163+
.setVectorTransferToSCFOptions(
164+
VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
168165

169166
// Created a nested OpPassManager and run.
170167
FuncOp funcOp = getFunction();

0 commit comments

Comments
 (0)