Skip to content

[MLIR][Python] Added a base class to all builtin floating point types #81720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
// Floating-point types.
//===----------------------------------------------------------------------===//

/// Checks whether the given type is a floating-point type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);

/// Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);

/// Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);

Expand Down
38 changes: 28 additions & 10 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
}
};

class PyFloatType : public PyConcreteType<PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
static constexpr const char *pyClassName = "FloatType";
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_property_readonly(
"width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
"Returns the width of the floating-point type");
}
};

/// Floating Point Type subclass - Float8E4M3FNType.
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
class PyFloat8E4M3FNType
: public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
};

/// Floating Point Type subclass - Float8M5E2Type.
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
};

/// Floating Point Type subclass - Float8E4M3FNUZ.
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
class PyFloat8E4M3FNUZType
: public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
};

/// Floating Point Type subclass - Float8E4M3B11FNUZ.
class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
class PyFloat8E4M3B11FNUZType
: public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
};

/// Floating Point Type subclass - Float8E5M2FNUZ.
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
class PyFloat8E5M2FNUZType
: public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
};

/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
};

/// Floating Point Type subclass - F16Type.
class PyF16Type : public PyConcreteType<PyF16Type> {
class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
};

/// Floating Point Type subclass - TF32Type.
class PyTF32Type : public PyConcreteType<PyTF32Type> {
class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type> {
};

/// Floating Point Type subclass - F32Type.
class PyF32Type : public PyConcreteType<PyF32Type> {
class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand All @@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
};

/// Floating Point Type subclass - F64Type.
class PyF64Type : public PyConcreteType<PyF64Type> {
class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
Expand Down Expand Up @@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {

void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
// Floating-point types.
//===----------------------------------------------------------------------===//

bool mlirTypeIsAFloat(MlirType type) {
return llvm::isa<FloatType>(unwrap(type));
}

unsigned mlirFloatTypeGetWidth(MlirType type) {
return llvm::cast<FloatType>(unwrap(type)).getWidth();
}

MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
return wrap(Float8E5M2Type::getTypeID());
}
Expand Down
28 changes: 19 additions & 9 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1442,7 +1442,17 @@ class DictAttr(Attribute):
@property
def typeid(self) -> TypeID: ...

class F16Type(Type):
class FloatType(Type):
@staticmethod
def isinstance(other: Type) -> bool: ...
def __init__(self, cast_from_type: Type) -> None: ...
@property
def width(self) -> int:
"""
Returns the width of the floating-point type.
"""

class F16Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F16Type:
Expand All @@ -1455,7 +1465,7 @@ class F16Type(Type):
@property
def typeid(self) -> TypeID: ...

class F32Type(Type):
class F32Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F32Type:
Expand All @@ -1468,7 +1478,7 @@ class F32Type(Type):
@property
def typeid(self) -> TypeID: ...

class F64Type(Type):
class F64Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F64Type:
Expand Down Expand Up @@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute):
Returns the value of the FlatSymbolRef attribute as a string
"""

class Float8E4M3B11FNUZType(Type):
class Float8E4M3B11FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
Expand All @@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type):
@property
def typeid(self) -> TypeID: ...

class Float8E4M3FNType(Type):
class Float8E4M3FNType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3FNType:
Expand All @@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type):
@property
def typeid(self) -> TypeID: ...

class Float8E4M3FNUZType(Type):
class Float8E4M3FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
Expand All @@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type):
@property
def typeid(self) -> TypeID: ...

class Float8E5M2FNUZType(Type):
class Float8E5M2FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
Expand All @@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type):
@property
def typeid(self) -> TypeID: ...

class Float8E5M2Type(Type):
class Float8E5M2Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E5M2Type:
Expand Down Expand Up @@ -1601,7 +1611,7 @@ class FloatAttr(Attribute):
Returns the value of the float attribute
"""

class FloatTF32Type(Type):
class FloatTF32Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> FloatTF32Type:
Expand Down
35 changes: 34 additions & 1 deletion mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,38 @@ def testTypeIsInstance():
print(IntegerType.isinstance(t1))
# CHECK: False
print(F32Type.isinstance(t1))
# CHECK: False
print(FloatType.isinstance(t1))
# CHECK: True
print(F32Type.isinstance(t2))
# CHECK: True
print(FloatType.isinstance(t2))


# CHECK-LABEL: TEST: testFloatTypeSubclasses
@run
def testFloatTypeSubclasses():
ctx = Context()
# CHECK: True
print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f16", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("bf16", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f32", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("tf32", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f64", ctx), FloatType))


# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
Expand Down Expand Up @@ -218,7 +248,10 @@ def testFloatType():
# CHECK: float: f32
print("float:", F32Type.get())
# CHECK: float: f64
print("float:", F64Type.get())
f64 = F64Type.get()
print("float:", f64)
# CHECK: f64 width: 64
print("f64 width:", f64.width)


# CHECK-LABEL: TEST: testNoneType
Expand Down