Skip to content

Commit c30014d

Browse files
psunnTai78641
andcommitted
[TOSA] Convert RESCALE op multiplier and shift from attributes to inputs
This patch updates the TOSA RescaleOp by converting its multiplier and shift parameters from attributes to explicit inputs, aligning the op with the TOSA V1.0 specification. Additionally, this commit adds RescaleOp-specific implementations of inferReturnTypeComponents and verify functions. Co-authored-by: Tai Ly <[email protected]> Signed-off-by: Peng Sun <[email protected]> Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
1 parent e9de91e commit c30014d

File tree

14 files changed

+435
-36
lines changed

14 files changed

+435
-36
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,9 +2279,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
22792279
//===----------------------------------------------------------------------===//
22802280
// Operator: rescale
22812281
//===----------------------------------------------------------------------===//
2282-
def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
2283-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
2284-
["inferReturnTypeComponents"]>]> {
2282+
def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
22852283
let summary = "Tosa rescale operator";
22862284

22872285
let description = [{
@@ -2307,10 +2305,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23072305

23082306
let arguments = (ins
23092307
Tosa_Tensor:$input,
2308+
Tosa_1DInt16Or32Tensor:$multiplier,
2309+
Tosa_1DInt8Tensor:$shift,
23102310
I32Attr:$input_zp,
23112311
I32Attr:$output_zp,
2312-
DenseI32ArrayAttr:$multiplier,
2313-
DenseI8ArrayAttr:$shift,
23142312
BoolAttr:$scale32,
23152313
BoolAttr:$double_round,
23162314
BoolAttr:$per_channel,
@@ -2327,6 +2325,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23272325
Extension<[Tosa_EXT_INT16]>,
23282326
];
23292327

2328+
let hasVerifier = 1;
2329+
23302330
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
23312331
}
23322332

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
4343

4444
def Tosa_Int4 : I<4>;
4545
def Tosa_Int8 : I<8>;
46+
def Tosa_Int16 : I<16>;
4647
def Tosa_Int32 : I<32>;
4748
def Tosa_Int64 : I<64>;
4849

@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
5455
AnySignlessInteger]>;
5556

5657
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
57-
Tosa_Int64]>;
58+
Tosa_Int64]>;
59+
60+
def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
61+
Tosa_Int32]>;
5862

5963
//===----------------------------------------------------------------------===//
6064
// Quantized Integer Types.
@@ -163,6 +167,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
163167
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
164168
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
165169

170+
// 1D tensor of specific types
171+
def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
172+
def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
173+
166174
// Ranked tensors up to given rank.
167175
def Tosa_Tensor1Dto4D : AnyTypeOf<[
168176
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,26 @@ namespace tosa {
3131
bool computeMultiplierAndShift(double scale, int32_t &multiplier,
3232
int32_t &shift, int32_t scaleWidth);
3333

34+
// Return a const value for array of IntType vec
35+
template <typename IntType>
36+
Value getConstTensorInt(OpBuilder &builder, Location loc,
37+
ArrayRef<IntType> vec) {
38+
static_assert(
39+
std::is_same<IntType, int8_t>::value ||
40+
std::is_same<IntType, int16_t>::value ||
41+
std::is_same<IntType, int32_t>::value,
42+
"getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
43+
44+
int64_t count = vec.size();
45+
assert(count > 0 && "Vector must not be empty");
46+
auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
47+
mlir::RankedTensorType const_type =
48+
RankedTensorType::get({count}, element_type);
49+
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
50+
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
51+
return const_op.getResult();
52+
}
53+
3454
//// Builds ConvOpQuantizationAttr from input and weight.
3555
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3656
Value input, Value weight);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
138138
// tosa::MulOp
139139
if (isa<tosa::MulOp>(op)) {
140140
auto shift_val = cast<tosa::MulOp>(op).getShift();
141-
ElementsAttr shift_elem;
141+
DenseElementsAttr shift_elem;
142142
if (!shift_val.getImpl() ||
143143
!matchPattern(shift_val, m_Constant(&shift_elem))) {
144144
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13891389
}
13901390

13911391
// The shift and multiplier values.
1392-
SmallVector<int32_t> multiplierValues(op.getMultiplier());
1393-
SmallVector<int8_t> shiftValues(op.getShift());
1392+
DenseElementsAttr shiftElems;
1393+
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1394+
return rewriter.notifyMatchFailure(
1395+
op, "tosa.rescale requires constant shift input values");
1396+
1397+
DenseElementsAttr multiplierElems;
1398+
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1399+
return rewriter.notifyMatchFailure(
1400+
op, "tosa.rescale requires constant multiplier input values");
1401+
1402+
llvm::SmallVector<int8_t> shiftValues =
1403+
llvm::to_vector(shiftElems.getValues<int8_t>());
1404+
// explicit cast is required here
1405+
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1406+
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1407+
[](IntegerAttr attr) -> int32_t {
1408+
return static_cast<int32_t>(attr.getInt());
1409+
}));
13941410

13951411
// If we shift by more than the bitwidth, this just sets to 0.
13961412
for (int i = 0, s = multiplierValues.size(); i < s; i++) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
22262226
NARY_SHAPE_INFER(tosa::NegateOp)
22272227
NARY_SHAPE_INFER(tosa::PowOp)
22282228
NARY_SHAPE_INFER(tosa::ReciprocalOp)
2229-
NARY_SHAPE_INFER(tosa::RescaleOp)
22302229
NARY_SHAPE_INFER(tosa::ReverseOp)
22312230
NARY_SHAPE_INFER(tosa::RsqrtOp)
22322231
NARY_SHAPE_INFER(tosa::SinOp)
@@ -2676,6 +2675,147 @@ LogicalResult TransposeConv2DOp::verify() {
26762675
return success();
26772676
}
26782677

2678+
LogicalResult RescaleOp::verify() {
2679+
auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
2680+
if (!inputType) {
2681+
emitOpError("expect shaped tensor for input, got ") << getInput().getType();
2682+
return failure();
2683+
}
2684+
2685+
auto inputElementType =
2686+
getStorageElementTypeOrSelf(inputType.getElementType());
2687+
if (!mlir::isa<IntegerType>(inputElementType)) {
2688+
emitOpError("expect input to have integer element type, got ")
2689+
<< inputElementType;
2690+
return failure();
2691+
}
2692+
2693+
auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
2694+
if (!outputType) {
2695+
emitOpError("expect shaped tensor for output, got ")
2696+
<< getOutput().getType();
2697+
return failure();
2698+
}
2699+
2700+
auto outputElementType =
2701+
getStorageElementTypeOrSelf(outputType.getElementType());
2702+
if (!mlir::isa<IntegerType>(outputElementType)) {
2703+
emitOpError("expect output to have integer element type, got ")
2704+
<< outputElementType;
2705+
return failure();
2706+
}
2707+
2708+
auto input_zp = getInputZpAttr().getInt();
2709+
if (input_zp != 0) {
2710+
// only int8/uint8 and uint16 input can have non-zero input_zp
2711+
if (!inputElementType.isInteger(8) &&
2712+
!(inputElementType.isInteger(16) && getInputUnsigned())) {
2713+
emitOpError("expect input_zp of 0, got ") << input_zp;
2714+
return failure();
2715+
}
2716+
// input_zp must be either 0 or 32768 for uint16 input
2717+
if (inputElementType.isInteger(16) && getInputUnsigned() &&
2718+
input_zp != 32768) {
2719+
emitOpError(
2720+
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
2721+
<< input_zp;
2722+
return failure();
2723+
}
2724+
}
2725+
2726+
auto output_zp = getOutputZpAttr().getInt();
2727+
if (output_zp != 0) {
2728+
// only int8/uint8 and uint16 output can have non-zero output_zp
2729+
if (!outputElementType.isInteger(8) &&
2730+
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
2731+
emitOpError("expect output_zp of 0, got ") << output_zp;
2732+
return failure();
2733+
}
2734+
// output_zp must be either 0 or 32768 for uint16 output
2735+
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
2736+
output_zp != 32768) {
2737+
emitOpError(
2738+
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
2739+
<< output_zp;
2740+
return failure();
2741+
}
2742+
}
2743+
2744+
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
2745+
if (!multiplierType) {
2746+
emitOpError("expect shaped tensor for multiplier, got ")
2747+
<< getMultiplier().getType();
2748+
return failure();
2749+
}
2750+
2751+
auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
2752+
if (!shiftType) {
2753+
emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
2754+
return failure();
2755+
}
2756+
2757+
// multiplier element type must be i32 for scale32 = true
2758+
if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2759+
emitOpError("expect i32 element type for multiplier for scale32=true, got ")
2760+
<< multiplierType.getElementType();
2761+
return failure();
2762+
}
2763+
2764+
// multiplier element type must be i16 for scale32 = false
2765+
if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2766+
emitOpError(
2767+
"expect i16 element type for multiplier for scale32=false, got ")
2768+
<< multiplierType.getElementType();
2769+
return failure();
2770+
}
2771+
2772+
if (!inputType.hasRank())
2773+
return success();
2774+
2775+
// multiplier/shift must have shape = {numChannels},
2776+
// where numChannel is 1 if per_channel = false
2777+
// otherwise numChannel is dimension in input shape's last axis
2778+
int64_t numChannels = 1;
2779+
if (getPerChannel()) {
2780+
numChannels = inputType.getDimSize(inputType.getRank() - 1);
2781+
}
2782+
2783+
if (!multiplierType.hasRank())
2784+
return success();
2785+
2786+
ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
2787+
// multiplier input has rank 1 by dialect definition
2788+
if (multiplierShape[0] != ShapedType::kDynamic &&
2789+
multiplierShape[0] != numChannels) {
2790+
emitOpError("expect shape of { ")
2791+
<< numChannels << " } for multiplier input, got { "
2792+
<< multiplierShape[0] << " }";
2793+
return failure();
2794+
}
2795+
2796+
if (!shiftType.hasRank())
2797+
return success();
2798+
2799+
ArrayRef<int64_t> shiftShape = shiftType.getShape();
2800+
// shift input has rank 1 by dialect definition
2801+
if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
2802+
emitOpError("expect shape of { ")
2803+
<< numChannels << " } for shift input, got { " << shiftShape[0] << " }";
2804+
return failure();
2805+
}
2806+
2807+
return success();
2808+
}
2809+
2810+
LogicalResult RescaleOp::inferReturnTypeComponents(
2811+
MLIRContext *context, ::std::optional<Location> location,
2812+
RescaleOp::Adaptor adaptor,
2813+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2814+
ShapeAdaptor inputShape(adaptor.getInput().getType());
2815+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2816+
return success();
2817+
}
2818+
26792819
LogicalResult IfOp::inferReturnTypeComponents(
26802820
MLIRContext *context, ::std::optional<Location> location,
26812821
IfOp::Adaptor adaptor,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3333

3434
// CHECK-LABEL: @rescale_unsupported_type
3535
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
36+
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
37+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
3638
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
37-
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
39+
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3840
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3941
}
4042

0 commit comments

Comments
 (0)