Skip to content

Commit 81d7eef

Browse files
authored
Sub-channel quantized type implementation (#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694). In order to make the review process easier, the PR has been divided into the following commit labels: 1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered. 2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those. 4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible. A design note: - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping. ``` Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>> ``` The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome! PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.
1 parent 7bda9ca commit 81d7eef

File tree

25 files changed

+2091
-176
lines changed

25 files changed

+2091
-176
lines changed

mlir/include/mlir-c/Dialect/Quant.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
172172
MLIR_CAPI_EXPORTED bool
173173
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
174174

175+
//===---------------------------------------------------------------------===//
176+
// UniformQuantizedSubChannelType
177+
//===---------------------------------------------------------------------===//
178+
179+
/// Returns `true` if the given type is a UniformQuantizedSubChannel.
180+
MLIR_CAPI_EXPORTED bool
181+
mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
182+
183+
/// Creates a UniformQuantizedSubChannelType with the given parameters.
184+
///
185+
/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
186+
/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes`
187+
/// point to `blockSizeInfoLength` number of elements, describing respectively
188+
/// the quantization axis and corresponding block size.
189+
MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
190+
unsigned flags, MlirType storageType, MlirType expressedType,
191+
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
192+
intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
193+
int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
194+
195+
/// Returns the number of block sizes provided in type.
196+
MLIR_CAPI_EXPORTED intptr_t
197+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
198+
199+
/// Returns the quantized dimension at the given position.
200+
MLIR_CAPI_EXPORTED int32_t
201+
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
202+
intptr_t pos);
203+
204+
/// Returns the block size at the given position.
205+
MLIR_CAPI_EXPORTED int64_t
206+
mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);
207+
208+
/// Returns the scales of the quantized type.
209+
MLIR_CAPI_EXPORTED MlirAttribute
210+
mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);
211+
212+
/// Returns the zero-points of the quantized type.
213+
MLIR_CAPI_EXPORTED MlirAttribute
214+
mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
215+
175216
//===---------------------------------------------------------------------===//
176217
// CalibratedQuantizedType
177218
//===---------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Quant/IR/QuantBase.td

Lines changed: 183 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
4040
encodes the necessary information for (lossy) round-trip conversion between
4141
an expressed and a stored value.
4242

43-
The `quant.uniform` type has two variants: per-layer quantization and
44-
per-channel (or per-axis) quantization. In per-layer quantization, the
45-
quantization information affects an entire tensor uniformly. Conversely, in
46-
per-channel quantization, the data type encodes the specific tensor axis
47-
that serves as the channel and includes quantization information for each
48-
individual channel within the tensor. Below are the specific syntactic and
49-
semantic considerations for each modality.
43+
The `quant.uniform` type has three variants: per-layer quantization,
44+
per-channel (or per-axis) quantization, and sub-channel (or blockwize)
45+
quantization. In per-layer quantization, the quantization information
46+
affects an entire tensor uniformly. Conversely, in per-channel
47+
quantization, the data type encodes the specific tensor axis that serves
48+
as the channel and includes quantization information for each individual
49+
channel within the tensor. Sub-channel quantization is a generalization
50+
of per-tensor and per-channel quantization, where the quantization
51+
parameters are defined for blocks of elements along one or more
52+
dimensions of the tensor. Below are the specific syntactic and semantic
53+
considerations for each modality.
5054

5155

5256
### Per-layer quantization
@@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
145149
```
146150
// A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
147151
// floats. Dimension 1 of the tensor acts as the channel dimension. Its
148-
// size 3 matches the number of provided scale values. Tensor elemenets at
152+
// size 3 matches the number of provided scale values. Tensor elements at
149153
// positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
150154
// 5.0, respectively.
151155
tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
@@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
159163
tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
160164
```
161165

166+
### Sub-channel quantization
167+
168+
Sub-channel quantization, also known as blockwise quantization, provides
169+
finer-grained control than per-tensor or per-channel quantization. It
170+
divides a tensor into blocks of elements, each with its own quantization
171+
parameters (scale and zero point). This is particularly useful when
172+
different regions of a tensor exhibit distinct value ranges.
173+
174+
The `!quant.uniform` type represents sub-channel quantization with the
175+
following syntax:
176+
177+
```
178+
`!quant.uniform` `<`
179+
storedType (`<` storageMin `:` storageMax `>`)? `:`
180+
expressedType `:` blockSizeInfo
181+
scaleZeroTensor `>`
182+
183+
blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
184+
axisBlock ::= axis `:` blockSize
185+
scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
186+
scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
187+
scaleZeroList ::= scaleZero (`,` scaleZero)*
188+
scaleZero ::= scale (`:` zeroPoint)?
189+
190+
scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
191+
scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
192+
scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
193+
```
194+
195+
The `blockSize` field specifies the size of the blocks along dimension
196+
`axis` of the tensor. The `scale` and `zeroPoint` fields specify the
197+
quantization parameters for a particular block. Specifically, the tensor
198+
element at position [i0...iN] uses
199+
`scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
200+
`scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
201+
and zeroPoint respectively.
202+
203+
Here are some examples:
204+
205+
```
206+
// A 3x4 tensor of i8 values representing f32 values, quantized
207+
// along axis-0 and axis-1 with block sizes 1 and 2,
208+
// respectively. As a result, the shape of the scales (or zero-points) will
209+
// be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
210+
// blocks along each axis. Tensor elements at positions
211+
// [0][0] and [0][1] use scale `s00` and zero point `z00`,
212+
// [0][2] and [0][3] use scale `s01` and zero point `z01`,
213+
// [1][0] and [1][1] use scale `s10` and zero point `z10`,
214+
// [1][2] and [1][3] use scale `s11` and zero point `z11`,
215+
// [2][0] and [2][1] use scale `s20` and zero point `z20`,
216+
// [2][2] and [2][3] use scale `s21` and zero point `z21`,
217+
tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
218+
{{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>
219+
220+
// A 2D dynamically sized tensor contains u16 values
221+
// representing f32 values. Since the shape of the quantization
222+
// parameters (i.e. scales and zero-points) is given as [2,2] and
223+
// the blocks-sizes are given as [1,2], the shape of the tensor is expected
224+
// to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
225+
// [0][0] and [0][1] use scale `s00` and zero point `z00`,
226+
// [0][2] and [0][3] use scale `s01` and zero point `z01`,
227+
// [1][0] and [1][1] use scale `s10` and zero point `z10`,
228+
// [1][2] and [1][3] use scale `s11` and zero point `z11`,
229+
tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
230+
{{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
231+
```
162232

163233
## Per-axis quantization integrity
164234

@@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
170240
respected in any context in which the `!quant.uniform` data type is used,
171241
such as the header of a `func.func` op, or the input of an arithmetic
172242
operation.
173-
243+
174244
- A quantized type with per-channel quantization information must be the
175245
element type of a tensor container type, and may not occur directly as
176246
the data type of a scalar value.
@@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
209279
// Correct. The quantized type now includes 3 scale values, matching the
210280
// size of dimension 1 of the result tensor.
211281
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
282+
283+
## Sub-channel quantization integrity
284+
285+
When type `!quant.uniform` contains sub-channel quantization information,
286+
the following rules are enforced. For efficiency, these rules are actively
287+
enforced by the verifiers of `quant` dialect ops, but they must be
288+
respected in any context in which the `!quant.uniform` data type is used,
289+
such as the header of a `func.func` op, or the input of an arithmetic
290+
operation.
291+
292+
- A quantized type with sub-channel quantization information must be the
293+
element type of a tensor container type, and may not occur directly as
294+
the data type of a scalar value.
295+
296+
```
297+
// Incorrect. Type !quant.uniform specifies sub-channel quantization for a
298+
// scalar type.
299+
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
300+
301+
// Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
302+
// in a `tensor` type.
303+
%result = quant.qcast %input : tensor<2x2xf32> to
304+
tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
305+
```
306+
307+
- The tensor containing the sub-channel quantized type must be ranked.
308+
309+
```
310+
// Incorrect. Type !quant.uniform specifies sub-channel quantization for a
311+
// unranked tensor type.
312+
%result = quant.qcast %input : tensor<*xf32> to
313+
tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
314+
```
315+
316+
- The axis for which a block size is specified should be valid for a tensor
317+
of a given rank. Block sizes can be specified for a subset of axes.
318+
Any unspecified block size for an axis i defaults to the tensor dimension
319+
size of that axis (shape(tensor)[i]).
320+
321+
```
322+
// Incorrect. The block-size is specified for axis 2 which is greater than
323+
// the rank of the tensor.
324+
%result = quant.qcast %input : tensor<2x2xf32> to
325+
tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
326+
327+
// Incorrect. The block-size is specified for a negative axis.
328+
%result = quant.qcast %input : tensor<2x2xf32> to
329+
tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
330+
331+
// Correct. The block size for axis 1 is skipped which should be assumed as
332+
// 2, the dim-size of tensor at axis 1.
333+
%result = quant.qcast %input : tensor<6x2xf32> to
334+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
335+
336+
// Correct. The block size for all the axes are skipped making the
337+
// sub-channel type essentially a per-tensor type.
338+
%result = quant.qcast %input : tensor<6x2xf32> to
339+
tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
340+
```
341+
342+
- Block size for a particular axis should be a positive integer and should
343+
be less than the dimension size of the tensor along that axis.
344+
345+
```
346+
// Incorrect. The block size for axis 0 is -1.
347+
%result = quant.qcast %input : tensor<6x2xf32> to
348+
tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
349+
350+
// Incorrect. The block size for axis 0 is 8 which is greater than the
351+
// dimension size of tensor at axis 0 (which is 6).
352+
%result = quant.qcast %input : tensor<6x2xf32> to
353+
tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
354+
355+
// Correct. The block size for axis 0 is now 3.
356+
%result = quant.qcast %input : tensor<6x2xf32> to
357+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
358+
```
359+
360+
- shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
361+
axis i in [0, 1, ..., rank(tensor)-1]].
362+
363+
```
364+
// Incorrect. The block size for axis 0 is 4 and the corresponding
365+
// dimension size is 6 and 6 % 4 != 0.
366+
%result = quant.qcast %input : tensor<6x2xf32> to
367+
tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
368+
369+
// Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
370+
%result = quant.qcast %input : tensor<6x2xf32> to
371+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
372+
```
373+
374+
- shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
375+
376+
```
377+
// Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
378+
// shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
379+
%result = quant.qcast %input : tensor<6x2xf32> to
380+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
381+
382+
// Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
383+
// shape(scales) equals [6,2]/[3,2].
384+
%result = quant.qcast %input : tensor<6x2xf32> to
385+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
212386
```
213387
}];
214388
let cppNamespace = "::mlir::quant";

mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef QUANT_BYTECODE
1414
#define QUANT_BYTECODE
1515

16+
include "mlir/IR/BuiltinDialectBytecode.td"
1617
include "mlir/IR/BytecodeBase.td"
1718

1819
def DoubleAPFloat:
@@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type
8182
}];
8283
}
8384

85+
def UniformQuantizedSubChannelType
86+
: DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
87+
SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
88+
Array<SignedVarIntList>:$quantizedDimensions,
89+
Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
90+
DenseElementsAttr:$zeroPoints)> {
91+
// Note: builder order differs from bytecode.
92+
let cBuilder = [{
93+
get<$_resultType>(context, flags, storageType, expressedType, scales,
94+
zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
95+
[](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
96+
storageTypeMin, storageTypeMax)
97+
}];
98+
}
99+
84100
/// This enum contains marker codes used to indicate which attribute is
85101
/// currently being decoded, and how it should be decoded. The order of these
86102
/// codes should generally be unchanged, as any changes will inevitably break
87103
/// compatibility with older bytecode.
88104

89105
def QuantDialectTypes : DialectTypes<"Quant"> {
90-
let elems = [
91-
ReservedOrDead,
92-
AnyQuantizedType,
93-
AnyQuantizedTypeWithExpressedType,
94-
CalibratedQuantizedType,
95-
UniformQuantizedType,
96-
UniformQuantizedPerAxisType
97-
];
106+
let elems = [ReservedOrDead, AnyQuantizedType,
107+
AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
108+
UniformQuantizedType, UniformQuantizedPerAxisType,
109+
UniformQuantizedSubChannelType];
98110
}
99111

100-
#endif // QUANT_BYTECODE
112+
#endif // QUANT_BYTECODE

0 commit comments

Comments
 (0)