@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
82
82
rhsOrResult);
83
83
}
84
84
85
- template < typename T>
85
+ // Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
86
86
static arith::ConstantOp
87
- createConstOpFromZpVal (Operation *op, const int64_t &zp, Type requiredAttrType,
88
- OpBuilder &rewriter) {
89
- auto castedN = static_cast <T>(zp);
87
+ createConstOpFromSExtZp (int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
88
+ bool isSigned, Location loc, OpBuilder &rewriter) {
89
+
90
+ // Zero the signed-extended bits if isSigned is false.
91
+ zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1 );
92
+
90
93
return rewriter.create <arith::ConstantOp>(
91
- op-> getLoc () , IntegerAttr::get (requiredAttrType, castedN ));
94
+ loc , IntegerAttr::get (rewriter. getIntegerType (attrBitwidth), zp ));
92
95
}
93
96
94
97
static Value createLinalgBodyCalculationForElementwiseOp (
@@ -1467,20 +1470,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1467
1470
Value value = blockArgs[0 ];
1468
1471
Type valueTy = value.getType ();
1469
1472
1470
- // For now we do all of our math in 64-bit. This is not optimal but
1471
- // should be correct for now, consider computing correct bit depth
1472
- // later.
1473
- int32_t inBitwidth = valueTy.getIntOrFloatBitWidth () > 32 ? 48 : 32 ;
1474
-
1475
1473
FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1476
1474
if (failed (maybeIZp)) {
1477
1475
(void )rewriter.notifyMatchFailure (
1478
1476
op, " input zero point cannot be statically determined" );
1479
1477
return ;
1480
1478
}
1481
1479
1482
- auto inputZp = createConstOpFromZpVal<int32_t >(
1483
- op, *maybeIZp, nestedBuilder.getIntegerType (inBitwidth),
1480
+ const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth ();
1481
+ const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32 ;
1482
+ auto inputZp = createConstOpFromSExtZp (
1483
+ *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned (), loc,
1484
1484
nestedBuilder);
1485
1485
1486
1486
FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1490
1490
return ;
1491
1491
};
1492
1492
1493
- // pre-process OutputZP as it can be unsigned
1494
- auto outBitwidth = outputTy.getElementType ().getIntOrFloatBitWidth ();
1495
- APInt OZp (outBitwidth, !op.getOutputUnsigned ());
1496
- OZp = static_cast <int64_t >(*maybeOZp);
1497
- *maybeOZp = op.getOutputUnsigned ()
1498
- ? static_cast <int64_t >(OZp.getZExtValue ())
1499
- : OZp.getSExtValue ();
1500
-
1501
- auto outputZp = createConstOpFromZpVal<int32_t >(
1502
- op, *maybeOZp, nestedBuilder.getI32Type (), nestedBuilder);
1493
+ IntegerType outIntType =
1494
+ cast<IntegerType>(blockArgs.back ().getType ());
1495
+ unsigned outBitWidth = outIntType.getWidth ();
1496
+ auto outputZp = createConstOpFromSExtZp (
1497
+ *maybeOZp, outBitWidth, /* attrBitwidth=*/ 32 ,
1498
+ !op.getOutputUnsigned (), loc, nestedBuilder);
1503
1499
1504
1500
Value multiplier = multiplierConstant ? multiplierConstant
1505
1501
: blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1527
1523
nestedBuilder.create <arith::AddIOp>(nestedLoc, value, outputZp);
1528
1524
1529
1525
// Saturate to the output size.
1530
- IntegerType outIntType =
1531
- cast<IntegerType>(blockArgs.back ().getType ());
1532
- unsigned outBitWidth = outIntType.getWidth ();
1533
-
1534
1526
int32_t intMin = APInt::getSignedMinValue (outBitWidth).getSExtValue ();
1535
1527
int32_t intMax = APInt::getSignedMaxValue (outBitWidth).getSExtValue ();
1536
1528
0 commit comments