Skip to content

Commit 32b7c1f

Browse files
committed
[mlir][TOSA] Set default TOSA validation level to 'None' for TOSA -> linalg
Unless otherwise specified this pass should not assume a level, as this rejects otherwise valid TOSA. This has caused build failures in IREE. The level (and other validation options) have now been made configurable. The pass options have been converted to enums to make them more type safe in C++. Reviewed By: Tai78641 Differential Revision: https://reviews.llvm.org/D157282
1 parent e7191fb commit 32b7c1f

File tree

5 files changed

+69
-21
lines changed

5 files changed

+69
-21
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
1515
#define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H
1616

17+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
1718
#include "mlir/Pass/Pass.h"
1819

1920
namespace mlir {
@@ -31,8 +32,11 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
3132
/// the pass, the function will only contain linalg ops or standard ops if the
3233
/// pipeline succeeds. The option to disable decompositions is available for
3334
/// benchmarking performance improvements from the canonicalizations.
34-
void addTosaToLinalgPasses(OpPassManager &pm,
35-
bool disableTosaDecompositions = false);
35+
void addTosaToLinalgPasses(
36+
OpPassManager &pm, bool disableTosaDecompositions = false,
37+
// Note: Default to 'none' level unless otherwise specified.
38+
tosa::ValidationOptions const &validationOptions =
39+
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
3640

3741
/// Populates conversion passes from TOSA dialect to Linalg dialect.
3842
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,31 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
4040
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
4141
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
4242
std::unique_ptr<Pass> createTosaOptionalDecompositions();
43-
std::unique_ptr<Pass> createTosaValidationPass();
43+
44+
struct ValidationOptions {
45+
/// Validate if operations match for the given profile.
46+
TosaProfileEnum profile = TosaProfileEnum::Undefined;
47+
ValidationOptions &setProfile(TosaProfileEnum profile) {
48+
this->profile = profile;
49+
return *this;
50+
}
51+
/// Verify if the properties of certain operations align the spec requirement.
52+
bool strictOperationSpecAlignment = false;
53+
ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) {
54+
strictOperationSpecAlignment = enable;
55+
return *this;
56+
}
57+
/// Validate if operator parameters are within specfication for the given
58+
/// level.
59+
TosaLevelEnum level = TosaLevelEnum::EightK;
60+
ValidationOptions &setLevel(TosaLevelEnum level) {
61+
this->level = level;
62+
return *this;
63+
}
64+
};
65+
66+
std::unique_ptr<Pass> createTosaValidationPass(
67+
ValidationOptions const &options = ValidationOptions());
4468

4569
#define GEN_PASS_REGISTRATION
4670
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,32 @@ def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
9191
let constructor = "createTosaValidationPass()";
9292

9393
let options = [
94-
Option<"profileName", "profile", "std::string",
95-
/*default=*/"\"undefined\"",
96-
"Validate if operations match for the given profile">,
94+
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
95+
/*default=*/"mlir::tosa::TosaProfileEnum::Undefined",
96+
"Validate if operations match for the given profile",
97+
[{::llvm::cl::values(
98+
clEnumValN(mlir::tosa::TosaProfileEnum::BaseInference, "bi",
99+
"Use Base Inference profile."),
100+
clEnumValN(mlir::tosa::TosaProfileEnum::MainInference, "mi",
101+
"Use Main Inference profile."),
102+
clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "mt",
103+
"Use Main Training profile."),
104+
clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "undefined",
105+
"Do not define a profile.")
106+
)}]>,
97107
Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool",
98108
/*default=*/"false",
99109
"Verify if the properties of certain operations align the spec requirement">,
100-
Option<"levelName", "level", "std::string",
101-
/*default=*/"\"8k\"",
102-
"Validate if operator parameters are within specfication for the given level">,
110+
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
111+
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
112+
"Validate if operator parameters are within specfication for the given level",
113+
[{::llvm::cl::values(
114+
clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
115+
"Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
116+
clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
117+
"Allows the full range of arguments specified by the operations according "
118+
"to the operation data types.")
119+
)}]>
103120
];
104121
}
105122

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
7474
return std::make_unique<TosaToLinalg>();
7575
}
7676

77-
void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
78-
bool disableTosaDecompositions) {
77+
void mlir::tosa::addTosaToLinalgPasses(
78+
OpPassManager &pm, bool disableTosaDecompositions,
79+
tosa::ValidationOptions const &validationOptions) {
7980
// Optional decompositions are designed to benefit linalg.
8081
if (!disableTosaDecompositions)
8182
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -88,6 +89,7 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
8889
// TODO: Remove pass that operates on const tensor and enable optionality
8990
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
9091
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
91-
pm.addNestedPass<func::FuncOp>(tosa::createTosaValidationPass());
92+
pm.addNestedPass<func::FuncOp>(
93+
tosa::createTosaValidationPass(validationOptions));
9294
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
9395
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
9696
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
9797
public:
9898
explicit TosaValidation() { populateConstantOperandChecks(); }
99+
explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
100+
this->profile = options.profile;
101+
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
102+
this->level = options.level;
103+
}
99104
void runOnOperation() override;
100105

101106
LogicalResult applyConstantOperandCheck(Operation *op) {
@@ -387,18 +392,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
387392
// configure profile and level values from pass options profileName and
388393
// levelName
389394
void configLevelAndProfile() {
390-
profileType = symbolizeEnum<TosaProfileEnum>(profileName);
391-
392-
auto levelType = symbolizeEnum<TosaLevelEnum>(levelName);
393-
394395
tosa_level = TOSA_LEVEL_NONE;
395-
if (levelType == TosaLevelEnum::EightK) {
396+
if (level == TosaLevelEnum::EightK) {
396397
tosa_level = TOSA_LEVEL_EIGHTK;
397398
}
398399
}
399400

400401
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
401-
std::optional<TosaProfileEnum> profileType;
402402
tosa_level_t tosa_level;
403403
};
404404

@@ -431,7 +431,7 @@ void TosaValidation::runOnOperation() {
431431
configLevelAndProfile();
432432
getOperation().walk([&](Operation *op) {
433433
for (Value operand : op->getOperands()) {
434-
if ((profileType == TosaProfileEnum::BaseInference) &&
434+
if ((profile == TosaProfileEnum::BaseInference) &&
435435
isa<FloatType>(getElementTypeOrSelf(operand))) {
436436
return signalPassFailure();
437437
}
@@ -451,6 +451,7 @@ void TosaValidation::runOnOperation() {
451451
}
452452
} // namespace
453453

454-
std::unique_ptr<Pass> mlir::tosa::createTosaValidationPass() {
455-
return std::make_unique<TosaValidation>();
454+
std::unique_ptr<Pass>
455+
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
456+
return std::make_unique<TosaValidation>(options);
456457
}

0 commit comments

Comments
 (0)