1
1
# RUN: %PYTHON %s | FileCheck %s
2
2
3
+ import numpy as np
3
4
from mlir .ir import *
4
5
from mlir .dialects import quant
5
6
@@ -18,21 +19,27 @@ def test_type_hierarchy():
18
19
any = Type .parse ("!quant.any<i8<-8:7>:f32>" )
19
20
uniform = Type .parse ("!quant.uniform<i8<-8:7>:f32, 0.99872:127>" )
20
21
per_axis = Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
22
+ sub_channel = Type .parse (
23
+ "!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>" )
21
24
calibrated = Type .parse ("!quant.calibrated<f32<-0.998:1.2321>>" )
22
25
23
26
assert not quant .QuantizedType .isinstance (i8 )
24
27
assert quant .QuantizedType .isinstance (any )
25
28
assert quant .QuantizedType .isinstance (uniform )
26
29
assert quant .QuantizedType .isinstance (per_axis )
30
+ assert quant .QuantizedType .isinstance (sub_channel )
27
31
assert quant .QuantizedType .isinstance (calibrated )
28
32
29
33
assert quant .AnyQuantizedType .isinstance (any )
30
34
assert quant .UniformQuantizedType .isinstance (uniform )
31
35
assert quant .UniformQuantizedPerAxisType .isinstance (per_axis )
36
+ assert quant .UniformQuantizedSubChannelType .isinstance (sub_channel )
32
37
assert quant .CalibratedQuantizedType .isinstance (calibrated )
33
38
34
39
assert not quant .AnyQuantizedType .isinstance (uniform )
35
40
assert not quant .UniformQuantizedType .isinstance (per_axis )
41
+ assert not quant .UniformQuantizedType .isinstance (sub_channel )
42
+ assert not quant .UniformQuantizedPerAxisType .isinstance (sub_channel )
36
43
37
44
38
45
# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +128,45 @@ def test_uniform_per_axis_type():
121
128
assert per_axis == Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
122
129
123
130
131
+ # CHECK-LABEL: TEST: test_uniform_sub_channel_type
132
+ @run
133
+ def test_uniform_sub_channel_type ():
134
+ with Context ():
135
+ i8 = IntegerType .get_signless (8 )
136
+ f32 = F32Type .get ()
137
+ sub_channel = quant .UniformQuantizedSubChannelType .get (
138
+ quant .QuantizedType .FLAG_SIGNED ,
139
+ i8 ,
140
+ f32 ,
141
+ DenseElementsAttr .get (np .asarray (
142
+ [2.0 , 3.0 , 4.0 , 5.0 ], np .float32 ).reshape (2 , 2 )),
143
+ DenseElementsAttr .get (np .asarray (
144
+ [10 , 20 , 30 , 40 ], np .int8 ).reshape (2 , 2 )),
145
+ [0 , 1 ], [1 , 2 ],
146
+ storage_type_min = quant .QuantizedType .default_minimum_for_integer (
147
+ is_signed = True , integral_width = 8
148
+ ),
149
+ storage_type_max = quant .QuantizedType .default_maximum_for_integer (
150
+ is_signed = True , integral_width = 8
151
+ ),
152
+ )
153
+
154
+ # CHECK: quantized dimensions: [0, 1]
155
+ print (f"quantized dimensions: { sub_channel .quantized_dimensions } " )
156
+ # CHECK: block sizes: [1, 2]
157
+ print (f"block sizes: { sub_channel .block_sizes } " )
158
+ # CHECK: scales: {{\[}}[2. 3.]
159
+ # CHECK: [4. 5.]]
160
+ print (f"scales: { np .asarray (sub_channel .scales )} " )
161
+ # CHECK: zero-points: {{\[}}[10 20]
162
+ # CHECK: [30 40]]
163
+ print (f"zero-points: { np .asarray (sub_channel .zero_points )} " )
164
+ # CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
165
+ print (sub_channel )
166
+ assert sub_channel == Type .parse (
167
+ "!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>" )
168
+
169
+
124
170
# CHECK-LABEL: TEST: test_calibrated_type
125
171
@run
126
172
def test_calibrated_type ():
0 commit comments