Skip to content

Commit 1081656

Browse files
psunnTai78641
andcommitted
[TOSA] Convert RESCALE op multiplier and shift from attributes to inputs
This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification. Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions. Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Peng Sun <[email protected]> Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
1 parent e27b8b2 commit 1081656

File tree

14 files changed

+436
-36
lines changed

14 files changed

+436
-36
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,9 +2262,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
22622262
//===----------------------------------------------------------------------===//
22632263
// Operator: rescale
22642264
//===----------------------------------------------------------------------===//
2265-
def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
2266-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
2267-
["inferReturnTypeComponents"]>]> {
2265+
def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
22682266
let summary = "Tosa rescale operator";
22692267

22702268
let description = [{
@@ -2290,10 +2288,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
22902288

22912289
let arguments = (ins
22922290
Tosa_Tensor:$input,
2291+
Tosa_1DInt16Or32Tensor:$multiplier,
2292+
Tosa_1DInt8Tensor:$shift,
22932293
I32Attr:$input_zp,
22942294
I32Attr:$output_zp,
2295-
DenseI32ArrayAttr:$multiplier,
2296-
DenseI8ArrayAttr:$shift,
22972295
BoolAttr:$scale32,
22982296
BoolAttr:$double_round,
22992297
BoolAttr:$per_channel,
@@ -2310,6 +2308,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23102308
Extension<[Tosa_EXT_INT16]>,
23112309
];
23122310

2311+
let hasVerifier = 1;
2312+
23132313
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
23142314
}
23152315

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
4343

4444
def Tosa_Int4 : I<4>;
4545
def Tosa_Int8 : I<8>;
46+
def Tosa_Int16 : I<16>;
4647
def Tosa_Int32 : I<32>;
4748
def Tosa_Int64 : I<64>;
4849

@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
5455
AnySignlessInteger]>;
5556

5657
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
57-
Tosa_Int64]>;
58+
Tosa_Int64]>;
59+
60+
def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
61+
Tosa_Int32]>;
5862

5963
//===----------------------------------------------------------------------===//
6064
// Quantized Integer Types.
@@ -74,6 +78,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
7478
Tosa_QuantizedType<"int16", [16, 0], 1>,
7579
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
7680

81+
//===----------------------------------------------------------------------===//
82+
// Floating-point types.
83+
//===----------------------------------------------------------------------===//
84+
def Tosa_Float : AnyTypeOf<[F32,
85+
F16,
86+
BF16]>;
87+
88+
def Tosa_F8 : AnyTypeOf<[F8E4M3FN,
89+
F8E5M2]>;
90+
7791
//===----------------------------------------------------------------------===//
7892
// Multi-category types.
7993
//===----------------------------------------------------------------------===//
@@ -162,6 +176,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
162176
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
163177
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
164178

179+
// 1D tensor of specific types
180+
def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
181+
def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
182+
165183
// Ranked tensors up to given rank.
166184
def Tosa_Tensor1Dto4D : AnyTypeOf<[
167185
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,26 @@ namespace tosa {
3131
void computeMultiplierAndShift(double scale, int32_t &multiplier,
3232
int32_t &shift, int32_t scaleWidth);
3333

34+
// Return a const value for array of IntType vec
35+
template <typename IntType>
36+
Value getConstTensorInt(OpBuilder &builder, Location loc,
37+
ArrayRef<IntType> vec) {
38+
static_assert(
39+
std::is_same<IntType, int8_t>::value ||
40+
std::is_same<IntType, int16_t>::value ||
41+
std::is_same<IntType, int32_t>::value,
42+
"getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
43+
44+
int64_t count = vec.size();
45+
assert(count > 0 && "Vector must not be empty");
46+
auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
47+
mlir::RankedTensorType const_type =
48+
RankedTensorType::get({count}, element_type);
49+
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
50+
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
51+
return const_op.getResult();
52+
}
53+
3454
//// Builds ConvOpQuantizationAttr from input and weight.
3555
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3656
Value input, Value weight);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
138138
// tosa::MulOp
139139
if (isa<tosa::MulOp>(op)) {
140140
auto shift_val = cast<tosa::MulOp>(op).getShift();
141-
ElementsAttr shift_elem;
141+
DenseElementsAttr shift_elem;
142142
if (!shift_val.getImpl() ||
143143
!matchPattern(shift_val, m_Constant(&shift_elem))) {
144144
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13891389
}
13901390

13911391
// The shift and multiplier values.
1392-
SmallVector<int32_t> multiplierValues(op.getMultiplier());
1393-
SmallVector<int8_t> shiftValues(op.getShift());
1392+
DenseElementsAttr shiftElems;
1393+
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1394+
return rewriter.notifyMatchFailure(
1395+
op, "tosa.rescale requires constant shift input values");
1396+
1397+
DenseElementsAttr multiplierElems;
1398+
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1399+
return rewriter.notifyMatchFailure(
1400+
op, "tosa.rescale requires constant multiplier input values");
1401+
1402+
llvm::SmallVector<int8_t> shiftValues =
1403+
llvm::to_vector(shiftElems.getValues<int8_t>());
1404+
// explicit cast is required here
1405+
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1406+
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1407+
[](IntegerAttr attr) -> int32_t {
1408+
return static_cast<int32_t>(attr.getInt());
1409+
}));
13941410

13951411
// If we shift by more than the bitwidth, this just sets to 0.
13961412
for (int i = 0, s = multiplierValues.size(); i < s; i++) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,7 +2024,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
20242024
NARY_SHAPE_INFER(tosa::NegateOp)
20252025
NARY_SHAPE_INFER(tosa::PowOp)
20262026
NARY_SHAPE_INFER(tosa::ReciprocalOp)
2027-
NARY_SHAPE_INFER(tosa::RescaleOp)
20282027
NARY_SHAPE_INFER(tosa::ReverseOp)
20292028
NARY_SHAPE_INFER(tosa::RsqrtOp)
20302029
NARY_SHAPE_INFER(tosa::SinOp)
@@ -2469,6 +2468,138 @@ LogicalResult TransposeConv2DOp::verify() {
24692468
return success();
24702469
}
24712470

2471+
LogicalResult RescaleOp::verify() {
2472+
auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
2473+
if (!inputType) {
2474+
emitOpError("expect shaped tensor for input, got ") << getInput().getType();
2475+
return failure();
2476+
}
2477+
2478+
auto inputElementType =
2479+
getStorageElementTypeOrSelf(inputType.getElementType());
2480+
if (!mlir::isa<IntegerType>(inputElementType)) {
2481+
emitOpError("expect input to have integer element type, got ")
2482+
<< inputElementType;
2483+
return failure();
2484+
}
2485+
2486+
auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
2487+
if (!outputType) {
2488+
emitOpError("expect shaped tensor for output, got ")
2489+
<< getOutput().getType();
2490+
return failure();
2491+
}
2492+
2493+
auto outputElementType =
2494+
getStorageElementTypeOrSelf(outputType.getElementType());
2495+
if (!mlir::isa<IntegerType>(outputElementType)) {
2496+
emitOpError("expect output to have integer element type, got ")
2497+
<< outputElementType;
2498+
return failure();
2499+
}
2500+
2501+
auto input_zp = getInputZpAttr().getInt();
2502+
if (input_zp != 0) {
2503+
// only int8/uint8 and uint16 input can have non-zero input_zp
2504+
if (!inputElementType.isInteger(8) &&
2505+
!(inputElementType.isInteger(16) && getInputUnsigned())) {
2506+
emitOpError("expect input_zp of 0, got ") << input_zp;
2507+
return failure();
2508+
}
2509+
// input_zp must be either 0 or 32768 for uint16 input
2510+
if (inputElementType.isInteger(16) && getInputUnsigned() &&
2511+
input_zp != 32768) {
2512+
emitOpError(
2513+
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
2514+
<< input_zp;
2515+
return failure();
2516+
}
2517+
}
2518+
2519+
auto output_zp = getOutputZpAttr().getInt();
2520+
if (output_zp != 0) {
2521+
// only int8/uint8 and uint16 output can have non-zero output_zp
2522+
if (!outputElementType.isInteger(8) &&
2523+
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
2524+
emitOpError("expect output_zp of 0, got ") << output_zp;
2525+
return failure();
2526+
}
2527+
// output_zp must be either 0 or 32768 for uint16 output
2528+
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
2529+
output_zp != 32768) {
2530+
emitOpError(
2531+
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
2532+
<< output_zp;
2533+
return failure();
2534+
}
2535+
}
2536+
2537+
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
2538+
if (!multiplierType) {
2539+
emitOpError("expect shaped tensor for multiplier, got ")
2540+
<< getMultiplier().getType();
2541+
return failure();
2542+
}
2543+
2544+
auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
2545+
if (!shiftType) {
2546+
emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
2547+
return failure();
2548+
}
2549+
2550+
// multiplier element type must be i32 for scale32 = true
2551+
if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2552+
emitOpError("expect i32 element type for multiplier for scale32=true, got ")
2553+
<< multiplierType.getElementType();
2554+
return failure();
2555+
}
2556+
2557+
// multiplier element type must be i16 for scale32 = false
2558+
if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2559+
emitOpError(
2560+
"expect i16 element type for multiplier for scale32=false, got ")
2561+
<< multiplierType.getElementType();
2562+
return failure();
2563+
}
2564+
2565+
// multiplier/shift must have shape = {numChannels},
2566+
// where numChannel is 1 if per_channel = false
2567+
// otherwise numChannel is dimension in input shape's last axis
2568+
int64_t numChannels = 1;
2569+
if (getPerChannel()) {
2570+
ArrayRef<int64_t> inputShape = inputType.getShape();
2571+
numChannels = inputShape[inputShape.size() - 1];
2572+
}
2573+
2574+
ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
2575+
// multiplier input has rank 1 by dialect definition
2576+
if (multiplierShape[0] != numChannels) {
2577+
emitOpError("expect shape of { ")
2578+
<< numChannels << " } for multiplier input, got { "
2579+
<< multiplierShape[0] << " }";
2580+
return failure();
2581+
}
2582+
2583+
ArrayRef<int64_t> shiftShape = shiftType.getShape();
2584+
// shift input has rank 1 by dialect definition
2585+
if (shiftShape[0] != numChannels) {
2586+
emitOpError("expect shape of { ")
2587+
<< numChannels << " } for shift input, got { " << shiftShape[0] << " }";
2588+
return failure();
2589+
}
2590+
2591+
return success();
2592+
}
2593+
2594+
LogicalResult RescaleOp::inferReturnTypeComponents(
2595+
MLIRContext *context, ::std::optional<Location> location,
2596+
RescaleOp::Adaptor adaptor,
2597+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2598+
ShapeAdaptor inputShape(adaptor.getInput().getType());
2599+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2600+
return success();
2601+
}
2602+
24722603
LogicalResult IfOp::inferReturnTypeComponents(
24732604
MLIRContext *context, ::std::optional<Location> location,
24742605
IfOp::Adaptor adaptor,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3333

3434
// CHECK-LABEL: @rescale_unsupported_type
3535
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
36+
%multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
37+
%shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
3638
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
37-
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
39+
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3840
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3941
}
4042

0 commit comments

Comments
 (0)