Skip to content

Commit 3fb8cb6

Browse files
authored
[mlir][tosa] Add support for EXT-DOUBLEROUND and EXT-INEXACTROUND (#130337)
1 parent 51850db commit 3fb8cb6

25 files changed

+186
-63
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
226226
// FFT : Fast Fourier Transform operations.
227227
// VARIABLE : Stateful variable operations.
228228
// CONTROLFLOW : Control Flow operations.
229+
// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
230+
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
229231
//===----------------------------------------------------------------------===//
230232

231233
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
241243
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
242244
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
243245
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
246+
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
247+
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
244248

245249
def Tosa_ExtensionAttr
246250
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
247251
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
248-
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
252+
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
253+
Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
249254
]>;
250255

251256
def Tosa_ExtensionArrayAttr

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2360,7 +2360,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
23602360
I32Attr:$input_zp,
23612361
I32Attr:$output_zp,
23622362
BoolAttr:$scale32,
2363-
BoolAttr:$double_round,
2363+
Tosa_RoundingTypeAttr:$rounding_mode,
23642364
BoolAttr:$per_channel,
23652365
BoolAttr: $input_unsigned,
23662366
BoolAttr: $output_unsigned

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class TosaProfileCompliance {
136136
switch (ext) {
137137
case Extension::int16:
138138
case Extension::int4:
139+
case Extension::doubleround:
140+
case Extension::inexactround:
139141
return {Profile::pro_int};
140142
case Extension::bf16:
141143
case Extension::fp8e4m3:

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
247247
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
248248
"Supported NaN propagation strategies">;
249249

250+
// Rounding mode for tosa.rescale
251+
def Tosa_RoundingTypeAttr : StringBasedAttr<
252+
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
253+
"::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
254+
"::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
255+
"Supported rounding modes">;
256+
250257
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
251258

252259
// Tensor to buffer types.

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
4444
Tosa_IntLike:$value,
4545
Tosa_IntLike:$multiplier,
4646
Tosa_Int8Like:$shift,
47-
BoolAttr:$double_round
47+
Tosa_RoundingTypeAttr:$rounding_mode
4848
);
4949

5050
let results = (outs

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter
6565

6666
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
6767
PatternRewriter &rewriter) const final {
68+
StringRef roundingMode = op.getRoundingMode();
69+
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
70+
return failure();
71+
}
72+
6873
Location loc = op.getLoc();
6974
Value value = op.getValue();
7075
Value multiplier32 = op.getMultiplier();
@@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
96101
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
97102

98103
// Apply double rounding if necessary.
99-
if (op.getDoubleRound()) {
104+
if (op.getRoundingMode() == "DOUBLE_ROUND") {
100105
int64_t roundInt = 1 << 30;
101106
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
102107
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
125130

126131
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
127132
PatternRewriter &rewriter) const final {
133+
StringRef roundingMode = op.getRoundingMode();
134+
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
135+
return failure();
136+
}
137+
128138
Location loc = op.getLoc();
129139

130140
Type resultTy = op.getType();
@@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
170180
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
171181

172182
// Conditionally perform our double round.
173-
if (op.getDoubleRound()) {
183+
if (op.getRoundingMode() == "DOUBLE_ROUND") {
174184
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
175185
Value valuePositive = rewriter.create<arith::CmpIOp>(
176186
loc, arith::CmpIPredicate::sge, value32, zero32);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
170170

171171
auto result = rewriter.create<tosa::ApplyScaleOp>(
172172
loc, rewriter.getI32Type(), a, b, shiftConst,
173-
rewriter.getBoolAttr(false));
173+
rewriter.getStringAttr("SINGLE_ROUND"));
174174

175175
if (elementTy.isInteger(32))
176176
return result;
@@ -1385,7 +1385,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13851385
unsigned rank = inputTy.getRank();
13861386

13871387
// This is an illegal configuration. terminate and log an error
1388-
if (op.getDoubleRound() && !op.getScale32())
1388+
if (op.getRoundingMode() == "INEXACT_ROUND")
1389+
return rewriter.notifyMatchFailure(
1390+
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1391+
"currently supported");
1392+
if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
13891393
return rewriter.notifyMatchFailure(
13901394
op, "tosa.rescale requires scale32 for double_round to be true");
13911395

@@ -1429,9 +1433,13 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14291433

14301434
// Double round only occurs if shift is greater than 31, check that this
14311435
// is ever true.
1436+
14321437
bool doubleRound =
1433-
op.getDoubleRound() &&
1438+
op.getRoundingMode() == "DOUBLE_ROUND" &&
14341439
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1440+
StringAttr roundingMode = doubleRound
1441+
? rewriter.getStringAttr("DOUBLE_ROUND")
1442+
: rewriter.getStringAttr("SINGLE_ROUND");
14351443

14361444
SmallVector<AffineMap> indexingMaps = {
14371445
rewriter.getMultiDimIdentityMap(rank)};
@@ -1527,7 +1535,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15271535

15281536
value = nestedBuilder.create<tosa::ApplyScaleOp>(
15291537
loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1530-
nestedBuilder.getBoolAttr(doubleRound));
1538+
roundingMode);
15311539

15321540
// Move to the new zero-point.
15331541
value =

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,9 +1031,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10311031

10321032
auto scaled =
10331033
rewriter
1034-
.create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
1035-
poolVal, multiplier, shift,
1036-
rewriter.getBoolAttr(false))
1034+
.create<tosa::ApplyScaleOp>(
1035+
loc, rewriter.getI32Type(), poolVal, multiplier, shift,
1036+
rewriter.getStringAttr("SINGLE_ROUND"))
10371037
.getResult();
10381038

10391039
// If we have quantization information we need to apply output

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
104104
}
105105

106106
LogicalResult applyLevelCheck(Operation *op);
107+
LogicalResult applyAttributeCheck(Operation *op);
107108

108109
// check variable read/write data types against variable declarations
109110
LogicalResult applyVariableCheck(Operation *op);
@@ -386,6 +387,25 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
386387
return true;
387388
}
388389

390+
bool attributeCheckRescale(Operation *op) {
391+
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
392+
if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
393+
!targetEnv.allows(Extension::doubleround)) {
394+
op->emitOpError()
395+
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
396+
<< "requires extension [doubleround]";
397+
return false;
398+
} else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
399+
!targetEnv.allows(Extension::inexactround)) {
400+
op->emitOpError()
401+
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
402+
<< "requires extension [inexactround]";
403+
return false;
404+
}
405+
}
406+
return true;
407+
}
408+
389409
// configure profile and level values from pass options profileName and
390410
// levelName
391411
void configLevelAndProfile() {
@@ -415,7 +435,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
415435
} else {
416436
llvm::errs() << "unknown TOSA extension name passed in: " << ext
417437
<< ", supported extension are int16, int4, bf16, "
418-
<< "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
438+
<< "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
439+
<< "doubleround and inexactround\n";
419440
return signalPassFailure();
420441
}
421442
}
@@ -642,6 +663,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
642663
return success();
643664
}
644665

666+
LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
667+
if (!attributeCheckRescale(op))
668+
return failure();
669+
return success();
670+
}
671+
645672
inline bool CompatibleTypes(const mlir::Type &type,
646673
const mlir::Type &declaredType) {
647674
// for now, simply use type equality comparison
@@ -936,6 +963,10 @@ void TosaValidation::runOnOperation() {
936963
if (failed(applyLevelCheck(op)))
937964
signalPassFailure();
938965

966+
// check additional attribute restrictions
967+
if (failed(applyAttributeCheck(op)))
968+
signalPassFailure();
969+
939970
// do variable type checks
940971
if (failed(applyVariableCheck(op)))
941972
signalPassFailure();
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics
2+
3+
// CHECK-LABEL: @apply_scale_unsupported_inexact_round
4+
func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
5+
// expected-error@+1 {{failed to legalize operation 'tosa.apply_scale'}}
6+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
7+
return %res : i32
8+
}

mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
6767
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
6868
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
6969
// CHECK: return %[[RESULT]]
70-
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
70+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
7171
return %res : i32
7272
}
7373

@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
7777
// SCALE: tosa.apply_scale
7878
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
7979
// CHECK-NOT: "tosa.apply_scale"
80-
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
80+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
8181
return %res : vector<4xi32>
8282
}
8383

@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
115115
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
116116
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
117117
// CHECK: return %[[TRUNC]]
118-
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
118+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
119119
return %res : i32
120120
}
121121

@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
152152
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
153153
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
154154
// CHECK: return %[[TRUNC]]
155-
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
155+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
156156
return %res : i32
157157
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
3636
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
3737
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
3838
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
39-
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
39+
%0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4040
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4141
}
4242

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
423423
// CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
424424
// CHECK: %[[C30:.+]] = arith.constant 30
425425
// CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
426-
// CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
426+
// CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
427427

428428
// Perform the normalization.
429429
// CHECK: %[[CMIN:.+]] = arith.constant -128

0 commit comments

Comments
 (0)