Skip to content

Commit dcfcd7e

Browse files
psunnTai78641
authored andcommitted
[mlir][tosa] Change Rescale zero points to be inputs
*Update RescaleOp to use zero-point as operands instead of attributes. *Check input_zp data type against the input and output_zp data type against the output. Change-Id: I2cf0106eb9f9ec88e16de5efc93b651053e5fc92 Signed-off-by: Peng Sun <[email protected]>
1 parent b0baa1d commit dcfcd7e

File tree

15 files changed

+270
-113
lines changed

15 files changed

+270
-113
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,15 @@ profileComplianceMap = {
216216
{fp32T, fp16T}}}}},
217217
{"tosa.rescale",
218218
{{{Profile::pro_int},
219-
{{i8T, i8T},
220-
{i8T, i16T},
221-
{i8T, i32T},
222-
{i16T, i8T},
223-
{i16T, i16T},
224-
{i16T, i32T},
225-
{i32T, i8T},
226-
{i32T, i16T},
227-
{i32T, i32T}}}}},
219+
{{i8T, i8T, i8T, i8T},
220+
{i8T, i8T, i16T, i16T},
221+
{i8T, i8T, i32T, i32T},
222+
{i16T, i16T, i8T, i8T},
223+
{i16T, i16T, i16T, i16T},
224+
{i16T, i16T, i32T, i32T},
225+
{i32T, i32T, i8T, i8T},
226+
{i32T, i32T, i16T, i16T},
227+
{i32T, i32T, i32T, i32T}}}}},
228228
{"tosa.const",
229229
{{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
230230
{{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -384,7 +384,10 @@ extensionComplianceMap = {
384384
{fp16T, fp8e5m2T},
385385
{fp32T, fp8e5m2T}}}}},
386386
{"tosa.rescale",
387-
{{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
387+
{{{Extension::int16},
388+
{{i48T, i48T, i8T, i8T},
389+
{i48T, i48T, i16T, i16T},
390+
{i48T, i48T, i32T, i32T}}}}},
388391
{"tosa.const",
389392
{{{Extension::int4}, {{i4T}}},
390393
{{Extension::int16}, {{i48T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,8 +2337,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
23372337
Tosa_Tensor:$input,
23382338
Tosa_1DInt16Or32Tensor:$multiplier,
23392339
Tosa_1DInt8Tensor:$shift,
2340-
I32Attr:$input_zp,
2341-
I32Attr:$output_zp,
2340+
Tosa_ScalarIntOrFloatTensor:$input_zp,
2341+
Tosa_ScalarIntOrFloatTensor:$output_zp,
23422342
BoolAttr:$scale32,
23432343
BoolAttr:$double_round,
23442344
BoolAttr:$per_channel,
@@ -2355,6 +2355,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
23552355
Extension<[Tosa_EXT_INT16]>,
23562356
];
23572357

2358+
let extraClassDeclaration = [{
2359+
FailureOr<int64_t> getInputZeroPoint();
2360+
FailureOr<int64_t> getOutputZeroPoint();
2361+
LogicalResult verifyInputZeroPoint(int64_t zp);
2362+
LogicalResult verifyOutputZeroPoint(int64_t zp);
2363+
}];
2364+
23582365
let hasVerifier = 1;
23592366

23602367
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
8484

8585
template <typename T>
8686
static arith::ConstantOp
87-
createConstFromIntAttribute(Operation *op, const std::string &attrName,
88-
Type requiredAttrType, OpBuilder &rewriter) {
89-
auto castedN = static_cast<T>(
90-
cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
87+
createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
88+
OpBuilder &rewriter) {
89+
auto castedN = static_cast<T>(zp);
9190
return rewriter.create<arith::ConstantOp>(
9291
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
9392
}
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14911490
// later.
14921491
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
14931492

1494-
auto inputZp = createConstFromIntAttribute<int32_t>(
1495-
op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1493+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1494+
if (failed(maybeIZp)) {
1495+
(void)rewriter.notifyMatchFailure(
1496+
op, "input zero point cannot be statically determined");
1497+
return;
1498+
}
1499+
1500+
auto inputZp = createConstOpFromZpVal<int32_t>(
1501+
op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
14961502
nestedBuilder);
1497-
auto outputZp = createConstFromIntAttribute<int32_t>(
1498-
op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1503+
1504+
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1505+
if (failed(maybeOZp)) {
1506+
(void)rewriter.notifyMatchFailure(
1507+
op, "output zero point cannot be statically determined");
1508+
return;
1509+
};
1510+
1511+
auto outputZp = createConstOpFromZpVal<int32_t>(
1512+
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
14991513

15001514
Value multiplier = multiplierConstant ? multiplierConstant
15011515
: blockArgs[multiplierArg];

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,27 @@ static Type getStorageElementTypeOrSelf(Type type) {
254254
return elementType;
255255
}
256256

257+
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
258+
Value valZp, StringRef name) {
259+
Type eType = getStorageElementTypeOrSelf(val.getType());
260+
Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
261+
262+
bool bothInts =
263+
mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
264+
bool bothFloats =
265+
mlir::isa<FloatType>(eType) && mlir::isa<FloatType>(eZpType);
266+
bool sameBitWidth =
267+
(eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
268+
269+
if ((!bothInts && !bothFloats) || !sameBitWidth) {
270+
return op->emitOpError()
271+
<< "expected " << name << " and " << name
272+
<< "_zp to both be integer or float of the same bitwidth, but got "
273+
<< eType << " vs. " << eZpType;
274+
}
275+
return success();
276+
}
277+
257278
//===----------------------------------------------------------------------===//
258279
// TOSA Operator Verifiers.
259280
//===----------------------------------------------------------------------===//
@@ -1696,6 +1717,38 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
16961717
return success();
16971718
}
16981719

1720+
static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
1721+
const int64_t &zp,
1722+
const std::string &operand) {
1723+
bool isInputZp = (zpVal == op.getInputZp());
1724+
bool isOutputZp = (zpVal == op.getOutputZp());
1725+
if (!isInputZp && !isOutputZp) {
1726+
return op.emitOpError("internal error: zero-point operand is neither "
1727+
"inputZp nor outputZp");
1728+
}
1729+
1730+
bool tensorUnsigned =
1731+
isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
1732+
StringRef tensorName = isInputZp ? "input" : "output";
1733+
1734+
Type zpElemType = getElementTypeOrSelf(zpVal);
1735+
1736+
if (zp != 0) {
1737+
if (!zpElemType.isInteger(8) &&
1738+
!(zpElemType.isInteger(16) && tensorUnsigned)) {
1739+
return op.emitOpError()
1740+
<< "expect " << tensorName << "_zp of 0, got " << zp;
1741+
}
1742+
if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
1743+
return op.emitOpError() << "expect " << tensorName
1744+
<< "_zp of 0 or 32768 for unsigned int16 "
1745+
<< tensorName << ", got " << zp;
1746+
}
1747+
}
1748+
1749+
return success();
1750+
}
1751+
16991752
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
17001753
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
17011754
return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
@@ -1714,6 +1767,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
17141767
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
17151768
ZERO_POINT_HELPER(AvgPool2dOp, Input)
17161769
ZERO_POINT_HELPER(AvgPool2dOp, Output)
1770+
ZERO_POINT_HELPER(RescaleOp, Input)
1771+
ZERO_POINT_HELPER(RescaleOp, Output)
17171772
#undef ZERO_POINT_HELPER
17181773

17191774
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2698,41 +2753,21 @@ LogicalResult RescaleOp::verify() {
26982753
return failure();
26992754
}
27002755

2701-
auto input_zp = getInputZpAttr().getInt();
2702-
if (input_zp != 0) {
2703-
// only int8/uint8 and uint16 input can have non-zero input_zp
2704-
if (!inputElementType.isInteger(8) &&
2705-
!(inputElementType.isInteger(16) && getInputUnsigned())) {
2706-
emitOpError("expect input_zp of 0, got ") << input_zp;
2707-
return failure();
2708-
}
2709-
// input_zp must be either 0 or 32768 for uint16 input
2710-
if (inputElementType.isInteger(16) && getInputUnsigned() &&
2711-
input_zp != 32768) {
2712-
emitOpError(
2713-
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
2714-
<< input_zp;
2715-
return failure();
2716-
}
2717-
}
2756+
if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
2757+
.failed())
2758+
return failure();
27182759

2719-
auto output_zp = getOutputZpAttr().getInt();
2720-
if (output_zp != 0) {
2721-
// only int8/uint8 and uint16 output can have non-zero output_zp
2722-
if (!outputElementType.isInteger(8) &&
2723-
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
2724-
emitOpError("expect output_zp of 0, got ") << output_zp;
2725-
return failure();
2726-
}
2727-
// output_zp must be either 0 or 32768 for uint16 output
2728-
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
2729-
output_zp != 32768) {
2730-
emitOpError(
2731-
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
2732-
<< output_zp;
2733-
return failure();
2734-
}
2735-
}
2760+
if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
2761+
.failed())
2762+
return failure();
2763+
2764+
FailureOr<int64_t> maybeIZp = getInputZeroPoint();
2765+
if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
2766+
return failure();
2767+
2768+
FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2769+
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2770+
return failure();
27362771

27372772
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
27382773
if (!multiplierType) {

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
175175
template <>
176176
void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
177177
addValue(op.getInput());
178+
addValue(op.getInputZp());
179+
addValue(op.getOutputZp());
178180
addValue(op.getOutput());
179181
}
180182

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3535
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
3636
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
3737
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
38+
%input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
39+
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
40+
3841
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
39-
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
42+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4043
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4144
}
4245

0 commit comments

Comments
 (0)