Skip to content

[mlir][tosa] Change zero points of convolution ops to required inputs #127679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 0 additions & 106 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,112 +168,6 @@ namespace tosa {
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Type srcElemType, int64_t zp = 0);

// Get zero point value from the attribute argument.
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);

// Verify if zero point falls into valid range.
template <typename T>
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
!std::is_same_v<T, DepthwiseConv2DOp> &&
!std::is_same_v<T, TransposeConv2DOp>) {
return failure();
}

if (!zpElemType.isIntOrFloat())
return failure();

if (!zpElemType.isInteger(8) && zp != 0)
return failure();

if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
return failure();

if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
return failure();

return success();
}

// Helper type trait to determine if an operation is a tosa convolution.
template <typename Op>
struct IsTosaConv : std::false_type {};

template <>
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};

template <typename Op>
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;

// Helper struct to hold the zero points of a TOSA convolution operation as
// named 64-bit integer fields.
struct ConvZpPair {
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
: inputZp(inputZp), weightZp(weightZp) {}
std::int64_t inputZp;
std::int64_t weightZp;
};

// Helper function which attempts to extract the zero points from a TOSA
// convolution by matching them against defining ops which should be tosa.const
// operations.
//
// There are three possible results:
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
// do exist but are invalid.
// 2. Succeeded in extracting zero-points.
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
// convolution.
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
template <typename TosaConvOp>
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
// Strictly speaking the base TOSA spec requires that for non int8 types
// zero points must be zero. However, in the dialect these operands are
// optional and only required for int8. They have no semantic meaning for
// non-quantized types and can therefore be safely ignored. This is case 3.
if (auto opElementTY =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
!opElementTY.isInteger(8))
return FailOrMaybeZP(std::nullopt);

// Now we know we should have a zero point check it is valid.
if (!op.getInputZp())
return rewriter.notifyMatchFailure(op, "missing input zero point");

// Helper to extract the zero point by matching its definition against a
// constant.
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
ElementsAttr zpAttr;
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
return std::nullopt;

int64_t zp;
if (tosa::getZeroPoint(zpAttr, zp).failed())
return std::nullopt;

return std::make_optional(zp);
};

auto maybeInputZp = extractZeroPoint(op.getInputZp());
if (!maybeInputZp)
return rewriter.notifyMatchFailure(op, "unable to extract input zp");

if (!op.getWeightZp())
return rewriter.notifyMatchFailure(op, "missing weight zero point");

auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
if (!maybeWeightZp)
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");

return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
}
} // namespace tosa
} // namespace mlir

Expand Down
48 changes: 40 additions & 8 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand All @@ -134,6 +135,13 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];

let builders = [Tosa_ConvOpQuantInfoBuilder];
let hasVerifier = 1;
}
Expand All @@ -153,8 +161,9 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,

Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
Expand All @@ -171,6 +180,13 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];

let builders = [Tosa_ConvOpQuantInfoBuilder];
let hasVerifier = 1;
}
Expand All @@ -191,8 +207,9 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand All @@ -209,6 +226,13 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];

let builders = [Tosa_ConvOpQuantInfoBuilder];
let hasVerifier = 1;
}
Expand Down Expand Up @@ -379,8 +403,9 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,

Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
Expand All @@ -397,6 +422,13 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];

let builders = [Tosa_TransConvOpQuantInfoBuilder];
let hasVerifier = 1;
}
Expand Down
63 changes: 41 additions & 22 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();

auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
if (llvm::failed(failureOrMaybeZps))
return failure();
// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;

if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");

if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");

auto maybeZps = failureOrMaybeZps.value();
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);

if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
Expand All @@ -289,19 +299,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {

// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (maybeZps) {
if (hasZp) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();

if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
if (inputZpVal < intMin || inputZpVal > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");

zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
}

llvm::SmallVector<int64_t> pad;
Expand All @@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// For 2D convolutions, we need to check if the target convolution op
// wants a HWCF kernel layout.
bool wantHwcf =
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
if (wantHwcf) {
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
Expand Down Expand Up @@ -372,9 +382,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);

if (maybeZps) {
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
if (hasZp) {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);

auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Expand Down Expand Up @@ -437,31 +447,40 @@ class DepthwiseConvConverter
/*inputSizeDims=*/{1, 2},
/*kernelSizeDims=*/{0, 1}, rewriter);

auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
if (llvm::failed(failureOrMaybeZps))
return failure();
// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;

if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");

auto maybeZps = failureOrMaybeZps.value();
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");

bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();

// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (maybeZps) {
if (hasZp) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();

if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
if (inputZpVal < intMin || inputZpVal > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv op quantization has zp outside of input "
"range");

zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
}

llvm::SmallVector<int64_t> pad;
Expand Down Expand Up @@ -501,7 +520,7 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));

if (!maybeZps) {
if (!hasZp) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
Expand All @@ -528,8 +547,8 @@ class DepthwiseConvConverter
.getResult(0);
rewriter.replaceOp(op, result);
} else {
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
Value conv =
Expand Down
Loading