Skip to content

Commit 82f3cbc

Browse files
authored
[MLIR][Python] Added a base class to all builtin floating point types (#81720)
This allows to * check if a given ir.Type is a floating point type via isinstance() or issubclass() * get the bitwidth of a floating point type See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.
1 parent 0c8b594 commit 82f3cbc

File tree

5 files changed

+95
-20
lines changed

5 files changed

+95
-20
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
7373
// Floating-point types.
7474
//===----------------------------------------------------------------------===//
7575

76+
/// Checks whether the given type is a floating-point type.
77+
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
78+
79+
/// Returns the bitwidth of a floating-point type.
80+
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
81+
7682
/// Returns the typeID of an Float8E5M2 type.
7783
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
7884

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
109109
}
110110
};
111111

112+
class PyFloatType : public PyConcreteType<PyFloatType> {
113+
public:
114+
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
115+
static constexpr const char *pyClassName = "FloatType";
116+
using PyConcreteType::PyConcreteType;
117+
118+
static void bindDerived(ClassTy &c) {
119+
c.def_property_readonly(
120+
"width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
121+
"Returns the width of the floating-point type");
122+
}
123+
};
124+
112125
/// Floating Point Type subclass - Float8E4M3FNType.
113-
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
126+
class PyFloat8E4M3FNType
127+
: public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
114128
public:
115129
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
116130
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
130144
};
131145

132146
/// Floating Point Type subclass - Float8M5E2Type.
133-
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
147+
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
134148
public:
135149
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
136150
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
150164
};
151165

152166
/// Floating Point Type subclass - Float8E4M3FNUZ.
153-
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
167+
class PyFloat8E4M3FNUZType
168+
: public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
154169
public:
155170
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
156171
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
170185
};
171186

172187
/// Floating Point Type subclass - Float8E4M3B11FNUZ.
173-
class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
188+
class PyFloat8E4M3B11FNUZType
189+
: public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
174190
public:
175191
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
176192
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
190206
};
191207

192208
/// Floating Point Type subclass - Float8E5M2FNUZ.
193-
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
209+
class PyFloat8E5M2FNUZType
210+
: public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
194211
public:
195212
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
196213
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
210227
};
211228

212229
/// Floating Point Type subclass - BF16Type.
213-
class PyBF16Type : public PyConcreteType<PyBF16Type> {
230+
class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
214231
public:
215232
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
216233
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
230247
};
231248

232249
/// Floating Point Type subclass - F16Type.
233-
class PyF16Type : public PyConcreteType<PyF16Type> {
250+
class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
234251
public:
235252
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
236253
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
250267
};
251268

252269
/// Floating Point Type subclass - TF32Type.
253-
class PyTF32Type : public PyConcreteType<PyTF32Type> {
270+
class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
254271
public:
255272
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
256273
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type> {
270287
};
271288

272289
/// Floating Point Type subclass - F32Type.
273-
class PyF32Type : public PyConcreteType<PyF32Type> {
290+
class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
274291
public:
275292
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
276293
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
290307
};
291308

292309
/// Floating Point Type subclass - F64Type.
293-
class PyF64Type : public PyConcreteType<PyF64Type> {
310+
class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
294311
public:
295312
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
296313
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
819836

820837
void mlir::python::populateIRTypes(py::module &m) {
821838
PyIntegerType::bind(m);
839+
PyFloatType::bind(m);
822840
PyIndexType::bind(m);
823841
PyFloat8E4M3FNType::bind(m);
824842
PyFloat8E5M2Type::bind(m);

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
7878
// Floating-point types.
7979
//===----------------------------------------------------------------------===//
8080

81+
bool mlirTypeIsAFloat(MlirType type) {
82+
return llvm::isa<FloatType>(unwrap(type));
83+
}
84+
85+
unsigned mlirFloatTypeGetWidth(MlirType type) {
86+
return llvm::cast<FloatType>(unwrap(type)).getWidth();
87+
}
88+
8189
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
8290
return wrap(Float8E5M2Type::getTypeID());
8391
}

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,7 +1442,17 @@ class DictAttr(Attribute):
14421442
@property
14431443
def typeid(self) -> TypeID: ...
14441444

1445-
class F16Type(Type):
1445+
class FloatType(Type):
1446+
@staticmethod
1447+
def isinstance(other: Type) -> bool: ...
1448+
def __init__(self, cast_from_type: Type) -> None: ...
1449+
@property
1450+
def width(self) -> int:
1451+
"""
1452+
Returns the width of the floating-point type.
1453+
"""
1454+
1455+
class F16Type(FloatType):
14461456
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
14471457
@staticmethod
14481458
def get(context: Optional[Context] = None) -> F16Type:
@@ -1455,7 +1465,7 @@ class F16Type(Type):
14551465
@property
14561466
def typeid(self) -> TypeID: ...
14571467

1458-
class F32Type(Type):
1468+
class F32Type(FloatType):
14591469
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
14601470
@staticmethod
14611471
def get(context: Optional[Context] = None) -> F32Type:
@@ -1468,7 +1478,7 @@ class F32Type(Type):
14681478
@property
14691479
def typeid(self) -> TypeID: ...
14701480

1471-
class F64Type(Type):
1481+
class F64Type(FloatType):
14721482
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
14731483
@staticmethod
14741484
def get(context: Optional[Context] = None) -> F64Type:
@@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute):
15021512
Returns the value of the FlatSymbolRef attribute as a string
15031513
"""
15041514

1505-
class Float8E4M3B11FNUZType(Type):
1515+
class Float8E4M3B11FNUZType(FloatType):
15061516
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
15071517
@staticmethod
15081518
def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
@@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type):
15151525
@property
15161526
def typeid(self) -> TypeID: ...
15171527

1518-
class Float8E4M3FNType(Type):
1528+
class Float8E4M3FNType(FloatType):
15191529
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
15201530
@staticmethod
15211531
def get(context: Optional[Context] = None) -> Float8E4M3FNType:
@@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type):
15281538
@property
15291539
def typeid(self) -> TypeID: ...
15301540

1531-
class Float8E4M3FNUZType(Type):
1541+
class Float8E4M3FNUZType(FloatType):
15321542
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
15331543
@staticmethod
15341544
def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
@@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type):
15411551
@property
15421552
def typeid(self) -> TypeID: ...
15431553

1544-
class Float8E5M2FNUZType(Type):
1554+
class Float8E5M2FNUZType(FloatType):
15451555
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
15461556
@staticmethod
15471557
def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
@@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type):
15541564
@property
15551565
def typeid(self) -> TypeID: ...
15561566

1557-
class Float8E5M2Type(Type):
1567+
class Float8E5M2Type(FloatType):
15581568
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
15591569
@staticmethod
15601570
def get(context: Optional[Context] = None) -> Float8E5M2Type:
@@ -1601,7 +1611,7 @@ class FloatAttr(Attribute):
16011611
Returns the value of the float attribute
16021612
"""
16031613

1604-
class FloatTF32Type(Type):
1614+
class FloatTF32Type(FloatType):
16051615
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
16061616
@staticmethod
16071617
def get(context: Optional[Context] = None) -> FloatTF32Type:

mlir/test/python/ir/builtin_types.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,38 @@ def testTypeIsInstance():
100100
print(IntegerType.isinstance(t1))
101101
# CHECK: False
102102
print(F32Type.isinstance(t1))
103+
# CHECK: False
104+
print(FloatType.isinstance(t1))
103105
# CHECK: True
104106
print(F32Type.isinstance(t2))
107+
# CHECK: True
108+
print(FloatType.isinstance(t2))
109+
110+
111+
# CHECK-LABEL: TEST: testFloatTypeSubclasses
112+
@run
113+
def testFloatTypeSubclasses():
114+
ctx = Context()
115+
# CHECK: True
116+
print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
117+
# CHECK: True
118+
print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
119+
# CHECK: True
120+
print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
121+
# CHECK: True
122+
print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
123+
# CHECK: True
124+
print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
125+
# CHECK: True
126+
print(isinstance(Type.parse("f16", ctx), FloatType))
127+
# CHECK: True
128+
print(isinstance(Type.parse("bf16", ctx), FloatType))
129+
# CHECK: True
130+
print(isinstance(Type.parse("f32", ctx), FloatType))
131+
# CHECK: True
132+
print(isinstance(Type.parse("tf32", ctx), FloatType))
133+
# CHECK: True
134+
print(isinstance(Type.parse("f64", ctx), FloatType))
105135

106136

107137
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@@ -218,7 +248,10 @@ def testFloatType():
218248
# CHECK: float: f32
219249
print("float:", F32Type.get())
220250
# CHECK: float: f64
221-
print("float:", F64Type.get())
251+
f64 = F64Type.get()
252+
print("float:", f64)
253+
# CHECK: f64 width: 64
254+
print("f64 width:", f64.width)
222255

223256

224257
# CHECK-LABEL: TEST: testNoneType

0 commit comments

Comments
 (0)