Skip to content

Commit 64f54d3

Browse files
psunnTai78641
andcommitted
[TOSA] Convert RESCALE op multiplier and shift from attributes to inputs
This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification. Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions. Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Peng Sun <[email protected]> Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
1 parent e27b8b2 commit 64f54d3

File tree

15 files changed

+473
-41
lines changed

15 files changed

+473
-41
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,9 +2262,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
22622262
//===----------------------------------------------------------------------===//
22632263
// Operator: rescale
22642264
//===----------------------------------------------------------------------===//
2265-
def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
2266-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
2267-
["inferReturnTypeComponents"]>]> {
2265+
def Tosa_RescaleOp: Tosa_InferShapedTypeOp<"rescale"> {
22682266
let summary = "Tosa rescale operator";
22692267

22702268
let description = [{
@@ -2290,10 +2288,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
22902288

22912289
let arguments = (ins
22922290
Tosa_Tensor:$input,
2291+
Tosa_1DInt16Or32Tensor:$multiplier,
2292+
Tosa_1DInt8Tensor:$shift,
22932293
I32Attr:$input_zp,
22942294
I32Attr:$output_zp,
2295-
DenseI32ArrayAttr:$multiplier,
2296-
DenseI8ArrayAttr:$shift,
22972295
BoolAttr:$scale32,
22982296
BoolAttr:$double_round,
22992297
BoolAttr:$per_channel,
@@ -2310,6 +2308,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23102308
Extension<[Tosa_EXT_INT16]>,
23112309
];
23122310

2311+
let hasVerifier = 1;
2312+
23132313
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
23142314
}
23152315

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
4343

4444
def Tosa_Int4 : I<4>;
4545
def Tosa_Int8 : I<8>;
46+
def Tosa_Int16 : I<16>;
4647
def Tosa_Int32 : I<32>;
4748
def Tosa_Int64 : I<64>;
4849

@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
5455
AnySignlessInteger]>;
5556

5657
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
57-
Tosa_Int64]>;
58+
Tosa_Int64]>;
59+
60+
def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
61+
Tosa_Int32]>;
5862

5963
//===----------------------------------------------------------------------===//
6064
// Quantized Integer Types.
@@ -68,11 +72,23 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
6872
// int8 : symmetric per tensor/per channel, signed
6973
// int16 : symmetric per tensor, signed
7074
//===----------------------------------------------------------------------===//
71-
def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
72-
Tosa_QuantizedType<"int4", [4, 0], 1>,
73-
Tosa_QuantizedType<"int8", [8, 0], 1>,
74-
Tosa_QuantizedType<"int16", [16, 0], 1>,
75-
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
75+
def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
76+
Tosa_QuantizedType<"int4", [4, 0], 1>,
77+
Tosa_QuantizedType<"int8", [8, 0], 1>,
78+
Tosa_QuantizedType<"int16", [16, 0], 1>,
79+
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
80+
81+
//===----------------------------------------------------------------------===//
82+
// Floating-point types.
83+
//===----------------------------------------------------------------------===//
84+
def Tosa_Float : AnyTypeOf<[
85+
F32,
86+
F16,
87+
BF16]>;
88+
89+
def Tosa_F8 : AnyTypeOf<[
90+
F8E4M3FN,
91+
F8E5M2]>;
7692

7793
//===----------------------------------------------------------------------===//
7894
// Multi-category types.
@@ -162,6 +178,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
162178
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
163179
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
164180

181+
// 1D tensor of specific types
182+
def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
183+
def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
184+
165185
// Ranked tensors up to given rank.
166186
def Tosa_Tensor1Dto4D : AnyTypeOf<[
167187
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ namespace tosa {
3131
void computeMultiplierAndShift(double scale, int32_t &multiplier,
3232
int32_t &shift, int32_t scaleWidth);
3333

34+
// Return a const value for array of int8 vec
35+
Value getConstTensorInt8(OpBuilder &builder, Location loc,
36+
ArrayRef<int8_t> vec);
37+
38+
// Return a const value for array of int16 vec
39+
Value getConstTensorInt16(OpBuilder &builder, Location loc,
40+
ArrayRef<int16_t> vec);
41+
42+
// Return a const value for array of int32 vec
43+
Value getConstTensorInt32(OpBuilder &builder, Location loc,
44+
ArrayRef<int32_t> vec);
45+
3446
//// Builds ConvOpQuantizationAttr from input and weight.
3547
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3648
Value input, Value weight);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13891389
}
13901390

13911391
// The shift and multiplier values.
1392-
SmallVector<int32_t> multiplierValues(op.getMultiplier());
1393-
SmallVector<int8_t> shiftValues(op.getShift());
1392+
ElementsAttr shiftElems;
1393+
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1394+
return rewriter.notifyMatchFailure(
1395+
op, "tosa.rescale requires constant shift input values");
1396+
1397+
ElementsAttr multiplierElems;
1398+
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1399+
return rewriter.notifyMatchFailure(
1400+
op, "tosa.rescale requires constant multiplier input values");
1401+
1402+
SmallVector<int8_t> shiftValues;
1403+
for (auto idx : shiftElems.getValues<IntegerAttr>()) {
1404+
shiftValues.push_back(static_cast<int8_t>(idx.getInt()));
1405+
}
1406+
SmallVector<int32_t> multiplierValues;
1407+
for (auto idx : multiplierElems.getValues<IntegerAttr>()) {
1408+
multiplierValues.push_back(static_cast<int32_t>(idx.getInt()));
1409+
}
13941410

13951411
// If we shift by more than the bitwidth, this just sets to 0.
13961412
for (int i = 0, s = multiplierValues.size(); i < s; i++) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,7 +2024,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
20242024
NARY_SHAPE_INFER(tosa::NegateOp)
20252025
NARY_SHAPE_INFER(tosa::PowOp)
20262026
NARY_SHAPE_INFER(tosa::ReciprocalOp)
2027-
NARY_SHAPE_INFER(tosa::RescaleOp)
20282027
NARY_SHAPE_INFER(tosa::ReverseOp)
20292028
NARY_SHAPE_INFER(tosa::RsqrtOp)
20302029
NARY_SHAPE_INFER(tosa::SinOp)
@@ -2469,6 +2468,139 @@ LogicalResult TransposeConv2DOp::verify() {
24692468
return success();
24702469
}
24712470

2471+
LogicalResult RescaleOp::verify() {
2472+
auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
2473+
if (!inputType) {
2474+
emitOpError("expect shaped tensor for input, got ") << getInput().getType();
2475+
return failure();
2476+
}
2477+
auto inputElementType = inputType.getElementType();
2478+
if (auto inputQType =
2479+
llvm::dyn_cast<quant::QuantizedType>(inputElementType)) {
2480+
inputElementType = inputQType.getStorageType();
2481+
}
2482+
if (!mlir::isa<IntegerType>(inputElementType)) {
2483+
emitOpError("expect input to have integer element type, got ")
2484+
<< inputElementType;
2485+
return failure();
2486+
}
2487+
auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
2488+
if (!outputType) {
2489+
emitOpError("expect shaped tensor for output, got ")
2490+
<< getOutput().getType();
2491+
return failure();
2492+
}
2493+
auto outputElementType = outputType.getElementType();
2494+
if (auto outputQType =
2495+
llvm::dyn_cast<quant::QuantizedType>(outputElementType)) {
2496+
outputElementType = outputQType.getStorageType();
2497+
}
2498+
if (!mlir::isa<IntegerType>(outputElementType)) {
2499+
emitOpError("expect output to have integer element type, got ")
2500+
<< outputElementType;
2501+
return failure();
2502+
}
2503+
auto input_zp = getInputZpAttr().getInt();
2504+
if (input_zp != 0) {
2505+
// only int8/uint8 and uint16 input can have non-zero input_zp
2506+
if (!inputElementType.isInteger(8) &&
2507+
!(inputElementType.isInteger(16) && getInputUnsigned())) {
2508+
emitOpError("expect input_zp of 0, got ") << input_zp;
2509+
return failure();
2510+
}
2511+
// input_zp must be either 0 or 32768 for uint16 input
2512+
if (inputElementType.isInteger(16) && getInputUnsigned() &&
2513+
input_zp != 32768) {
2514+
emitOpError(
2515+
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
2516+
<< input_zp;
2517+
return failure();
2518+
}
2519+
}
2520+
2521+
auto output_zp = getOutputZpAttr().getInt();
2522+
if (output_zp != 0) {
2523+
// only int8/uint8 and uint16 output can have non-zero output_zp
2524+
if (!outputElementType.isInteger(8) &&
2525+
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
2526+
emitOpError("expect output_zp of 0, got ") << output_zp;
2527+
return failure();
2528+
}
2529+
// output_zp must be either 0 or 32768 for uint16 output
2530+
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
2531+
output_zp != 32768) {
2532+
emitOpError(
2533+
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
2534+
<< output_zp;
2535+
return failure();
2536+
}
2537+
}
2538+
2539+
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
2540+
if (!multiplierType) {
2541+
emitOpError("expect shaped tensor for multiplier, got ")
2542+
<< getMultiplier().getType();
2543+
return failure();
2544+
}
2545+
auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
2546+
if (!shiftType) {
2547+
emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
2548+
return failure();
2549+
}
2550+
2551+
// multiplier element type must be i32 for scale32 = true
2552+
if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2553+
emitOpError("expect i32 element type for multiplier for scale32=true, got ")
2554+
<< multiplierType.getElementType();
2555+
return failure();
2556+
}
2557+
2558+
// multiplier element type must be i16 for scale32 = false
2559+
if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2560+
emitOpError(
2561+
"expect i16 element type for multiplier for scale32=false, got ")
2562+
<< multiplierType.getElementType();
2563+
return failure();
2564+
}
2565+
2566+
// multiplier/shift must have shape = {numChannels},
2567+
// where numChannel is 1 if per_channel = false
2568+
// otherwise numChannel is dimension in input shape's last axis
2569+
int64_t numChannels = 1;
2570+
if (getPerChannel()) {
2571+
ArrayRef<int64_t> inputShape = inputType.getShape();
2572+
numChannels = inputShape[inputShape.size() - 1];
2573+
}
2574+
2575+
ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
2576+
// multiplier input has rank 1 by dialect definition
2577+
if (multiplierShape[0] != numChannels) {
2578+
emitOpError("expect shape of { ")
2579+
<< numChannels << " } for multiplier input, got { "
2580+
<< multiplierShape[0] << " }";
2581+
return failure();
2582+
}
2583+
2584+
ArrayRef<int64_t> shiftShape = shiftType.getShape();
2585+
// shift input has rank 1 by dialect definition
2586+
if (shiftShape[0] != numChannels) {
2587+
emitOpError("expect shape of { ")
2588+
<< numChannels << " } for shift input, got { " << shiftShape[0] << " }";
2589+
return failure();
2590+
}
2591+
2592+
return success();
2593+
}
2594+
2595+
LogicalResult RescaleOp::inferReturnTypeComponents(
2596+
MLIRContext *context, ::std::optional<Location> location,
2597+
RescaleOp::Adaptor adaptor,
2598+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2599+
ShapeAdaptor inputShape(adaptor.getInput().getType());
2600+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2601+
return success();
2602+
}
2603+
24722604
LogicalResult IfOp::inferReturnTypeComponents(
24732605
MLIRContext *context, ::std::optional<Location> location,
24742606
IfOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,45 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
107107
}
108108
}
109109

110+
// Return a const value for array of int8 vec
111+
Value mlir::tosa::getConstTensorInt8(OpBuilder &builder, Location loc,
112+
ArrayRef<int8_t> vec) {
113+
int64_t count = vec.size();
114+
assert(count > 0);
115+
auto element_type = builder.getI8Type();
116+
mlir::RankedTensorType const_type =
117+
RankedTensorType::get({count}, element_type);
118+
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
119+
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
120+
return const_op.getResult();
121+
}
122+
123+
// Return a const value for array of int16 vec
124+
Value mlir::tosa::getConstTensorInt16(OpBuilder &builder, Location loc,
125+
ArrayRef<int16_t> vec) {
126+
int64_t count = vec.size();
127+
assert(count > 0);
128+
auto element_type = builder.getI16Type();
129+
mlir::RankedTensorType const_type =
130+
RankedTensorType::get({count}, element_type);
131+
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
132+
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
133+
return const_op.getResult();
134+
}
135+
136+
// Return a const value for array of int32 vec
137+
Value mlir::tosa::getConstTensorInt32(OpBuilder &builder, Location loc,
138+
ArrayRef<int32_t> vec) {
139+
int64_t count = vec.size();
140+
assert(count > 0);
141+
auto element_type = builder.getI32Type();
142+
mlir::RankedTensorType const_type =
143+
RankedTensorType::get({count}, element_type);
144+
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
145+
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
146+
return const_op.getResult();
147+
}
148+
110149
#define GET_UQTYPE(inputType) \
111150
(llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
112151
#define GET_QTYPE(inputType) \

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3333

3434
// CHECK-LABEL: @rescale_unsupported_type
3535
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
36+
%multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
37+
%shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
3638
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
37-
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
39+
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3840
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3941
}
4042

0 commit comments

Comments
 (0)