@@ -41,17 +41,91 @@ using namespace mlir::tosa;
41
41
42
42
namespace {
43
43
44
- static LogicalResult checkConstantOperandPad (Operation *op) {
44
+ static LogicalResult
45
+ checkConstantOperands (Operation *op, ArrayRef<unsigned int > operandIndices) {
46
+ for (const auto index : operandIndices) {
47
+ Attribute attr;
48
+ if (!matchPattern (op->getOperand (index), m_Constant (&attr))) {
49
+ return op->emitOpError (" expected compile time resolvable constant, but "
50
+ " got variable value for operand #" )
51
+ << index;
52
+ }
53
+ }
54
+ return success ();
55
+ }
56
+
57
+ static LogicalResult checkConstantOperandMul (Operation *op,
58
+ const TargetEnv &env) {
59
+ if (!env.allows (Extension::dynamic) && isa<tosa::MulOp>(op)) {
60
+ // Check 'shift'
61
+ return checkConstantOperands (op, {2 });
62
+ }
63
+ return success ();
64
+ }
65
+
66
+ static LogicalResult checkConstantOperandTable (Operation *op,
67
+ const TargetEnv &env) {
68
+ if (!env.allows (Extension::dynamic) && isa<tosa::TableOp>(op)) {
69
+ // Check 'table'
70
+ return checkConstantOperands (op, {1 });
71
+ }
72
+ return success ();
73
+ }
74
+
75
+ static LogicalResult checkConstantOperandPad (Operation *op,
76
+ const TargetEnv &env) {
45
77
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
46
- DenseElementsAttr paddings;
47
- if (!matchPattern (padOp.getPadding (), m_Constant (&paddings)))
48
- return op->emitOpError (" padding of pad is not constant" );
78
+ // Assume this op is zero-padding if padConst is not presented
79
+ if (!env.allows (Extension::dynamic) && padOp.getPadConst ())
80
+ // Check 'pad_const'
81
+ // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82
+ return checkConstantOperands (op, {2 });
83
+ }
84
+ return success ();
85
+ }
86
+
87
+ static LogicalResult checkConstantOperandRescale (Operation *op,
88
+ const TargetEnv &env) {
89
+ if (!env.allows (Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90
+ // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91
+ return checkConstantOperands (op, {1 , 2 , 3 , 4 });
92
+ }
93
+ return success ();
94
+ }
95
+
96
+ template <typename T>
97
+ static LogicalResult checkConstantOperandConvOps (Operation *op,
98
+ const TargetEnv &env) {
99
+ if (!env.allows (Extension::dynamic) && isa<T>(op)) {
100
+ // Check 'input_zp' and 'weight_zp'
101
+ return checkConstantOperands (op, {3 , 4 });
102
+ }
103
+ return success ();
104
+ }
105
+
106
+ static LogicalResult checkConstantOperandMatMul (Operation *op,
107
+ const TargetEnv &env) {
108
+ if (!env.allows (Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109
+ // Check 'A_zp' and 'B_zp'
110
+ return checkConstantOperands (op, {2 , 3 });
111
+ }
112
+ return success ();
113
+ }
114
+
115
+ static LogicalResult checkConstantOperandAvgPool2d (Operation *op,
116
+ const TargetEnv &env) {
117
+ if (!env.allows (Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118
+ // Check 'input_zp' and 'output_zp'
119
+ return checkConstantOperands (op, {1 , 2 });
120
+ }
121
+ return success ();
122
+ }
49
123
50
- DenseElementsAttr padConst;
51
- // Assume this op is zero-padding if padConst is not presented.
52
- if (padOp. getPadConst ( ) &&
53
- ! matchPattern (padOp. getPadConst (), m_Constant (&padConst)))
54
- return op-> emitOpError ( " pad_const of pad is not constant " );
124
+ static LogicalResult checkConstantOperandNegate (Operation *op,
125
+ const TargetEnv &env) {
126
+ if (!env. allows (Extension::dynamic ) && isa<tosa::NegateOp>(op)) {
127
+ // Check 'input1_zp' and 'output_zp'
128
+ return checkConstantOperands (op, { 1 , 2 } );
55
129
}
56
130
return success ();
57
131
}
@@ -97,7 +171,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
97
171
98
172
LogicalResult applyConstantOperandCheck (Operation *op) {
99
173
for (auto &checker : constCheckers) {
100
- if (failed (checker (op)))
174
+ if (failed (checker (op, targetEnv )))
101
175
return failure ();
102
176
}
103
177
return success ();
@@ -114,7 +188,19 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
114
188
115
189
private:
116
190
void populateConstantOperandChecks () {
191
+ constCheckers.emplace_back (checkConstantOperandMul);
192
+ constCheckers.emplace_back (checkConstantOperandTable);
117
193
constCheckers.emplace_back (checkConstantOperandPad);
194
+ constCheckers.emplace_back (checkConstantOperandRescale);
195
+ constCheckers.emplace_back (checkConstantOperandConvOps<tosa::Conv2DOp>);
196
+ constCheckers.emplace_back (checkConstantOperandConvOps<tosa::Conv3DOp>);
197
+ constCheckers.emplace_back (
198
+ checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
199
+ constCheckers.emplace_back (
200
+ checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
201
+ constCheckers.emplace_back (checkConstantOperandMatMul);
202
+ constCheckers.emplace_back (checkConstantOperandAvgPool2d);
203
+ constCheckers.emplace_back (checkConstantOperandNegate);
118
204
}
119
205
120
206
bool levelCheckKernel (Operation *op, int32_t v, const StringRef checkDesc) {
@@ -436,7 +522,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
436
522
llvm::errs () << " unknown TOSA extension name passed in: " << ext
437
523
<< " , supported extension are int16, int4, bf16, "
438
524
<< " fp8e4m3, fp8e5m2, fft, variable, controlflow, "
439
- << " doubleround and inexactround \n " ;
525
+ << " doubleround, inexactround and dynamic \n " ;
440
526
return signalPassFailure ();
441
527
}
442
528
}
@@ -447,7 +533,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
447
533
bool CheckVariableReadOrWrite (Operation *op);
448
534
bool isValidElementType (Type type);
449
535
450
- SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
536
+ SmallVector<
537
+ std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
538
+ constCheckers;
451
539
TosaLevel tosaLevel;
452
540
DenseMap<StringAttr, mlir::Type> variablesMap;
453
541
TosaProfileCompliance profileComp;
0 commit comments