Skip to content

Commit 6d33d7c

Browse files
committed
Addings c-api and py-apis
1 parent 5de1147 commit 6d33d7c

File tree

5 files changed

+71
-1
lines changed

5 files changed

+71
-1
lines changed

mlir/lib/Bindings/Python/DialectQuant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir-c/BuiltinAttributes.h"
910
#include "mlir-c/Dialect/Quant.h"
1011
#include "mlir-c/IR.h"
1112
#include "mlir/Bindings/Python/PybindAdaptors.h"

mlir/lib/CAPI/Dialect/Quant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/Dialect/Quant.h"
10+
#include "mlir-c/BuiltinAttributes.h"
1011
#include "mlir/CAPI/Registration.h"
1112
#include "mlir/Dialect/Quant/IR/Quant.h"
1213
#include "mlir/Dialect/Quant/IR/QuantTypes.h"

mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55

6-
from mlir.ir import Type
6+
from mlir.ir import DenseElementsAttr, Type
77

88
__all__ = [
99
"QuantizedType",
@@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType):
109109
@property
110110
def is_fixed_point(self) -> bool: ...
111111

112+
class UniformQuantizedSubChannelType(QuantizedType):
113+
114+
@classmethod
115+
def get(cls, flags: int, storage_type: Type, expressed_type: Type,
116+
scales: DenseElementsAttr, zero_points: DenseElementsAttr,
117+
quantized_dimensions: list[int], block_sizes: list[int],
118+
storage_type_min: int, storage_type_max: int):
119+
...
120+
121+
@property
122+
def quantized_dimensions(self) -> list[int]: ...
123+
124+
@property
125+
def block_sizes(self) -> list[int]: ...
126+
127+
@property
128+
def scales(self) -> DenseElementsAttr: ...
129+
130+
@property
131+
def zero_points(self) -> DenseElementsAttr: ...
112132

113133
def CalibratedQuantizedType(QuantizedType):
114134

mlir/test/CAPI/quant.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
1111

1212
#include "mlir-c/Dialect/Quant.h"
13+
#include "mlir-c/BuiltinAttributes.h"
1314
#include "mlir-c/BuiltinTypes.h"
1415
#include "mlir-c/IR.h"
1516

@@ -357,6 +358,7 @@ int main(void) {
357358
testAnyQuantizedType(ctx);
358359
testUniformType(ctx);
359360
testUniformPerAxisType(ctx);
361+
testUniformSubChannelType(ctx);
360362
testCalibratedType(ctx);
361363
mlirContextDestroy(ctx);
362364
return EXIT_SUCCESS;

mlir/test/python/dialects/quant.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3+
import numpy as np
34
from mlir.ir import *
45
from mlir.dialects import quant
56

@@ -18,21 +19,27 @@ def test_type_hierarchy():
1819
any = Type.parse("!quant.any<i8<-8:7>:f32>")
1920
uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
2021
per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
22+
sub_channel = Type.parse(
23+
"!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>")
2124
calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
2225

2326
assert not quant.QuantizedType.isinstance(i8)
2427
assert quant.QuantizedType.isinstance(any)
2528
assert quant.QuantizedType.isinstance(uniform)
2629
assert quant.QuantizedType.isinstance(per_axis)
30+
assert quant.QuantizedType.isinstance(sub_channel)
2731
assert quant.QuantizedType.isinstance(calibrated)
2832

2933
assert quant.AnyQuantizedType.isinstance(any)
3034
assert quant.UniformQuantizedType.isinstance(uniform)
3135
assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
36+
assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
3237
assert quant.CalibratedQuantizedType.isinstance(calibrated)
3338

3439
assert not quant.AnyQuantizedType.isinstance(uniform)
3540
assert not quant.UniformQuantizedType.isinstance(per_axis)
41+
assert not quant.UniformQuantizedType.isinstance(sub_channel)
42+
assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
3643

3744

3845
# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +128,45 @@ def test_uniform_per_axis_type():
121128
assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
122129

123130

131+
# CHECK-LABEL: TEST: test_uniform_sub_channel_type
132+
@run
133+
def test_uniform_sub_channel_type():
134+
with Context():
135+
i8 = IntegerType.get_signless(8)
136+
f32 = F32Type.get()
137+
sub_channel = quant.UniformQuantizedSubChannelType.get(
138+
quant.QuantizedType.FLAG_SIGNED,
139+
i8,
140+
f32,
141+
DenseElementsAttr.get(np.asarray(
142+
[2.0, 3.0, 4.0, 5.0], np.float32).reshape(2, 2)),
143+
DenseElementsAttr.get(np.asarray(
144+
[10, 20, 30, 40], np.int8).reshape(2, 2)),
145+
[0, 1], [1, 2],
146+
storage_type_min=quant.QuantizedType.default_minimum_for_integer(
147+
is_signed=True, integral_width=8
148+
),
149+
storage_type_max=quant.QuantizedType.default_maximum_for_integer(
150+
is_signed=True, integral_width=8
151+
),
152+
)
153+
154+
# CHECK: quantized dimensions: [0, 1]
155+
print(f"quantized dimensions: {sub_channel.quantized_dimensions}")
156+
# CHECK: block sizes: [1, 2]
157+
print(f"block sizes: {sub_channel.block_sizes}")
158+
# CHECK: scales: {{\[}}[2. 3.]
159+
# CHECK: [4. 5.]]
160+
print(f"scales: {np.asarray(sub_channel.scales)}")
161+
# CHECK: zero-points: {{\[}}[10 20]
162+
# CHECK: [30 40]]
163+
print(f"zero-points: {np.asarray(sub_channel.zero_points)}")
164+
# CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
165+
print(sub_channel)
166+
assert sub_channel == Type.parse(
167+
"!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>")
168+
169+
124170
# CHECK-LABEL: TEST: test_calibrated_type
125171
@run
126172
def test_calibrated_type():

0 commit comments

Comments
 (0)