Skip to content

[mlir][tosa] Check for compile time constants in the validation pass #131123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// CONTROLFLOW : Control Flow operations.
// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
//===----------------------------------------------------------------------===//

def Tosa_NONE : I32EnumAttrCase<"none", 0>;
Expand All @@ -245,12 +246,13 @@ def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;

def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC, Tosa_EXT_NONE
]>;

def Tosa_ExtensionArrayAttr
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class TosaProfileCompliance {
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
case Extension::dynamic:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
Expand Down
112 changes: 100 additions & 12 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,91 @@ using namespace mlir::tosa;

namespace {

static LogicalResult checkConstantOperandPad(Operation *op) {
static LogicalResult
checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
for (const auto index : operandIndices) {
Attribute attr;
if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
return op->emitOpError("expected compile time resolvable constant, but "
"got variable value for operand #")
<< index;
}
}
return success();
}

static LogicalResult checkConstantOperandMul(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
// Check 'shift'
return checkConstantOperands(op, {2});
}
return success();
}

static LogicalResult checkConstantOperandTable(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
// Check 'table'
return checkConstantOperands(op, {1});
}
return success();
}

static LogicalResult checkConstantOperandPad(Operation *op,
const TargetEnv &env) {
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
DenseElementsAttr paddings;
if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
return op->emitOpError("padding of pad is not constant");
// Assume this op is zero-padding if padConst is not presented
if (!env.allows(Extension::dynamic) && padOp.getPadConst())
// Check 'pad_const'
// Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
return checkConstantOperands(op, {2});
}
return success();
}

static LogicalResult checkConstantOperandRescale(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
// Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
return checkConstantOperands(op, {1, 2, 3, 4});
}
return success();
}

template <typename T>
static LogicalResult checkConstantOperandConvOps(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<T>(op)) {
// Check 'input_zp' and 'weight_zp'
return checkConstantOperands(op, {3, 4});
}
return success();
}

static LogicalResult checkConstantOperandMatMul(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
// Check 'A_zp' and 'B_zp'
return checkConstantOperands(op, {2, 3});
}
return success();
}

static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
// Check 'input_zp' and 'output_zp'
return checkConstantOperands(op, {1, 2});
}
return success();
}

DenseElementsAttr padConst;
// Assume this op is zero-padding if padConst is not presented.
if (padOp.getPadConst() &&
!matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
return op->emitOpError("pad_const of pad is not constant");
static LogicalResult checkConstantOperandNegate(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
// Check 'input1_zp' and 'output_zp'
return checkConstantOperands(op, {1, 2});
}
return success();
}
Expand Down Expand Up @@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {

LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : constCheckers) {
if (failed(checker(op)))
if (failed(checker(op, targetEnv)))
return failure();
}
return success();
Expand All @@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {

private:
void populateConstantOperandChecks() {
constCheckers.emplace_back(checkConstantOperandMul);
constCheckers.emplace_back(checkConstantOperandTable);
constCheckers.emplace_back(checkConstantOperandPad);
constCheckers.emplace_back(checkConstantOperandRescale);
constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
constCheckers.emplace_back(
checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
constCheckers.emplace_back(
checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
constCheckers.emplace_back(checkConstantOperandMatMul);
constCheckers.emplace_back(checkConstantOperandAvgPool2d);
constCheckers.emplace_back(checkConstantOperandNegate);
}

bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
Expand Down Expand Up @@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
<< "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
<< "doubleround and inexactround\n";
<< "doubleround, inexactround and dynamic\n";
return signalPassFailure();
}
}
Expand All @@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type);

SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
Expand Down
87 changes: 87 additions & 0 deletions mlir/test/Dialect/Tosa/dynamic_extension.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//--------------------------------------------------------
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------

// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"

// -----

func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}

// -----

func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
return
}

// -----

func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}

// -----

func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}

// -----

func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}

// -----

func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}

// -----

func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}

// -----

func.func @test_matmul_non_const_zps(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}

// -----

func.func @test_negate_non_const_zps(%arg0: tensor<1xf32>, %input1_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1xf32> {
%0 = tosa.negate %arg0, %input1_zp, %output_zp {} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}

// -----

func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
%0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)

func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}}
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
Expand Down
Loading