Skip to content

[mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs #129720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2279,9 +2279,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>]> {
def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
let summary = "Tosa rescale operator";

let description = [{
Expand All @@ -2307,10 +2305,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,

let arguments = (ins
Tosa_Tensor:$input,
Tosa_1DInt16Or32Tensor:$multiplier,
Tosa_1DInt8Tensor:$shift,
I32Attr:$input_zp,
I32Attr:$output_zp,
DenseI32ArrayAttr:$multiplier,
DenseI8ArrayAttr:$shift,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel,
Expand All @@ -2327,6 +2325,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
Extension<[Tosa_EXT_INT16]>,
];

let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

Expand Down
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>

def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
def Tosa_Int64 : I<64>;

Expand All @@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
AnySignlessInteger]>;

def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
Tosa_Int64]>;
Tosa_Int64]>;

def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
Tosa_Int32]>;

//===----------------------------------------------------------------------===//
// Quantized Integer Types.
Expand Down Expand Up @@ -163,6 +167,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;

// 1D tensor of specific types
def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;

// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ namespace tosa {
bool computeMultiplierAndShift(double scale, int32_t &multiplier,
int32_t &shift, int32_t scaleWidth);

// Return a const value for array of IntType vec
template <typename IntType>
Value getConstTensorInt(OpBuilder &builder, Location loc,
ArrayRef<IntType> vec) {
static_assert(
std::is_same<IntType, int8_t>::value ||
std::is_same<IntType, int16_t>::value ||
std::is_same<IntType, int32_t>::value,
"getConstTensorInt only supports int8_t, int16_t, and int32_t types.");

int64_t count = vec.size();
assert(count > 0 && "Vector must not be empty");
auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
mlir::RankedTensorType const_type =
RankedTensorType::get({count}, element_type);
mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
return const_op.getResult();
}

//// Builds ConvOpQuantizationAttr from input and weight.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);
Expand Down
22 changes: 19 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
auto shift_val = cast<tosa::MulOp>(op).getShift();
ElementsAttr shift_elem;
DenseElementsAttr shift_elem;
if (!shift_val.getImpl() ||
!matchPattern(shift_val, m_Constant(&shift_elem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
Expand Down Expand Up @@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}

// The shift and multiplier values.
SmallVector<int32_t> multiplierValues(op.getMultiplier());
SmallVector<int8_t> shiftValues(op.getShift());
DenseElementsAttr shiftElems;
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant shift input values");

DenseElementsAttr multiplierElems;
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant multiplier input values");

llvm::SmallVector<int8_t> shiftValues =
llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
[](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));

// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
Expand Down
142 changes: 141 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2226,7 +2226,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SinOp)
Expand Down Expand Up @@ -2676,6 +2675,147 @@ LogicalResult TransposeConv2DOp::verify() {
return success();
}

LogicalResult RescaleOp::verify() {
auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
if (!inputType) {
emitOpError("expect shaped tensor for input, got ") << getInput().getType();
return failure();
}

auto inputElementType =
getStorageElementTypeOrSelf(inputType.getElementType());
if (!mlir::isa<IntegerType>(inputElementType)) {
emitOpError("expect input to have integer element type, got ")
<< inputElementType;
return failure();
}

auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
if (!outputType) {
emitOpError("expect shaped tensor for output, got ")
<< getOutput().getType();
return failure();
}

auto outputElementType =
getStorageElementTypeOrSelf(outputType.getElementType());
if (!mlir::isa<IntegerType>(outputElementType)) {
emitOpError("expect output to have integer element type, got ")
<< outputElementType;
return failure();
}

auto input_zp = getInputZpAttr().getInt();
if (input_zp != 0) {
// only int8/uint8 and uint16 input can have non-zero input_zp
if (!inputElementType.isInteger(8) &&
!(inputElementType.isInteger(16) && getInputUnsigned())) {
emitOpError("expect input_zp of 0, got ") << input_zp;
return failure();
}
// input_zp must be either 0 or 32768 for uint16 input
if (inputElementType.isInteger(16) && getInputUnsigned() &&
input_zp != 32768) {
emitOpError(
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
<< input_zp;
return failure();
}
}

auto output_zp = getOutputZpAttr().getInt();
if (output_zp != 0) {
// only int8/uint8 and uint16 output can have non-zero output_zp
if (!outputElementType.isInteger(8) &&
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
emitOpError("expect output_zp of 0, got ") << output_zp;
return failure();
}
// output_zp must be either 0 or 32768 for uint16 output
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
output_zp != 32768) {
emitOpError(
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
<< output_zp;
return failure();
}
}

auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
if (!multiplierType) {
emitOpError("expect shaped tensor for multiplier, got ")
<< getMultiplier().getType();
return failure();
}

auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
if (!shiftType) {
emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
return failure();
}

// multiplier element type must be i32 for scale32 = true
if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
emitOpError("expect i32 element type for multiplier for scale32=true, got ")
<< multiplierType.getElementType();
return failure();
}

// multiplier element type must be i16 for scale32 = false
if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
emitOpError(
"expect i16 element type for multiplier for scale32=false, got ")
<< multiplierType.getElementType();
return failure();
}

if (!inputType.hasRank())
return success();

// multiplier/shift must have shape = {numChannels},
// where numChannel is 1 if per_channel = false
// otherwise numChannel is dimension in input shape's last axis
int64_t numChannels = 1;
if (getPerChannel()) {
numChannels = inputType.getDimSize(inputType.getRank() - 1);
}

if (!multiplierType.hasRank())
return success();

ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
// multiplier input has rank 1 by dialect definition
if (multiplierShape[0] != ShapedType::kDynamic &&
multiplierShape[0] != numChannels) {
emitOpError("expect shape of { ")
<< numChannels << " } for multiplier input, got { "
<< multiplierShape[0] << " }";
return failure();
}

if (!shiftType.hasRank())
return success();

ArrayRef<int64_t> shiftShape = shiftType.getShape();
// shift input has rank 1 by dialect definition
if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
emitOpError("expect shape of { ")
<< numChannels << " } for shift input, got { " << shiftShape[0] << " }";
return failure();
}

return success();
}

LogicalResult RescaleOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
RescaleOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput().getType());
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
return success();
}

LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %

// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}

Expand Down
Loading