@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
562
562
563
563
bool CheckVariable (Operation *op);
564
564
bool CheckVariableReadOrWrite (Operation *op);
565
- bool isValidElementType (Type type);
565
+ bool isValidElementType (Type type, const bool allowUnsigned = false );
566
566
567
567
SmallVector<
568
568
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1176
1176
return success ();
1177
1177
}
1178
1178
1179
- bool TosaValidation::isValidElementType (Type type) {
1179
+ bool TosaValidation::isValidElementType (Type type, const bool allowUnsigned ) {
1180
1180
if (isa<FloatType>(type)) {
1181
1181
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1182
1182
Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
1191
1191
case 48 :
1192
1192
return true ;
1193
1193
}
1194
+ } else if (allowUnsigned && intTy.isUnsigned ()) {
1195
+ switch (intTy.getWidth ()) {
1196
+ case 8 :
1197
+ case 16 :
1198
+ case 32 :
1199
+ return true ;
1200
+ }
1194
1201
}
1195
1202
} else if (mlir::isa<tosa::shapeType>(type)) {
1196
1203
return true ;
@@ -1209,19 +1216,23 @@ void TosaValidation::runOnOperation() {
1209
1216
if (op->getDialect () != tosaDialect)
1210
1217
return ;
1211
1218
1212
- // perform valid element type check at the beginning to
1213
- // protect rest of code against quantized element types
1219
+ // validate operator element types:
1220
+ // - rescale operator is allowed to have ui8/ui16/ui32
1221
+ // operands/results
1222
+ // - perform valid element type check at the beginning to
1223
+ // protect rest of code against quantized element types
1224
+ const bool opIsRescale = isa<tosa::RescaleOp>(op);
1214
1225
for (Value operand : op->getOperands ()) {
1215
1226
auto elementTy = getElementTypeOrSelf (operand);
1216
- if (!isValidElementType (elementTy)) {
1227
+ if (!isValidElementType (elementTy, opIsRescale )) {
1217
1228
op->emitOpError () << " is not profile-aligned: element type "
1218
1229
<< elementTy << " is not legal" ;
1219
1230
return signalPassFailure ();
1220
1231
}
1221
1232
}
1222
1233
for (Type resultTy : op->getResultTypes ()) {
1223
1234
auto elementTy = getElementTypeOrSelf (resultTy);
1224
- if (!isValidElementType (elementTy)) {
1235
+ if (!isValidElementType (elementTy, opIsRescale )) {
1225
1236
op->emitOpError () << " is not profile-aligned: element type "
1226
1237
<< elementTy << " is not legal" ;
1227
1238
return signalPassFailure ();
0 commit comments