Skip to content

Commit cd1e65d

Browse files
committed
[MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp
Lowering of tosa.rescale to Linalg unconditionally sign-extend the input zero-point value, even when unsigned_input is true. This commit refactor zeropoint handling to share the same logic between input and output zeropoint.
1 parent 4e81ee4 commit cd1e65d

File tree

2 files changed

+53
-29
lines changed

2 files changed

+53
-29
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
8282
rhsOrResult);
8383
}
8484

85-
template <typename T>
85+
// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
8686
static arith::ConstantOp
87-
createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
88-
OpBuilder &rewriter) {
89-
auto castedN = static_cast<T>(zp);
87+
createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
88+
bool isSigned, Location loc, OpBuilder &rewriter) {
89+
90+
// Zero the signed-extended bits if isSigned is false.
91+
zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
92+
9093
return rewriter.create<arith::ConstantOp>(
91-
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
94+
loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
9295
}
9396

9497
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,20 +1470,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14671470
Value value = blockArgs[0];
14681471
Type valueTy = value.getType();
14691472

1470-
// For now we do all of our math in 64-bit. This is not optimal but
1471-
// should be correct for now, consider computing correct bit depth
1472-
// later.
1473-
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1474-
14751473
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
14761474
if (failed(maybeIZp)) {
14771475
(void)rewriter.notifyMatchFailure(
14781476
op, "input zero point cannot be statically determined");
14791477
return;
14801478
}
14811479

1482-
auto inputZp = createConstOpFromZpVal<int32_t>(
1483-
op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
1480+
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1481+
const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
1482+
auto inputZp = createConstOpFromSExtZp(
1483+
*maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
14841484
nestedBuilder);
14851485

14861486
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14901490
return;
14911491
};
14921492

1493-
// pre-process OutputZP as it can be unsigned
1494-
auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
1495-
APInt OZp(outBitwidth, !op.getOutputUnsigned());
1496-
OZp = static_cast<int64_t>(*maybeOZp);
1497-
*maybeOZp = op.getOutputUnsigned()
1498-
? static_cast<int64_t>(OZp.getZExtValue())
1499-
: OZp.getSExtValue();
1500-
1501-
auto outputZp = createConstOpFromZpVal<int32_t>(
1502-
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
1493+
IntegerType outIntType =
1494+
cast<IntegerType>(blockArgs.back().getType());
1495+
unsigned outBitWidth = outIntType.getWidth();
1496+
auto outputZp = createConstOpFromSExtZp(
1497+
*maybeOZp, outBitWidth, /*attrBitwidth=*/32,
1498+
!op.getOutputUnsigned(), loc, nestedBuilder);
15031499

15041500
Value multiplier = multiplierConstant ? multiplierConstant
15051501
: blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15271523
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
15281524

15291525
// Saturate to the output size.
1530-
IntegerType outIntType =
1531-
cast<IntegerType>(blockArgs.back().getType());
1532-
unsigned outBitWidth = outIntType.getWidth();
1533-
15341526
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
15351527
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
15361528

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12411241
// CHECK: [[INIT:%.+]] = tensor.empty()
12421242
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
12431243
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1244-
// CHECK: [[C17:%.+]] = arith.constant 17
1244+
// CHECK: [[C128:%.+]] = arith.constant 128
12451245
// CHECK: [[C22:%.+]] = arith.constant 22
12461246
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
1247-
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1247+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
12481248
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
12491249
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
12501250
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12551255
// CHECK: linalg.yield [[TRUNC]]
12561256
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
12571257
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1258-
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1258+
%input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
12591259
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
12601260
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
12611261

12621262
return
12631263
}
12641264

1265+
// -----
1266+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1267+
1268+
// CHECK-LABEL: @rescale_i48_unsigned_output
1269+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1270+
func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
1271+
// CHECK: [[C19689:%.+]] = arith.constant 19689
1272+
// CHECK: [[C15:%.+]] = arith.constant 15
1273+
// CHECK: [[INIT:%.+]] = tensor.empty()
1274+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
1275+
// CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
1276+
// CHECK: [[C0:%.+]] = arith.constant 0
1277+
// CHECK: [[C234:%.+]] = arith.constant 234
1278+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
1279+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
1280+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
1281+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1282+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1283+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1284+
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1285+
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1286+
// CHECK: linalg.yield [[TRUNC]]
1287+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1288+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1289+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
1290+
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
1291+
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
1292+
1293+
// CHECK: return
1294+
return
1295+
}
1296+
12651297
// -----
12661298

12671299
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

0 commit comments

Comments
 (0)