Skip to content

Commit cc72042

Browse files
[mlir][tosa] Make Convolution Zero Points Inputs (#122939)
The TOSA-v1.0 specification moves the "zero point" parameters of the convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately. Rename the "filter" argument of `tosa.transpose_conv2d` to weight to align with the TOSA specification. Remove the quantization_info attribute on the convolution operations. Co-authored-by: TatWai Chong <[email protected]>
1 parent 9ad4ebd commit cc72042

19 files changed

+576
-198
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
264264
"operands attr-dict `:` functional-type(operands, results)";
265265
}
266266

267+
// The "SameVariadicOperandSize" trait allows us to pass optional arguments
268+
// for multiple zero points in convolution ops.
269+
class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
270+
: Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
271+
[SameVariadicOperandSize])> {
272+
}
273+
267274
#endif // TOSA_OP_BASE

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
1717
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1818
#include "mlir/Dialect/Traits.h"
19+
#include "mlir/IR/Matchers.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/IR/OpImplementation.h"
2122
#include "mlir/IR/TypeUtilities.h"
@@ -29,6 +30,7 @@
2930
//===----------------------------------------------------------------------===//
3031

3132
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
33+
#include "mlir/Transforms/DialectConversion.h"
3234

3335
namespace mlir {
3436
class PatternRewriter;
@@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t);
152154
#define GET_OP_CLASSES
153155
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
154156

157+
namespace mlir {
158+
namespace tosa {
159+
160+
// Create a rank-1 const tensor for zero point of the source tensor.
161+
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
162+
Type srcElemType, int64_t zp = 0);
163+
164+
// Get zero point value from the attribute argument.
165+
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
166+
167+
// Verify if zero point falls into valid range.
168+
template <typename T>
169+
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
170+
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
171+
!std::is_same_v<T, DepthwiseConv2DOp> &&
172+
!std::is_same_v<T, TransposeConv2DOp>) {
173+
return failure();
174+
}
175+
176+
if (!zpElemType.isIntOrFloat())
177+
return failure();
178+
179+
if (!zpElemType.isInteger(8) && zp != 0)
180+
return failure();
181+
182+
if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
183+
return failure();
184+
185+
if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
186+
return failure();
187+
188+
return success();
189+
}
190+
191+
// Helper type trait to determine if an operation is a tosa convolution.
192+
template <typename Op>
193+
struct IsTosaConv : std::false_type {};
194+
195+
template <>
196+
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
197+
template <>
198+
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
199+
template <>
200+
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
201+
template <>
202+
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
203+
204+
template <typename Op>
205+
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
206+
207+
// Helper struct to hold the zero points of a TOSA convolution operation as
208+
// named 64-bit integer fields.
209+
struct ConvZpPair {
210+
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
211+
: inputZp(inputZp), weightZp(weightZp) {}
212+
std::int64_t inputZp;
213+
std::int64_t weightZp;
214+
};
215+
216+
// Helper function which attempts to extract the zero points from a TOSA
217+
// convolution by matching them against defining ops which should be tosa.const
218+
// operations.
219+
//
220+
// There are three possible results:
221+
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
222+
// do exist but are invalid.
223+
// 2. Succeeded in extracting zero-points.
224+
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
225+
// convolution.
226+
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
227+
template <typename TosaConvOp>
228+
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
229+
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
230+
// Strictly speaking the base TOSA spec requires that for non int8 types
231+
// zero points must be zero. However, in the dialect these operands are
232+
// optional and only required for int8. They have no semantic meaning for
233+
// non-quantized types and can therefore be safely ignored. This is case 3.
234+
if (auto opElementTY =
235+
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
236+
!opElementTY.isInteger(8))
237+
return FailOrMaybeZP(std::nullopt);
238+
239+
// Now we know we should have a zero point check it is valid.
240+
if (!op.getInputZp())
241+
return rewriter.notifyMatchFailure(op, "missing input zero point");
242+
243+
// Helper to extract the zero point by matching its definition against a
244+
// constant.
245+
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
246+
ElementsAttr zpAttr;
247+
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
248+
return std::nullopt;
249+
250+
int64_t zp;
251+
if (tosa::getZeroPoint(zpAttr, zp).failed())
252+
return std::nullopt;
253+
254+
return std::make_optional(zp);
255+
};
256+
257+
auto maybeInputZp = extractZeroPoint(op.getInputZp());
258+
if (!maybeInputZp)
259+
return rewriter.notifyMatchFailure(op, "unable to extract input zp");
260+
261+
if (!op.getWeightZp())
262+
return rewriter.notifyMatchFailure(op, "missing weight zero point");
263+
264+
auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
265+
if (!maybeWeightZp)
266+
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
267+
268+
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
269+
}
270+
} // namespace tosa
271+
} // namespace mlir
272+
155273
#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
9292
//===----------------------------------------------------------------------===//
9393
// Operator: conv2d
9494
//===----------------------------------------------------------------------===//
95-
def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
95+
def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
9696
let summary = "2D Convolution Operator";
9797

9898
let description = [{
@@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
104104
Tosa_Tensor4D:$input,
105105
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
106106
Tosa_Tensor1D:$bias,
107+
Optional<Tosa_ZeroPointTensor>:$input_zp,
108+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
107109
Tosa_IntArrayAttr4:$pad,
108110
Tosa_IntArrayAttr2:$stride,
109111
Tosa_IntArrayAttr2:$dilation,
110112
TypeAttrOf<Tosa_AccType>:$acc_type,
111-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
112113
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
113114
);
114115

@@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
123124
//===----------------------------------------------------------------------===//
124125
// Operator: conv3d
125126
//===----------------------------------------------------------------------===//
126-
def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
127+
def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
127128
let summary = "3D Convolution operator";
128129

129130
let description = [{
@@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
134135
Tosa_Tensor5D:$input,
135136
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
136137
Tosa_Tensor1D:$bias,
138+
Optional<Tosa_ZeroPointTensor>:$input_zp,
139+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
137140
Tosa_IntArrayAttr6:$pad,
138141
Tosa_IntArrayAttr3:$stride,
139142
Tosa_IntArrayAttr3:$dilation,
140143
TypeAttrOf<Tosa_AccType>:$acc_type,
141-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
142144
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
143145
);
144146

@@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
153155
//===----------------------------------------------------------------------===//
154156
// Operator: depthwise_conv2d
155157
//===----------------------------------------------------------------------===//
156-
def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
158+
def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
157159
let summary = "Depthwise 2D Convolution operator";
158160

159161
let description = [{
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
165167
Tosa_Tensor4D:$input,
166168
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
167169
Tosa_Tensor1D:$bias,
170+
Optional<Tosa_ZeroPointTensor>:$input_zp,
171+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
168172
Tosa_IntArrayAttr4:$pad,
169173
Tosa_IntArrayAttr2:$stride,
170174
Tosa_IntArrayAttr2:$dilation,
171175
TypeAttrOf<Tosa_AccType>:$acc_type,
172-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
173176
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
174177
);
175178

@@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
338341
//===----------------------------------------------------------------------===//
339342
// Operator: transpose_conv2d
340343
//===----------------------------------------------------------------------===//
341-
def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
344+
def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
342345
let summary = "Transpose 2D Convolution operator.";
343346

344347
let description = [{
@@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
348351

349352
let arguments = (ins
350353
Tosa_Tensor4D:$input,
351-
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
354+
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
352355
Tosa_Tensor1D:$bias,
356+
Optional<Tosa_ZeroPointTensor>:$input_zp,
357+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
353358
Tosa_IntArrayAttr4:$out_pad,
354359
Tosa_IntArrayAttr2:$stride,
355360
Tosa_IntArrayAttr4:$out_shape,
356361
TypeAttrOf<Tosa_AccType>:$acc_type,
357-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
358362
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
359363
);
360364

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,9 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
288288
def Rank2TosaShape : TosaShapeOfRank<2>;
289289
def Rank4TosaShape : TosaShapeOfRank<4>;
290290

291+
// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
292+
// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
293+
// following def can be removed.
294+
def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;
295+
291296
#endif // TOSA_TYPES_BASE

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
3535
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3636
Value input, Value weight);
3737

38+
std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
39+
Value weight);
40+
3841
//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
3942
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
4043
Value a, Value b);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
258258
DenseI64ArrayAttr padAttr = op.getPadAttr();
259259
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260260
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261-
bool isQuantized = op.getQuantizationInfo().has_value();
261+
262+
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
263+
if (llvm::failed(failureOrMaybeZps))
264+
return failure();
265+
266+
auto maybeZps = failureOrMaybeZps.value();
262267

263268
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
264269
return rewriter.notifyMatchFailure(
@@ -284,22 +289,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
284289

285290
// Apply padding as necessary.
286291
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
287-
if (isQuantized) {
288-
auto quantizationInfo = *op.getQuantizationInfo();
289-
int64_t iZp = quantizationInfo.getInputZp();
290-
292+
if (maybeZps) {
291293
int64_t intMin =
292294
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
293295
.getSExtValue();
294296
int64_t intMax =
295297
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
296298
.getSExtValue();
297299

298-
if (iZp < intMin || iZp > intMax)
300+
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
299301
return rewriter.notifyMatchFailure(
300302
op, "tosa.conv op quantization has zp outside of input range");
301303

302-
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
304+
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
303305
}
304306

305307
llvm::SmallVector<int64_t> pad;
@@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
312314
// For 2D convolutions, we need to check if the target convolution op
313315
// wants a HWCF kernel layout.
314316
bool wantHwcf =
315-
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316-
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317+
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317319
if (wantHwcf) {
318320
// Transpose the kernel to match dimension ordering of the linalg
319321
// convolution operation.
@@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
374376
Value broadcastBias =
375377
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
376378

377-
if (isQuantized) {
378-
auto quantizationInfo = *op.getQuantizationInfo();
379-
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
380-
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
379+
if (maybeZps) {
380+
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
381+
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
381382

382383
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
383384
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,39 +441,31 @@ class DepthwiseConvConverter
440441
/*inputSizeDims=*/{1, 2},
441442
/*kernelSizeDims=*/{0, 1}, rewriter);
442443

443-
bool isQuantized = op->hasAttr("quantization_info");
444-
IntegerAttr iZp;
445-
IntegerAttr kZp;
446-
if (isQuantized) {
447-
auto quantizationInfo =
448-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
449-
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
450-
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
451-
}
444+
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
445+
if (llvm::failed(failureOrMaybeZps))
446+
return failure();
447+
448+
auto maybeZps = failureOrMaybeZps.value();
452449

453450
auto weightShape = weightTy.getShape();
454451
auto resultShape = resultTy.getShape();
455452

456453
// Apply padding as necessary.
457454
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
458-
if (isQuantized) {
459-
auto quantizationInfo =
460-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
461-
int64_t iZp = quantizationInfo.getInputZp();
462-
455+
if (maybeZps) {
463456
int64_t intMin =
464457
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
465458
.getSExtValue();
466459
int64_t intMax =
467460
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
468461
.getSExtValue();
469462

470-
if (iZp < intMin || iZp > intMax)
463+
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
471464
return rewriter.notifyMatchFailure(
472465
op, "tosa.depthwise_conv op quantization has zp outside of input "
473466
"range");
474467

475-
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
468+
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
476469
}
477470

478471
llvm::SmallVector<int64_t> pad;
@@ -512,7 +505,7 @@ class DepthwiseConvConverter
512505
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513506
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514507

515-
if (!isQuantized) {
508+
if (!maybeZps) {
516509
Value conv = rewriter
517510
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
518511
loc, linalgConvTy, ValueRange{input, weight},
@@ -539,8 +532,10 @@ class DepthwiseConvConverter
539532
.getResult(0);
540533
rewriter.replaceOp(op, result);
541534
} else {
535+
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
536+
IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
542537
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
543-
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
538+
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
544539
Value conv =
545540
rewriter
546541
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(

0 commit comments

Comments
 (0)