Skip to content

Commit 3bb61ab

Browse files
committed
[mlir][tosa] Disallow invalid datatype combinations in the validation pass
This commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure. Signed-off-by: Luke Hutton <[email protected]> Change-Id: Iab36dd84cdbf188015c80b066c321edbb2efc0ff
1 parent 0f2fb2b commit 3bb61ab

File tree

11 files changed

+81
-25
lines changed

11 files changed

+81
-25
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
4040
// Note: Default to 'none' level unless otherwise specified.
4141
std::optional<tosa::TosaValidationOptions> validationOptions =
4242
tosa::TosaValidationOptions{
43-
{"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
43+
{"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
4444

4545
/// Populates TOSA to linalg pipelines
4646
/// Currently, this includes only the "tosa-to-linalg-pipeline".

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
115115
// environment.
116116
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
117117
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
118+
LogicalResult checkInvalid(Operation *op);
118119

119120
template <typename T>
120121
LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
163164
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
164165

165166
private:
167+
template <typename T>
168+
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
169+
CheckCondition &condition);
170+
166171
OperationProfileComplianceMap profileComplianceMap;
167172
OperationExtensionComplianceMap extensionComplianceMap;
168173
};

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
9494
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
9595
/*default=*/"false",
9696
"Verify if the properties of certain operations align the spec requirement">,
97+
Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
98+
/*default=*/"false",
99+
"Disable checks for operations that are determined to be invalid due to their "
100+
"operand/result datatypes not aligning with the 'Supported Data Types' "
101+
"sections of the specifciation">,
97102
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
98103
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
99104
"Validate if operator parameters are within specfication for the given level",

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
119119
validationOptions.profile = {"none"};
120120
validationOptions.extension = {"none"};
121121
validationOptions.strictOpSpecAlignment = false;
122+
validationOptions.allowInvalidOpDatatypeCombinations = false;
122123
validationOptions.level = tosa::TosaLevelEnum::EightK;
123124
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
124125
tosaToLinalgNamedOptions,

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
300300
// Tosa Profile And Extension Compliance Checker
301301
//===----------------------------------------------------------------------===//
302302

303+
template <typename T>
304+
FailureOr<SmallVector<T>>
305+
TosaProfileCompliance::getOperatorDefinition(Operation *op,
306+
CheckCondition &condition) {
307+
const std::string opName = op->getName().getStringRef().str();
308+
const auto complianceMap = getProfileComplianceMap<T>();
309+
const auto it = complianceMap.find(opName);
310+
if (it == complianceMap.end())
311+
return {};
312+
313+
return findMatchedProfile<T>(op, it->second, condition);
314+
}
315+
303316
template <typename T>
304317
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
305318
Operation *op, const tosa::TargetEnv &targetEnv,
@@ -309,11 +322,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
309322
if (specRequiredModeSet.size() == 0)
310323
return success();
311324

312-
auto opName = op->getName().getStringRef().str();
313-
auto compMap = getProfileComplianceMap<T>();
314-
auto it = compMap.find(opName);
315-
316-
if (it == compMap.end()) {
325+
CheckCondition condition = CheckCondition::invalid;
326+
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
327+
if (failed(maybeOpRequiredMode)) {
317328
// Operators such as variable and shape ops do not have an operand type
318329
// restriction. When the profile compliance information of operation is not
319330
// found, confirm if the target have enabled the profile required from the
@@ -334,12 +345,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
334345
return failure();
335346
}
336347

337-
CheckCondition condition = CheckCondition::invalid;
338-
// Find the profiles or extensions requirement according to the signature of
339-
// type of the operand list.
340-
SmallVector<T> opRequiredMode =
341-
findMatchedProfile<T>(op, it->second, condition);
342-
348+
// Find the required profiles or extensions according to the operand type
349+
// combination.
350+
const auto opRequiredMode = maybeOpRequiredMode.value();
343351
if (opRequiredMode.size() == 0) {
344352
// No matched restriction found.
345353
return success();
@@ -419,6 +427,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
419427
return success();
420428
}
421429

430+
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
431+
CheckCondition condition = CheckCondition::invalid;
432+
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
433+
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
434+
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
435+
!maybeProfDef.value().size() && !maybeExtDef.value().size())
436+
return failure();
437+
438+
return success();
439+
}
440+
422441
// Find the profiles or extensions requirement according to the signature of
423442
// type of the operand list.
424443
template <typename T>

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
165165
this->profile = options.profile;
166166
this->extension = options.extension;
167167
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
168+
this->allowInvalidOpDatatypeCombinations =
169+
options.allowInvalidOpDatatypeCombinations;
168170
this->level = options.level;
169171
}
170172
void runOnOperation() final;
@@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
10421044
}
10431045
}
10441046

1047+
if (!allowInvalidOpDatatypeCombinations &&
1048+
failed(profileComp.checkInvalid(op))) {
1049+
op->emitOpError("illegal: operand/result data types not supported");
1050+
return signalPassFailure();
1051+
}
1052+
10451053
// Some uses of TOSA rely on the constant operands of particular
10461054
// operations.
10471055
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
1414
// -----
1515

1616
// check that -tosa-validate level checking kick in
17-
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
17+
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
1818
// expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}}
19-
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
20-
return %0 : tensor<*xi8>
19+
%0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
20+
return %0 : tensor<*xi32>
2121
}
2222

2323
// -----

mlir/test/Dialect/Tosa/dynamic_extension.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
// Check operations when the dynamic extension is enabled.
33
//--------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
66

77
// -----
88

9-
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
10-
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
11-
return %0 : tensor<13x21x3xi8>
9+
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
10+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
11+
return %0 : tensor<13x21x3xi16>
1212
}
1313

1414
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
19211921
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
19221922
return %1 : tensor<f32>
19231923
}
1924+
1925+
// -----
1926+
1927+
// CHECK-LABEL: test_add_i1
1928+
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
1929+
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
1930+
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
1931+
return %0 : tensor<13x21x3xi1>
1932+
}
1933+
1934+
// -----
1935+
1936+
// CHECK-LABEL: test_mul_out_i16
1937+
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
1938+
// expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
1939+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
1940+
return %0 : tensor<13x21x3xi16>
1941+
}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
191191

192192
// -----
193193

194-
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
194+
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
195195
// expected-error@+1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
196-
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
197-
return %0 : tensor<13x21x3xi8>
196+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
197+
return %0 : tensor<13x21x3xi32>
198198
}
199199

200200
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
169169

170170
// -----
171171

172-
func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
172+
func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
173173
// expected-error@+1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
174-
%0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
175-
return %0 : tensor<1x1x1x1x1x1x64xi16>
174+
%0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
175+
return %0 : tensor<1x1x1x1x1x1x64xi32>
176176
}
177177

178178
// -----

0 commit comments

Comments
 (0)