Skip to content

Commit 955c02d

Browse files
authored
[mlir][tosa] Check for compile time constants in the validation pass (#131123)
This commit adds a concept of the 'dynamic' extension in the Dialect and checks that compile time constant (CTC) operands for each operator are constant if the dynamic extension is not loaded. Operands labeled as CTC in the specification that are of tosa.shape (shape_t in the specification) type are not checked as they are always expected to be constant. This requirement is checked elsewhere in the dialect. Signed-off-by: Luke Hutton <[email protected]>
1 parent 7a7c33d commit 955c02d

File tree

6 files changed

+322
-16
lines changed

6 files changed

+322
-16
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
228228
// CONTROLFLOW : Control Flow operations.
229229
// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
230230
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
231+
// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
231232
//===----------------------------------------------------------------------===//
232233

233234
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -245,12 +246,13 @@ def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
245246
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
246247
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
247248
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
249+
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
248250

249251
def Tosa_ExtensionAttr
250252
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
251253
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
252254
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
253-
Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
255+
Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC, Tosa_EXT_NONE
254256
]>;
255257

256258
def Tosa_ExtensionArrayAttr

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class TosaProfileCompliance {
146146
return {Profile::pro_fp};
147147
case Extension::variable:
148148
case Extension::controlflow:
149+
case Extension::dynamic:
149150
return {Profile::pro_fp, Profile::pro_int};
150151
case Extension::none:
151152
return {};

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,91 @@ using namespace mlir::tosa;
4141

4242
namespace {
4343

44-
static LogicalResult checkConstantOperandPad(Operation *op) {
44+
static LogicalResult
45+
checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
46+
for (const auto index : operandIndices) {
47+
Attribute attr;
48+
if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
49+
return op->emitOpError("expected compile time resolvable constant, but "
50+
"got variable value for operand #")
51+
<< index;
52+
}
53+
}
54+
return success();
55+
}
56+
57+
static LogicalResult checkConstantOperandMul(Operation *op,
58+
const TargetEnv &env) {
59+
if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60+
// Check 'shift'
61+
return checkConstantOperands(op, {2});
62+
}
63+
return success();
64+
}
65+
66+
static LogicalResult checkConstantOperandTable(Operation *op,
67+
const TargetEnv &env) {
68+
if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69+
// Check 'table'
70+
return checkConstantOperands(op, {1});
71+
}
72+
return success();
73+
}
74+
75+
static LogicalResult checkConstantOperandPad(Operation *op,
76+
const TargetEnv &env) {
4577
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
46-
DenseElementsAttr paddings;
47-
if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
48-
return op->emitOpError("padding of pad is not constant");
78+
// Assume this op is zero-padding if padConst is not presented
79+
if (!env.allows(Extension::dynamic) && padOp.getPadConst())
80+
// Check 'pad_const'
81+
// Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82+
return checkConstantOperands(op, {2});
83+
}
84+
return success();
85+
}
86+
87+
static LogicalResult checkConstantOperandRescale(Operation *op,
88+
const TargetEnv &env) {
89+
if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90+
// Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91+
return checkConstantOperands(op, {1, 2, 3, 4});
92+
}
93+
return success();
94+
}
95+
96+
template <typename T>
97+
static LogicalResult checkConstantOperandConvOps(Operation *op,
98+
const TargetEnv &env) {
99+
if (!env.allows(Extension::dynamic) && isa<T>(op)) {
100+
// Check 'input_zp' and 'weight_zp'
101+
return checkConstantOperands(op, {3, 4});
102+
}
103+
return success();
104+
}
105+
106+
static LogicalResult checkConstantOperandMatMul(Operation *op,
107+
const TargetEnv &env) {
108+
if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109+
// Check 'A_zp' and 'B_zp'
110+
return checkConstantOperands(op, {2, 3});
111+
}
112+
return success();
113+
}
114+
115+
static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
116+
const TargetEnv &env) {
117+
if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118+
// Check 'input_zp' and 'output_zp'
119+
return checkConstantOperands(op, {1, 2});
120+
}
121+
return success();
122+
}
49123

50-
DenseElementsAttr padConst;
51-
// Assume this op is zero-padding if padConst is not presented.
52-
if (padOp.getPadConst() &&
53-
!matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
54-
return op->emitOpError("pad_const of pad is not constant");
124+
static LogicalResult checkConstantOperandNegate(Operation *op,
125+
const TargetEnv &env) {
126+
if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127+
// Check 'input1_zp' and 'output_zp'
128+
return checkConstantOperands(op, {1, 2});
55129
}
56130
return success();
57131
}
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
97171

98172
LogicalResult applyConstantOperandCheck(Operation *op) {
99173
for (auto &checker : constCheckers) {
100-
if (failed(checker(op)))
174+
if (failed(checker(op, targetEnv)))
101175
return failure();
102176
}
103177
return success();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
114188

115189
private:
116190
void populateConstantOperandChecks() {
191+
constCheckers.emplace_back(checkConstantOperandMul);
192+
constCheckers.emplace_back(checkConstantOperandTable);
117193
constCheckers.emplace_back(checkConstantOperandPad);
194+
constCheckers.emplace_back(checkConstantOperandRescale);
195+
constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
196+
constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
197+
constCheckers.emplace_back(
198+
checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
199+
constCheckers.emplace_back(
200+
checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
201+
constCheckers.emplace_back(checkConstantOperandMatMul);
202+
constCheckers.emplace_back(checkConstantOperandAvgPool2d);
203+
constCheckers.emplace_back(checkConstantOperandNegate);
118204
}
119205

120206
bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
436522
llvm::errs() << "unknown TOSA extension name passed in: " << ext
437523
<< ", supported extension are int16, int4, bf16, "
438524
<< "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
439-
<< "doubleround and inexactround\n";
525+
<< "doubleround, inexactround and dynamic\n";
440526
return signalPassFailure();
441527
}
442528
}
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
447533
bool CheckVariableReadOrWrite(Operation *op);
448534
bool isValidElementType(Type type);
449535

450-
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
536+
SmallVector<
537+
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
538+
constCheckers;
451539
TosaLevel tosaLevel;
452540
DenseMap<StringAttr, mlir::Type> variablesMap;
453541
TosaProfileCompliance profileComp;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//--------------------------------------------------------
2+
// Check operations when the dynamic extension is enabled.
3+
//--------------------------------------------------------
4+
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
6+
7+
// -----
8+
9+
func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
10+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
11+
return %0 : tensor<13x21x3xi8>
12+
}
13+
14+
// -----
15+
16+
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
17+
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
18+
return
19+
}
20+
21+
// -----
22+
23+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
24+
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
25+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
26+
return %1 : tensor<13x21x3xi8>
27+
}
28+
29+
// -----
30+
31+
func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
32+
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
33+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
34+
%0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
35+
return %0 : tensor<13x21x3xi32>
36+
}
37+
38+
// -----
39+
40+
func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
41+
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
42+
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
43+
%0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
44+
return %0 : tensor<13x21x3xi32>
45+
}
46+
47+
// -----
48+
49+
func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
50+
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
51+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
52+
%output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
53+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
54+
return %0 : tensor<13x21x3xi32>
55+
}
56+
57+
// -----
58+
59+
func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_zp: tensor<1xi32>) -> tensor<13x21x3xi32> {
60+
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
61+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
62+
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
63+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
64+
return %0 : tensor<13x21x3xi32>
65+
}
66+
67+
// -----
68+
69+
func.func @test_matmul_non_const_zps(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> {
70+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
71+
return %0 : tensor<1x14x28xf32>
72+
}
73+
74+
// -----
75+
76+
func.func @test_negate_non_const_zps(%arg0: tensor<1xf32>, %input1_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1xf32> {
77+
%0 = tosa.negate %arg0, %input1_zp, %output_zp {} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
78+
return %0 : tensor<1xf32>
79+
}
80+
81+
// -----
82+
83+
func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
84+
%0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32} :
85+
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
86+
return %0 : tensor<1x32x32x8xf32>
87+
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
242242

243243
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
244244
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
245-
// expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}}
245+
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
246246
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
247247
return %1 : tensor<13x21x3xi8>
248248
}

0 commit comments

Comments
 (0)