Skip to content

Commit 47ef5c4

Browse files
annuasd牛奕博
andauthored
[mlir][Bindings] Fix missing return value of functions and incorrect type hint in pyi. (#116731)
The zero points of UniformQuantizedPerAxisType should be List[int]. And there are two methods missing return value. Co-authored-by: 牛奕博 <[email protected]>
1 parent 27046ba commit 47ef5c4

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

mlir/lib/Bindings/Python/DialectQuant.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
250250
double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
251251
scales.push_back(scale);
252252
}
253+
return scales;
253254
},
254255
"The scales designate the difference between the real values "
255256
"corresponding to consecutive quantized values differing by 1. The ith "
@@ -265,6 +266,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
265266
mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
266267
zeroPoints.push_back(zeroPoint);
267268
}
269+
return zeroPoints;
268270
},
269271
"the storage values corresponding to the real value 0 in the affine "
270272
"equation. The ith zero point corresponds to the ith slice in the "

mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class UniformQuantizedPerAxisType(QuantizedType):
101101
def scales(self) -> list[float]: ...
102102

103103
@property
104-
def zero_points(self) -> list[float]: ...
104+
def zero_points(self) -> list[int]: ...
105105

106106
@property
107107
def quantized_dimension(self) -> int: ...

mlir/test/python/dialects/quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def test_uniform_per_axis_type():
108108
),
109109
)
110110

111-
# CHECK: scales: None
111+
# CHECK: scales: [200.0, 0.99872]
112112
print(f"scales: {per_axis.scales}")
113-
# CHECK: zero_points: None
113+
# CHECK: zero_points: [0, 120]
114114
print(f"zero_points: {per_axis.zero_points}")
115115
# CHECK: quantized dim: 1
116116
print(f"quantized dim: {per_axis.quantized_dimension}")

0 commit comments

Comments
 (0)