Skip to content

[mlir][tosa] Disallow invalid datatype combinations in the validation pass #131595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
tosa::TosaValidationOptions{
{"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
{"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});

/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class TosaProfileCompliance {
// environment.
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
LogicalResult checkInvalid(Operation *op);

template <typename T>
LogicalResult checkProfileOrExtension(
Expand Down Expand Up @@ -163,6 +164,10 @@ class TosaProfileCompliance {
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);

private:
template <typename T>
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
CheckCondition &condition);

OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
};
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
"sections of the specifciation">,
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
"Validate if operator parameters are within specfication for the given level",
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
validationOptions.profile = {"none"};
validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
Expand Down
44 changes: 32 additions & 12 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,15 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
addValue(op.getValues());
addValue(op.getIndices());
addValue(op.getOutput());
return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
addValue(op.getIndices());
addValue(op.getInput());
addValue(op.getValuesOut());
return success();
Expand Down Expand Up @@ -347,6 +349,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//

template <typename T>
FailureOr<SmallVector<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op,
CheckCondition &condition) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};

return findMatchedProfile<T>(op, it->second, condition);
}

template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
Operation *op, const tosa::TargetEnv &targetEnv,
Expand All @@ -356,11 +371,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();

auto opName = op->getName().getStringRef().str();
auto compMap = getProfileComplianceMap<T>();
auto it = compMap.find(opName);

if (it == compMap.end()) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
if (failed(maybeOpRequiredMode)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
Expand All @@ -381,12 +394,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
return failure();
}

CheckCondition condition = CheckCondition::invalid;
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
SmallVector<T> opRequiredMode =
findMatchedProfile<T>(op, it->second, condition);

// Find the required profiles or extensions according to the operand type
// combination.
const auto opRequiredMode = maybeOpRequiredMode.value();
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
Expand Down Expand Up @@ -466,6 +476,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
return success();
}

LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
!maybeProfDef.value().size() && !maybeExtDef.value().size())
return failure();

return success();
}

// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
Expand All @@ -483,7 +504,6 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(

for (size_t i = 0; i < compInfo.size(); i++) {
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;

for (SmallVector<TypeInfo> expected : sets) {
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
this->profile = options.profile;
this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
this->level = options.level;
}
void runOnOperation() final;
Expand Down Expand Up @@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
}
}

if (!allowInvalidOpDatatypeCombinations &&
failed(profileComp.checkInvalid(op))) {
op->emitOpError("illegal: operand/result data types not supported");
return signalPassFailure();
}

// Some uses of TOSA rely on the constant operands of particular
// operations.
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
// -----

// check that -tosa-validate of stateful ops kick in
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}

// -----

// check that -tosa-validate level checking kick in
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}}
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
return %0 : tensor<*xi8>
%0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}

// -----
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Tosa/dynamic_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------

// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"

// -----

func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}

// -----
Expand Down
40 changes: 29 additions & 11 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -616,26 +616,26 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten

// -----

func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi32>
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
return
}

// -----

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
return
Expand All @@ -644,18 +644,18 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
// -----

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}

Expand Down Expand Up @@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
return %1 : tensor<f32>
}

// -----

// CHECK-LABEL: test_add_i1
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
}

// -----

// CHECK-LABEL: test_mul_out_i16
func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
// expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<

// -----

func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
// expected-error@+1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso

// -----

func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
// expected-error@+1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
%0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
return %0 : tensor<1x1x1x1x1x1x64xi16>
%0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
return %0 : tensor<1x1x1x1x1x1x64xi32>
}

// -----
Expand Down