Skip to content

[mlir][tosa] Check for compile time constants in the validation pass #131123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2025

Conversation

lhutton1
Copy link
Contributor

This commit adds a concept of the 'dynamic' extension in the Dialect and checks that compile time constant (CTC) operands for each operator are constant if the dynamic extension is not loaded.

Operands labeled as CTC in the specification that are of tosa.shape (shape_t in the specification) type are not checked as they are always expected to be constant. This requirement is checked elsewhere in the dialect.

This commit adds a concept of the 'dynamic' extension in the Dialect
and checks that compile time constant (CTC) operands for each operator
are constant if the dynamic extension is not loaded.

Operands labeled as CTC in the specification that are of tosa.shape
(shape_t in the specification) type are not checked as they are always
expected to be constant. This requirement is checked elsewhere in the
dialect.

Change-Id: I2477e6a6cd1c01b6a66a5d9c4abc38eb064bab43
Signed-off-by: Luke Hutton <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Mar 13, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit adds a concept of the 'dynamic' extension in the Dialect and checks that compile time constant (CTC) operands for each operator are constant if the dynamic extension is not loaded.

Operands labeled as CTC in the specification that are of tosa.shape (shape_t in the specification) type are not checked as they are always expected to be constant. This requirement is checked elsewhere in the dialect.


Patch is 25.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131123.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+100-12)
  • (added) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+87)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+130-2)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 52130d31c248d..4cb2e8006ca57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -228,6 +228,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // CONTROLFLOW  : Control Flow operations.
 // DOUBLEROUND  : Adds double rounding support to the RESCALE operator.
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
+// DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -245,12 +246,13 @@ def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
 def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
 def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
 def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
+def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
       Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
-      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
+      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 88f454f63e6f9..69b827fe14dee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -146,6 +146,7 @@ class TosaProfileCompliance {
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:
+    case Extension::dynamic:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6b1b651a90cfb..79c13793d7713 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -41,17 +41,91 @@ using namespace mlir::tosa;
 
 namespace {
 
-static LogicalResult checkConstantOperandPad(Operation *op) {
+static LogicalResult
+checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
+  for (const auto index : operandIndices) {
+    Attribute attr;
+    if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
+      return op->emitOpError("expected compile time resolvable constant, but "
+                             "got variable value for operand #")
+             << index;
+    }
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandMul(Operation *op,
+                                             const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
+    // Check 'shift'
+    return checkConstantOperands(op, {2});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandTable(Operation *op,
+                                               const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
+    // Check 'table'
+    return checkConstantOperands(op, {1});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandPad(Operation *op,
+                                             const TargetEnv &env) {
   if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
-    DenseElementsAttr paddings;
-    if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
-      return op->emitOpError("padding of pad is not constant");
+    // Assume this op is zero-padding if padConst is not presented
+    if (!env.allows(Extension::dynamic) && padOp.getPadConst())
+      // Check 'pad_const'
+      // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
+      return checkConstantOperands(op, {2});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandRescale(Operation *op,
+                                                 const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
+    // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2, 3, 4});
+  }
+  return success();
+}
+
+template <typename T>
+static LogicalResult checkConstantOperandConvOps(Operation *op,
+                                                 const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
+    // Check 'input_zp' and 'weight_zp'
+    return checkConstantOperands(op, {3, 4});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandMatMul(Operation *op,
+                                                const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
+    // Check 'A_zp' and 'B_zp'
+    return checkConstantOperands(op, {2, 3});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
+                                                   const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
+    // Check 'input_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2});
+  }
+  return success();
+}
 
-    DenseElementsAttr padConst;
-    // Assume this op is zero-padding if padConst is not presented.
-    if (padOp.getPadConst() &&
-        !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
-      return op->emitOpError("pad_const of pad is not constant");
+static LogicalResult checkConstantOperandNegate(Operation *op,
+                                                const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
+    // Check 'input1_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2});
   }
   return success();
 }
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   LogicalResult applyConstantOperandCheck(Operation *op) {
     for (auto &checker : constCheckers) {
-      if (failed(checker(op)))
+      if (failed(checker(op, targetEnv)))
         return failure();
     }
     return success();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
 private:
   void populateConstantOperandChecks() {
+    constCheckers.emplace_back(checkConstantOperandMul);
+    constCheckers.emplace_back(checkConstantOperandTable);
     constCheckers.emplace_back(checkConstantOperandPad);
+    constCheckers.emplace_back(checkConstantOperandRescale);
+    constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
+    constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
+    constCheckers.emplace_back(
+        checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
+    constCheckers.emplace_back(
+        checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
+    constCheckers.emplace_back(checkConstantOperandMatMul);
+    constCheckers.emplace_back(checkConstantOperandAvgPool2d);
+    constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
   bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
                        << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
-                       << "doubleround and inexactround\n";
+                       << "doubleround, inexactround and dynamic\n";
           return signalPassFailure();
         }
       }
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type);
 
-  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
+  SmallVector<
+      std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
+      constCheckers;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
   TosaProfileCompliance profileComp;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
new file mode 100644
index 0000000000000..fd9b3d5f23483
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -0,0 +1,87 @@
+//--------------------------------------------------------
+// Check operations when the dynamic extension is enabled.
+//--------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+
+// -----
+
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+  return
+}
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_matmul_non_const_zps(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+func.func @test_negate_non_const_zps(%arg0: tensor<1xf32>, %input1_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1xf32> {
+  %0 = tosa.negate %arg0, %input1_zp, %output_zp {} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
+  %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..5b591e3c5f45c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -242,7 +242,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
 
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
   %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
-  // expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}}
+  // expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
   %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1a8366528b35c..13952716a9611 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -91,7 +91,7 @@ func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error@+1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
 
@@ -103,6 +103,134 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error@+1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  // expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi8>, %arg2: tensor<8xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x4x8xi32> {
+  %weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.conv2d' op expected compile time resolvable constant, but got variable value for operand #3}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xi8>, tensor<8x1x1x4xi8>, tensor<8xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi32>
+  return %0 : tensor<1x4x4x8xi32>
+}
+
+// -----
+
+func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.conv3d' op expected compile time resolvable constant, but got variable value for operand #4}}
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (t...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 13, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit adds a concept of the 'dynamic' extension in the Dialect and checks that compile time constant (CTC) operands for each operator are constant if the dynamic extension is not loaded.

Operands labeled as CTC in the specification that are of tosa.shape (shape_t in the specification) type are not checked as they are always expected to be constant. This requirement is checked elsewhere in the dialect.


Patch is 25.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131123.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+100-12)
  • (added) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+87)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+130-2)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 52130d31c248d..4cb2e8006ca57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -228,6 +228,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // CONTROLFLOW  : Control Flow operations.
 // DOUBLEROUND  : Adds double rounding support to the RESCALE operator.
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
+// DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -245,12 +246,13 @@ def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
 def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
 def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
 def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
+def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
       Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
-      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
+      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 88f454f63e6f9..69b827fe14dee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -146,6 +146,7 @@ class TosaProfileCompliance {
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:
+    case Extension::dynamic:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6b1b651a90cfb..79c13793d7713 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -41,17 +41,91 @@ using namespace mlir::tosa;
 
 namespace {
 
-static LogicalResult checkConstantOperandPad(Operation *op) {
+static LogicalResult
+checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
+  for (const auto index : operandIndices) {
+    Attribute attr;
+    if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
+      return op->emitOpError("expected compile time resolvable constant, but "
+                             "got variable value for operand #")
+             << index;
+    }
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandMul(Operation *op,
+                                             const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
+    // Check 'shift'
+    return checkConstantOperands(op, {2});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandTable(Operation *op,
+                                               const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
+    // Check 'table'
+    return checkConstantOperands(op, {1});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandPad(Operation *op,
+                                             const TargetEnv &env) {
   if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
-    DenseElementsAttr paddings;
-    if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
-      return op->emitOpError("padding of pad is not constant");
+    // Assume this op is zero-padding if padConst is not presented
+    if (!env.allows(Extension::dynamic) && padOp.getPadConst())
+      // Check 'pad_const'
+      // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
+      return checkConstantOperands(op, {2});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandRescale(Operation *op,
+                                                 const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
+    // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2, 3, 4});
+  }
+  return success();
+}
+
+template <typename T>
+static LogicalResult checkConstantOperandConvOps(Operation *op,
+                                                 const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
+    // Check 'input_zp' and 'weight_zp'
+    return checkConstantOperands(op, {3, 4});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandMatMul(Operation *op,
+                                                const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
+    // Check 'A_zp' and 'B_zp'
+    return checkConstantOperands(op, {2, 3});
+  }
+  return success();
+}
+
+static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
+                                                   const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
+    // Check 'input_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2});
+  }
+  return success();
+}
 
-    DenseElementsAttr padConst;
-    // Assume this op is zero-padding if padConst is not presented.
-    if (padOp.getPadConst() &&
-        !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
-      return op->emitOpError("pad_const of pad is not constant");
+static LogicalResult checkConstantOperandNegate(Operation *op,
+                                                const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
+    // Check 'input1_zp' and 'output_zp'
+    return checkConstantOperands(op, {1, 2});
   }
   return success();
 }
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   LogicalResult applyConstantOperandCheck(Operation *op) {
     for (auto &checker : constCheckers) {
-      if (failed(checker(op)))
+      if (failed(checker(op, targetEnv)))
         return failure();
     }
     return success();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
 private:
   void populateConstantOperandChecks() {
+    constCheckers.emplace_back(checkConstantOperandMul);
+    constCheckers.emplace_back(checkConstantOperandTable);
     constCheckers.emplace_back(checkConstantOperandPad);
+    constCheckers.emplace_back(checkConstantOperandRescale);
+    constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
+    constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
+    constCheckers.emplace_back(
+        checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
+    constCheckers.emplace_back(
+        checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
+    constCheckers.emplace_back(checkConstantOperandMatMul);
+    constCheckers.emplace_back(checkConstantOperandAvgPool2d);
+    constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
   bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
                        << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
-                       << "doubleround and inexactround\n";
+                       << "doubleround, inexactround and dynamic\n";
           return signalPassFailure();
         }
       }
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type);
 
-  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
+  SmallVector<
+      std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
+      constCheckers;
   TosaLevel tosaLevel;
   DenseMap<StringAttr, mlir::Type> variablesMap;
   TosaProfileCompliance profileComp;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
new file mode 100644
index 0000000000000..fd9b3d5f23483
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -0,0 +1,87 @@
+//--------------------------------------------------------
+// Check operations when the dynamic extension is enabled.
+//--------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+
+// -----
+
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+  return
+}
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_matmul_non_const_zps(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
+  %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
+  return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+func.func @test_negate_non_const_zps(%arg0: tensor<1xf32>, %input1_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1xf32> {
+  %0 = tosa.negate %arg0, %input1_zp, %output_zp {} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
+  %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
+         (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..5b591e3c5f45c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -242,7 +242,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
 
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
   %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
-  // expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}}
+  // expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
   %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1a8366528b35c..13952716a9611 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -91,7 +91,7 @@ func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error@+1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
 
@@ -103,6 +103,134 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error@+1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
+
+// -----
+
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
+  %zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  // expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi8>, %arg2: tensor<8xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x4x8xi32> {
+  %weight_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.conv2d' op expected compile time resolvable constant, but got variable value for operand #3}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xi8>, tensor<8x1x1x4xi8>, tensor<8xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi32>
+  return %0 : tensor<1x4x4x8xi32>
+}
+
+// -----
+
+func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> {
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.conv3d' op expected compile time resolvable constant, but got variable value for operand #4}}
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (t...
[truncated]

@CoTinker CoTinker requested a review from FranklandJack March 13, 2025 12:32
Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Jerry-Ge Jerry-Ge merged commit 955c02d into llvm:main Mar 14, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants