Skip to content

Commit 1c36395

Browse files
committed
Deal with zeropoint sign in getter
Also clarify zeropoint extension rules.
1 parent cd1e65d commit 1c36395

File tree

3 files changed

+34
-40
lines changed

3 files changed

+34
-40
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,6 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
8282
rhsOrResult);
8383
}
8484

85-
// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
86-
static arith::ConstantOp
87-
createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
88-
bool isSigned, Location loc, OpBuilder &rewriter) {
89-
90-
// Zero the signed-extended bits if isSigned is false.
91-
zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
92-
93-
return rewriter.create<arith::ConstantOp>(
94-
loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
95-
}
96-
9785
static Value createLinalgBodyCalculationForElementwiseOp(
9886
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
9987
ConversionPatternRewriter &rewriter) {
@@ -1478,10 +1466,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14781466
}
14791467

14801468
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1481-
const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
1482-
auto inputZp = createConstOpFromSExtZp(
1483-
*maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
1484-
nestedBuilder);
1469+
// Extend zeropoint for sub-32bits widths.
1470+
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471+
auto inputZp = nestedBuilder.create<arith::ConstantOp>(loc,
1472+
IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1473+
*maybeIZp));
14851474

14861475
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
14871476
if (failed(maybeOZp)) {
@@ -1493,9 +1482,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14931482
IntegerType outIntType =
14941483
cast<IntegerType>(blockArgs.back().getType());
14951484
unsigned outBitWidth = outIntType.getWidth();
1496-
auto outputZp = createConstOpFromSExtZp(
1497-
*maybeOZp, outBitWidth, /*attrBitwidth=*/32,
1498-
!op.getOutputUnsigned(), loc, nestedBuilder);
1485+
const int32_t outAttrBitwidth = 32;
1486+
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1487+
auto outputZp = nestedBuilder.create<arith::ConstantOp>(loc,
1488+
IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1489+
*maybeOZp));
14991490

15001491
Value multiplier = multiplierConstant ? multiplierConstant
15011492
: blockArgs[multiplierArg];

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,7 +2118,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
21182118
// return failure if val is not a constant
21192119
// set zp to -1 if val is non-zero float or val is not integer nor float
21202120
// otherwise set zp to val's constant value
2121-
static FailureOr<int64_t> getZeroPoint(Value val) {
2121+
static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
21222122
ElementsAttr zpAttr;
21232123
if (!matchPattern(val, m_Constant(&zpAttr))) {
21242124
return failure();
@@ -2135,7 +2135,10 @@ static FailureOr<int64_t> getZeroPoint(Value val) {
21352135
}
21362136

21372137
if (llvm::isa<IntegerType>(zpElemType)) {
2138-
return zpAttr.getValues<APInt>()[0].getSExtValue();
2138+
if (signExtend)
2139+
return zpAttr.getValues<APInt>()[0].getSExtValue();
2140+
else
2141+
return zpAttr.getValues<APInt>()[0].getZExtValue();
21392142
}
21402143

21412144
// return non-zero value to trigger error check
@@ -2186,30 +2189,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
21862189
return success();
21872190
}
21882191

2189-
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
2192+
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
21902193
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2191-
return getZeroPoint(get##OPERAND_NAME##Zp()); \
2194+
return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
21922195
} \
21932196
LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
21942197
return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
21952198
}
21962199

2197-
ZERO_POINT_HELPER(Conv2DOp, Input)
2198-
ZERO_POINT_HELPER(Conv2DOp, Weight)
2199-
ZERO_POINT_HELPER(Conv3DOp, Input)
2200-
ZERO_POINT_HELPER(Conv3DOp, Weight)
2201-
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
2202-
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
2203-
ZERO_POINT_HELPER(TransposeConv2DOp, Input)
2204-
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
2205-
ZERO_POINT_HELPER(AvgPool2dOp, Input)
2206-
ZERO_POINT_HELPER(AvgPool2dOp, Output)
2207-
ZERO_POINT_HELPER(MatMulOp, A)
2208-
ZERO_POINT_HELPER(MatMulOp, B)
2209-
ZERO_POINT_HELPER(NegateOp, Input1)
2210-
ZERO_POINT_HELPER(NegateOp, Output)
2211-
ZERO_POINT_HELPER(RescaleOp, Input)
2212-
ZERO_POINT_HELPER(RescaleOp, Output)
2200+
ZERO_POINT_HELPER(Conv2DOp, Input, true)
2201+
ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2202+
ZERO_POINT_HELPER(Conv3DOp, Input, true)
2203+
ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2204+
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2205+
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2206+
ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2207+
ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2208+
ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2209+
ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2210+
ZERO_POINT_HELPER(MatMulOp, A, true)
2211+
ZERO_POINT_HELPER(MatMulOp, B, true)
2212+
ZERO_POINT_HELPER(NegateOp, Input1, true)
2213+
ZERO_POINT_HELPER(NegateOp, Output, true)
2214+
ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2215+
ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
22132216
#undef ZERO_POINT_HELPER
22142217

22152218
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
15171517
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
15181518
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
15191519
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
1520-
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
1520+
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}}
15211521
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
15221522
return %0 : tensor<13x21x3xi16>
15231523
}

0 commit comments

Comments
 (0)