Skip to content

Commit 5685def

Browse files
authored
[mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs (#129720)
1 parent d1bd1c7 commit 5685def

File tree

14 files changed

+435
-36
lines changed

14 files changed

+435
-36
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,9 +2279,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
22792279
//===----------------------------------------------------------------------===//
22802280
// Operator: rescale
22812281
//===----------------------------------------------------------------------===//
2282-
def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
2283-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
2284-
["inferReturnTypeComponents"]>]> {
2282+
def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
22852283
let summary = "Tosa rescale operator";
22862284

22872285
let description = [{
@@ -2307,10 +2305,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23072305

23082306
let arguments = (ins
23092307
Tosa_Tensor:$input,
2308+
Tosa_1DInt16Or32Tensor:$multiplier,
2309+
Tosa_1DInt8Tensor:$shift,
23102310
I32Attr:$input_zp,
23112311
I32Attr:$output_zp,
2312-
DenseI32ArrayAttr:$multiplier,
2313-
DenseI8ArrayAttr:$shift,
23142312
BoolAttr:$scale32,
23152313
BoolAttr:$double_round,
23162314
BoolAttr:$per_channel,
@@ -2327,6 +2325,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
23272325
Extension<[Tosa_EXT_INT16]>,
23282326
];
23292327

2328+
let hasVerifier = 1;
2329+
23302330
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
23312331
}
23322332

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
4343

4444
def Tosa_Int4 : I<4>;
4545
def Tosa_Int8 : I<8>;
46+
def Tosa_Int16 : I<16>;
4647
def Tosa_Int32 : I<32>;
4748
def Tosa_Int64 : I<64>;
4849

@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
5455
AnySignlessInteger]>;
5556

5657
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
57-
Tosa_Int64]>;
58+
Tosa_Int64]>;
59+
60+
def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
61+
Tosa_Int32]>;
5862

5963
//===----------------------------------------------------------------------===//
6064
// Quantized Integer Types.
@@ -163,6 +167,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
163167
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
164168
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
165169

170+
// 1D tensor of specific types
171+
def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
172+
def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
173+
166174
// Ranked tensors up to given rank.
167175
def Tosa_Tensor1Dto4D : AnyTypeOf<[
168176
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,26 @@ namespace tosa {
3131
bool computeMultiplierAndShift(double scale, int32_t &multiplier,
3232
int32_t &shift, int32_t scaleWidth);
3333

34+
// Return a const value for array of IntType vec
35+
template <typename IntType>
36+
Value getConstTensorInt(OpBuilder &builder, Location loc,
37+
ArrayRef<IntType> vec) {
38+
static_assert(
39+
std::is_same<IntType, int8_t>::value ||
40+
std::is_same<IntType, int16_t>::value ||
41+
std::is_same<IntType, int32_t>::value,
42+
"getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
43+
44+
int64_t count = vec.size();
45+
assert(count > 0 && "Vector must not be empty");
46+
auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
47+
mlir::RankedTensorType const_type =
48+
RankedTensorType::get({count}, element_type);
49+
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
50+
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
51+
return const_op.getResult();
52+
}
53+
3454
//// Builds ConvOpQuantizationAttr from input and weight.
3555
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3656
Value input, Value weight);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
138138
// tosa::MulOp
139139
if (isa<tosa::MulOp>(op)) {
140140
auto shift_val = cast<tosa::MulOp>(op).getShift();
141-
ElementsAttr shift_elem;
141+
DenseElementsAttr shift_elem;
142142
if (!shift_val.getImpl() ||
143143
!matchPattern(shift_val, m_Constant(&shift_elem))) {
144144
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13891389
}
13901390

13911391
// The shift and multiplier values.
1392-
SmallVector<int32_t> multiplierValues(op.getMultiplier());
1393-
SmallVector<int8_t> shiftValues(op.getShift());
1392+
DenseElementsAttr shiftElems;
1393+
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1394+
return rewriter.notifyMatchFailure(
1395+
op, "tosa.rescale requires constant shift input values");
1396+
1397+
DenseElementsAttr multiplierElems;
1398+
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1399+
return rewriter.notifyMatchFailure(
1400+
op, "tosa.rescale requires constant multiplier input values");
1401+
1402+
llvm::SmallVector<int8_t> shiftValues =
1403+
llvm::to_vector(shiftElems.getValues<int8_t>());
1404+
// explicit cast is required here
1405+
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1406+
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1407+
[](IntegerAttr attr) -> int32_t {
1408+
return static_cast<int32_t>(attr.getInt());
1409+
}));
13941410

13951411
// If we shift by more than the bitwidth, this just sets to 0.
13961412
for (int i = 0, s = multiplierValues.size(); i < s; i++) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
22262226
NARY_SHAPE_INFER(tosa::NegateOp)
22272227
NARY_SHAPE_INFER(tosa::PowOp)
22282228
NARY_SHAPE_INFER(tosa::ReciprocalOp)
2229-
NARY_SHAPE_INFER(tosa::RescaleOp)
22302229
NARY_SHAPE_INFER(tosa::ReverseOp)
22312230
NARY_SHAPE_INFER(tosa::RsqrtOp)
22322231
NARY_SHAPE_INFER(tosa::SinOp)
@@ -2676,6 +2675,147 @@ LogicalResult TransposeConv2DOp::verify() {
26762675
return success();
26772676
}
26782677

2678+
LogicalResult RescaleOp::verify() {
2679+
auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
2680+
if (!inputType) {
2681+
emitOpError("expect shaped tensor for input, got ") << getInput().getType();
2682+
return failure();
2683+
}
2684+
2685+
auto inputElementType =
2686+
getStorageElementTypeOrSelf(inputType.getElementType());
2687+
if (!mlir::isa<IntegerType>(inputElementType)) {
2688+
emitOpError("expect input to have integer element type, got ")
2689+
<< inputElementType;
2690+
return failure();
2691+
}
2692+
2693+
auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
2694+
if (!outputType) {
2695+
emitOpError("expect shaped tensor for output, got ")
2696+
<< getOutput().getType();
2697+
return failure();
2698+
}
2699+
2700+
auto outputElementType =
2701+
getStorageElementTypeOrSelf(outputType.getElementType());
2702+
if (!mlir::isa<IntegerType>(outputElementType)) {
2703+
emitOpError("expect output to have integer element type, got ")
2704+
<< outputElementType;
2705+
return failure();
2706+
}
2707+
2708+
auto input_zp = getInputZpAttr().getInt();
2709+
if (input_zp != 0) {
2710+
// only int8/uint8 and uint16 input can have non-zero input_zp
2711+
if (!inputElementType.isInteger(8) &&
2712+
!(inputElementType.isInteger(16) && getInputUnsigned())) {
2713+
emitOpError("expect input_zp of 0, got ") << input_zp;
2714+
return failure();
2715+
}
2716+
// input_zp must be either 0 or 32768 for uint16 input
2717+
if (inputElementType.isInteger(16) && getInputUnsigned() &&
2718+
input_zp != 32768) {
2719+
emitOpError(
2720+
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
2721+
<< input_zp;
2722+
return failure();
2723+
}
2724+
}
2725+
2726+
auto output_zp = getOutputZpAttr().getInt();
2727+
if (output_zp != 0) {
2728+
// only int8/uint8 and uint16 output can have non-zero output_zp
2729+
if (!outputElementType.isInteger(8) &&
2730+
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
2731+
emitOpError("expect output_zp of 0, got ") << output_zp;
2732+
return failure();
2733+
}
2734+
// output_zp must be either 0 or 32768 for uint16 output
2735+
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
2736+
output_zp != 32768) {
2737+
emitOpError(
2738+
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
2739+
<< output_zp;
2740+
return failure();
2741+
}
2742+
}
2743+
2744+
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
2745+
if (!multiplierType) {
2746+
emitOpError("expect shaped tensor for multiplier, got ")
2747+
<< getMultiplier().getType();
2748+
return failure();
2749+
}
2750+
2751+
auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
2752+
if (!shiftType) {
2753+
emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
2754+
return failure();
2755+
}
2756+
2757+
// multiplier element type must be i32 for scale32 = true
2758+
if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2759+
emitOpError("expect i32 element type for multiplier for scale32=true, got ")
2760+
<< multiplierType.getElementType();
2761+
return failure();
2762+
}
2763+
2764+
// multiplier element type must be i16 for scale32 = false
2765+
if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2766+
emitOpError(
2767+
"expect i16 element type for multiplier for scale32=false, got ")
2768+
<< multiplierType.getElementType();
2769+
return failure();
2770+
}
2771+
2772+
if (!inputType.hasRank())
2773+
return success();
2774+
2775+
// multiplier/shift must have shape = {numChannels},
2776+
// where numChannel is 1 if per_channel = false
2777+
// otherwise numChannel is dimension in input shape's last axis
2778+
int64_t numChannels = 1;
2779+
if (getPerChannel()) {
2780+
numChannels = inputType.getDimSize(inputType.getRank() - 1);
2781+
}
2782+
2783+
if (!multiplierType.hasRank())
2784+
return success();
2785+
2786+
ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
2787+
// multiplier input has rank 1 by dialect definition
2788+
if (multiplierShape[0] != ShapedType::kDynamic &&
2789+
multiplierShape[0] != numChannels) {
2790+
emitOpError("expect shape of { ")
2791+
<< numChannels << " } for multiplier input, got { "
2792+
<< multiplierShape[0] << " }";
2793+
return failure();
2794+
}
2795+
2796+
if (!shiftType.hasRank())
2797+
return success();
2798+
2799+
ArrayRef<int64_t> shiftShape = shiftType.getShape();
2800+
// shift input has rank 1 by dialect definition
2801+
if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
2802+
emitOpError("expect shape of { ")
2803+
<< numChannels << " } for shift input, got { " << shiftShape[0] << " }";
2804+
return failure();
2805+
}
2806+
2807+
return success();
2808+
}
2809+
2810+
LogicalResult RescaleOp::inferReturnTypeComponents(
2811+
MLIRContext *context, ::std::optional<Location> location,
2812+
RescaleOp::Adaptor adaptor,
2813+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2814+
ShapeAdaptor inputShape(adaptor.getInput().getType());
2815+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2816+
return success();
2817+
}
2818+
26792819
LogicalResult IfOp::inferReturnTypeComponents(
26802820
MLIRContext *context, ::std::optional<Location> location,
26812821
IfOp::Adaptor adaptor,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3333

3434
// CHECK-LABEL: @rescale_unsupported_type
3535
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
36+
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
37+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
3638
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
37-
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
39+
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3840
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
3941
}
4042

0 commit comments

Comments
 (0)