Skip to content

Commit d4570ea

Browse files
authored
[mlir][tosa] Disallow invalid datatype combinations in the validation pass (#131595)
This commit checks if the operands/results of an operator can be found in the profile compliance mapping, if it isn't the operator is considered invalid. As a result, operator datatype combinations that are not listed under the "Supported Data Types" of the TOSA specification are disallowed and the validation pass results in failure. Signed-off-by: Luke Hutton <[email protected]>
1 parent f8e908a commit d4570ea

File tree

11 files changed

+97
-40
lines changed

11 files changed

+97
-40
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
4040
// Note: Default to 'none' level unless otherwise specified.
4141
std::optional<tosa::TosaValidationOptions> validationOptions =
4242
tosa::TosaValidationOptions{
43-
{"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
43+
{"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
4444

4545
/// Populates TOSA to linalg pipelines
4646
/// Currently, this includes only the "tosa-to-linalg-pipeline".

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
115115
// environment.
116116
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
117117
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
118+
LogicalResult checkInvalid(Operation *op);
118119

119120
template <typename T>
120121
LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
163164
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
164165

165166
private:
167+
template <typename T>
168+
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
169+
CheckCondition &condition);
170+
166171
OperationProfileComplianceMap profileComplianceMap;
167172
OperationExtensionComplianceMap extensionComplianceMap;
168173
};

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
9494
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
9595
/*default=*/"false",
9696
"Verify if the properties of certain operations align the spec requirement">,
97+
Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
98+
/*default=*/"false",
99+
"Disable checks for operations that are determined to be invalid due to their "
100+
"operand/result datatypes not aligning with the 'Supported Data Types' "
101+
"sections of the specifciation">,
97102
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
98103
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
99104
"Validate if operator parameters are within specfication for the given level",

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
119119
validationOptions.profile = {"none"};
120120
validationOptions.extension = {"none"};
121121
validationOptions.strictOpSpecAlignment = false;
122+
validationOptions.allowInvalidOpDatatypeCombinations = false;
122123
validationOptions.level = tosa::TosaLevelEnum::EightK;
123124
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
124125
tosaToLinalgNamedOptions,

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,15 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
140140
template <>
141141
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142142
addValue(op.getValues());
143+
addValue(op.getIndices());
143144
addValue(op.getOutput());
144145
return success();
145146
}
146147

147148
template <>
148149
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
149150
addValue(op.getValuesIn());
151+
addValue(op.getIndices());
150152
addValue(op.getInput());
151153
addValue(op.getValuesOut());
152154
return success();
@@ -347,6 +349,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
347349
// Tosa Profile And Extension Compliance Checker
348350
//===----------------------------------------------------------------------===//
349351

352+
template <typename T>
353+
FailureOr<SmallVector<T>>
354+
TosaProfileCompliance::getOperatorDefinition(Operation *op,
355+
CheckCondition &condition) {
356+
const std::string opName = op->getName().getStringRef().str();
357+
const auto complianceMap = getProfileComplianceMap<T>();
358+
const auto it = complianceMap.find(opName);
359+
if (it == complianceMap.end())
360+
return {};
361+
362+
return findMatchedProfile<T>(op, it->second, condition);
363+
}
364+
350365
template <typename T>
351366
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
352367
Operation *op, const tosa::TargetEnv &targetEnv,
@@ -356,11 +371,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
356371
if (specRequiredModeSet.size() == 0)
357372
return success();
358373

359-
auto opName = op->getName().getStringRef().str();
360-
auto compMap = getProfileComplianceMap<T>();
361-
auto it = compMap.find(opName);
362-
363-
if (it == compMap.end()) {
374+
CheckCondition condition = CheckCondition::invalid;
375+
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
376+
if (failed(maybeOpRequiredMode)) {
364377
// Operators such as control-flow and shape ops do not have an operand type
365378
// restriction. When the profile compliance information of operation is not
366379
// found, confirm if the target have enabled the profile required from the
@@ -381,12 +394,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
381394
return failure();
382395
}
383396

384-
CheckCondition condition = CheckCondition::invalid;
385-
// Find the profiles or extensions requirement according to the signature of
386-
// type of the operand list.
387-
SmallVector<T> opRequiredMode =
388-
findMatchedProfile<T>(op, it->second, condition);
389-
397+
// Find the required profiles or extensions according to the operand type
398+
// combination.
399+
const auto opRequiredMode = maybeOpRequiredMode.value();
390400
if (opRequiredMode.size() == 0) {
391401
// No matched restriction found.
392402
return success();
@@ -466,6 +476,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
466476
return success();
467477
}
468478

479+
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
480+
CheckCondition condition = CheckCondition::invalid;
481+
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
482+
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
483+
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
484+
!maybeProfDef.value().size() && !maybeExtDef.value().size())
485+
return failure();
486+
487+
return success();
488+
}
489+
469490
// Find the profiles or extensions requirement according to the signature of
470491
// type of the operand list.
471492
template <typename T>
@@ -483,7 +504,6 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
483504

484505
for (size_t i = 0; i < compInfo.size(); i++) {
485506
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
486-
487507
for (SmallVector<TypeInfo> expected : sets) {
488508
assert(present.size() == expected.size() &&
489509
"the entries for profile-based compliance do not match between "

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
165165
this->profile = options.profile;
166166
this->extension = options.extension;
167167
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
168+
this->allowInvalidOpDatatypeCombinations =
169+
options.allowInvalidOpDatatypeCombinations;
168170
this->level = options.level;
169171
}
170172
void runOnOperation() final;
@@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
10421044
}
10431045
}
10441046

1047+
if (!allowInvalidOpDatatypeCombinations &&
1048+
failed(profileComp.checkInvalid(op))) {
1049+
op->emitOpError("illegal: operand/result data types not supported");
1050+
return signalPassFailure();
1051+
}
1052+
10451053
// Some uses of TOSA rely on the constant operands of particular
10461054
// operations.
10471055
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44
// -----
55

66
// check that -tosa-validate of stateful ops kick in
7-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
8-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
7+
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
8+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
99
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
10-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
10+
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
1111
return
1212
}
1313

1414
// -----
1515

1616
// check that -tosa-validate level checking kick in
17-
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
17+
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
1818
// expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}}
19-
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
20-
return %0 : tensor<*xi8>
19+
%0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
20+
return %0 : tensor<*xi32>
2121
}
2222

2323
// -----

mlir/test/Dialect/Tosa/dynamic_extension.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
// Check operations when the dynamic extension is enabled.
33
//--------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
66

77
// -----
88

9-
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
10-
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
11-
return %0 : tensor<13x21x3xi8>
9+
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
10+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
11+
return %0 : tensor<13x21x3xi16>
1212
}
1313

1414
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -616,26 +616,26 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
616616

617617
// -----
618618

619-
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
620-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
619+
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
620+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
621621
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
622-
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi32>
622+
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
623623
return
624624
}
625625

626626
// -----
627627

628-
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
629-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
628+
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
629+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
630630
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
631631
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
632632
return
633633
}
634634

635635
// -----
636636

637-
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
638-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
637+
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
638+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
639639
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
640640
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
641641
return
@@ -644,18 +644,18 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
644644
// -----
645645

646646
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
647-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
647+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
648648
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
649649
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
650650
return
651651
}
652652

653653
// -----
654654

655-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
656-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
655+
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
656+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
657657
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
658-
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
658+
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
659659
return
660660
}
661661

@@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
19211921
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
19221922
return %1 : tensor<f32>
19231923
}
1924+
1925+
// -----
1926+
1927+
// CHECK-LABEL: test_add_i1
1928+
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
1929+
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
1930+
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
1931+
return %0 : tensor<13x21x3xi1>
1932+
}
1933+
1934+
// -----
1935+
1936+
// CHECK-LABEL: test_mul_out_i16
1937+
func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
1938+
// expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
1939+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
1940+
return %0 : tensor<13x21x3xi16>
1941+
}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
191191

192192
// -----
193193

194-
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
194+
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
195195
// expected-error@+1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
196-
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
197-
return %0 : tensor<13x21x3xi8>
196+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
197+
return %0 : tensor<13x21x3xi32>
198198
}
199199

200200
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
169169

170170
// -----
171171

172-
func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
172+
func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
173173
// expected-error@+1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
174-
%0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
175-
return %0 : tensor<1x1x1x1x1x1x64xi16>
174+
%0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
175+
return %0 : tensor<1x1x1x1x1x1x64xi32>
176176
}
177177

178178
// -----

0 commit comments

Comments
 (0)