Skip to content

Commit 7208649

Browse files
authored
[TOSA] Use attributes for unsigned rescale (#118075)
Unsigned integer types are uncommon enough in MLIR that there is no operation to cast a scalar from signless to unsigned and vice versa. Currently tosa.rescale uses builtin.unrealized_conversion_cast which does not lower. Instead, this commit introduces optional attributes to indicate unsigned input or output, named similarly to those in the TOSA specification. This is more in line with the rest of MLIR where specific operations rather than values are signed/unsigned.
1 parent bba2507 commit 7208649

File tree

3 files changed

+63
-48
lines changed

3 files changed

+63
-48
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,22 +1869,23 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
18691869
let description = [{
18701870
Rescale quantized values into a new domain. Supported rescalings are:
18711871

1872-
| Mode | Input | Output |
1873-
|------------------------|-------|--------|
1874-
| signed 8 to 8 | int8 | int8 |
1875-
| signed 8 to 16 | int8 | int16 |
1876-
| signed 8 to 32 | int8 | int32 |
1877-
| signed 16 to 8 | int16 | int8 |
1878-
| signed 16 to 16 | int16 | int16 |
1879-
| signed 16 to 32 | int16 | int32 |
1880-
| signed 32 to 8 | int32 | int8 |
1881-
| signed 32 to 16 | int32 | int16 |
1882-
| signed 32 to 32 | int32 | int32 |
1883-
| signed 48 to 8 | int48 | int8 |
1884-
| signed 48 to 16 | int48 | int16 |
1885-
| signed 48 to 32 | int48 | int32 |
1886-
| unsigned 8 to signed 8 | uint8 | int8 |
1887-
| signed 8 to unsigned 8 | int8 | uint8 |
1872+
| Mode | Input | Output | Unsigned | Unsigned |
1873+
| | | | input | output |
1874+
|------------------------|-------|--------|----------|----------|
1875+
| signed 8 to 8 | int8 | int8 | false | false |
1876+
| signed 8 to 16 | int8 | int16 | false | false |
1877+
| signed 8 to 32 | int8 | int32 | false | false |
1878+
| signed 16 to 8 | int16 | int8 | false | false |
1879+
| signed 16 to 16 | int16 | int16 | false | false |
1880+
| signed 16 to 32 | int16 | int32 | false | false |
1881+
| signed 32 to 8 | int32 | int8 | false | false |
1882+
| signed 32 to 16 | int32 | int16 | false | false |
1883+
| signed 32 to 32 | int32 | int32 | false | false |
1884+
| signed 48 to 8 | int48 | int8 | false | false |
1885+
| signed 48 to 16 | int48 | int16 | false | false |
1886+
| signed 48 to 32 | int48 | int32 | false | false |
1887+
| unsigned 8 to signed 8 | uint8 | int8 | true | false |
1888+
| signed 8 to unsigned 8 | int8 | uint8 | false | true |
18881889
}];
18891890

18901891
let arguments = (ins
@@ -1895,13 +1896,33 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
18951896
DenseI8ArrayAttr:$shift,
18961897
BoolAttr:$scale32,
18971898
BoolAttr:$double_round,
1898-
BoolAttr:$per_channel
1899+
BoolAttr:$per_channel,
1900+
DefaultValuedOptionalAttr<BoolAttr, "false">:$input_unsigned,
1901+
DefaultValuedOptionalAttr<BoolAttr, "false">:$output_unsigned
18991902
);
19001903

19011904
let results = (outs
19021905
Tosa_Tensor:$output
19031906
);
19041907

1908+
// Custom builder that does not require optional input_unsigned and
1909+
// output_unsigned.
1910+
let builders = [
1911+
OpBuilder<(ins "::mlir::Type":$output,
1912+
"::mlir::Value":$input,
1913+
"::mlir::IntegerAttr":$input_zp,
1914+
"::mlir::IntegerAttr":$output_zp,
1915+
"::mlir::DenseI32ArrayAttr":$multiplier,
1916+
"::mlir::DenseI8ArrayAttr":$shift,
1917+
"::mlir::BoolAttr":$scale32,
1918+
"::mlir::BoolAttr":$double_round,
1919+
"::mlir::BoolAttr":$per_channel), [{
1920+
auto FalseAttr = BoolAttr::get($_builder.getContext(), false);
1921+
build($_builder, $_state, output, input, input_zp, output_zp, multiplier,
1922+
shift, scale32, double_round, per_channel, FalseAttr, FalseAttr);
1923+
}]>
1924+
];
1925+
19051926
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
19061927
}
19071928

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,14 +1261,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
12611261
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
12621262

12631263
if (valueTy.getIntOrFloatBitWidth() < 32) {
1264-
if (valueTy.isUnsignedInteger()) {
1265-
value = nestedBuilder
1266-
.create<UnrealizedConversionCastOp>(
1267-
nestedLoc,
1268-
nestedBuilder.getIntegerType(
1269-
valueTy.getIntOrFloatBitWidth()),
1270-
value)
1271-
.getResult(0);
1264+
if (op.getInputUnsigned()) {
12721265
value = nestedBuilder.create<arith::ExtUIOp>(
12731266
nestedLoc, nestedBuilder.getI32Type(), value);
12741267
} else {
@@ -1297,7 +1290,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
12971290
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
12981291

12991292
// Unsigned integers have a difference output value.
1300-
if (outIntType.isUnsignedInteger()) {
1293+
if (op.getOutputUnsigned()) {
13011294
intMin = 0;
13021295
intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
13031296
}
@@ -1314,13 +1307,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13141307
value = nestedBuilder.create<arith::TruncIOp>(
13151308
nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
13161309
value);
1317-
1318-
if (outIntType.isUnsignedInteger()) {
1319-
value = nestedBuilder
1320-
.create<UnrealizedConversionCastOp>(nestedLoc,
1321-
outIntType, value)
1322-
.getResult(0);
1323-
}
13241310
}
13251311

13261312
nestedBuilder.create<linalg::YieldOp>(loc, value);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,11 +1132,21 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
11321132
// CHECK-DAG: linalg.yield [[TRUNC]]
11331133
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
11341134

1135+
// CHECK: return
1136+
return
1137+
}
1138+
1139+
// -----
1140+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1141+
1142+
// CHECK-LABEL: @rescale_i8_unsigned_output
1143+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1144+
func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
11351145
// CHECK: [[C0:%.+]] = arith.constant 19689
11361146
// CHECK: [[C1:%.+]] = arith.constant 15
11371147
// CHECK: [[INIT:%.+]] = tensor.empty()
1138-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
1139-
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
1148+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
1149+
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
11401150
// CHECK: [[C17:%.+]] = arith.constant 17
11411151
// CHECK: [[C22:%.+]] = arith.constant 22
11421152
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
@@ -1148,9 +1158,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
11481158
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
11491159
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
11501160
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1151-
// CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1152-
// CHECK: linalg.yield [[CAST]]
1153-
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
1161+
// CHECK: linalg.yield [[TRUNC]]
1162+
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
11541163

11551164
// CHECK: return
11561165
return
@@ -1171,9 +1180,9 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
11711180

11721181
// CHECK: %[[C0:.+]] = arith.constant 0
11731182
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1174-
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8>
1175-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
1176-
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
1183+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
1184+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
1185+
%1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
11771186

11781187
return
11791188
}
@@ -1199,18 +1208,17 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
11991208

12001209
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
12011210

1202-
// CHECK-LABEL: @rescale_ui8
1211+
// CHECK-LABEL: @rescale_i8_unsigned_input
12031212
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1204-
func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
1213+
func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12051214
// CHECK: [[C0:%.+]] = arith.constant 19689
12061215
// CHECK: [[C1:%.+]] = arith.constant 15
12071216
// CHECK: [[INIT:%.+]] = tensor.empty()
1208-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
1209-
// CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
1217+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
1218+
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
12101219
// CHECK: [[C17:%.+]] = arith.constant 17
12111220
// CHECK: [[C22:%.+]] = arith.constant 22
1212-
// CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1213-
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[CAST]]
1221+
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
12141222
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
12151223
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
12161224
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
@@ -1220,7 +1228,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
12201228
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
12211229
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
12221230
// CHECK: linalg.yield [[TRUNC]]
1223-
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
1231+
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
12241232

12251233
return
12261234
}

0 commit comments

Comments
 (0)