Skip to content

Commit bb63d24

Browse files
committed
[NFC][mlir] Add support for llvm style casting for mlir types
Note: when operating on a Type hierarchy with LeafType inheriting from MiddleType which inherits from mlir::Type. calling LeafType::classof(MiddleType) will always return false. because classof call the static getTypeID from its parent instead of the dynamic Type::getTypeID so classof in this context will check if the TypeID of LeafType is the same as the TypeID of MiddleType which is always false. It is bypassed in this commit inside CastInfo<To, From>::isPossible by calling classof with an mlir::Type. but other unsuspecting users of LeafType::classof(MiddleType) would still get an incorrect result.
1 parent b8055c5 commit bb63d24

File tree

3 files changed

+108
-18
lines changed

3 files changed

+108
-18
lines changed

mlir/include/mlir/IR/Types.h

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,9 @@ class Type {
9494

9595
bool operator!() const { return impl == nullptr; }
9696

97-
template <typename U>
98-
bool isa() const;
99-
template <typename First, typename Second, typename... Rest>
97+
template <typename... Tys>
10098
bool isa() const;
101-
template <typename First, typename... Rest>
99+
template <typename... Tys>
102100
bool isa_and_nonnull() const;
103101
template <typename U>
104102
U dyn_cast() const;
@@ -185,6 +183,9 @@ class Type {
185183
/// Return the abstract type descriptor for this type.
186184
const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
187185

186+
/// Return the Type implementation.
187+
ImplType *getImpl() const { return impl; }
188+
188189
protected:
189190
ImplType *impl{nullptr};
190191
};
@@ -250,34 +251,29 @@ inline ::llvm::hash_code hash_value(Type arg) {
250251
return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
251252
}
252253

253-
template <typename U>
254-
bool Type::isa() const {
255-
assert(impl && "isa<> used on a null type.");
256-
return U::classof(*this);
257-
}
258-
259-
template <typename First, typename Second, typename... Rest>
254+
template <typename... Tys>
260255
bool Type::isa() const {
261-
return isa<First>() || isa<Second, Rest...>();
256+
return llvm::isa<Tys...>(*this);
262257
}
263258

264-
template <typename First, typename... Rest>
259+
template <typename... Tys>
265260
bool Type::isa_and_nonnull() const {
266-
return impl && isa<First, Rest...>();
261+
return llvm::isa_and_present<Tys...>(*this);
267262
}
268263

269264
template <typename U>
270265
U Type::dyn_cast() const {
271-
return isa<U>() ? U(impl) : U(nullptr);
266+
return llvm::dyn_cast<U>(*this);
272267
}
268+
273269
template <typename U>
274270
U Type::dyn_cast_or_null() const {
275-
return (impl && isa<U>()) ? U(impl) : U(nullptr);
271+
return llvm::dyn_cast_or_null<U>(*this);
276272
}
273+
277274
template <typename U>
278275
U Type::cast() const {
279-
assert(isa<U>());
280-
return U(impl);
276+
return llvm::cast<U>(*this);
281277
}
282278

283279
} // namespace mlir
@@ -325,6 +321,32 @@ struct PointerLikeTypeTraits<mlir::Type> {
325321
static constexpr int NumLowBitsAvailable = 3;
326322
};
327323

324+
/// Add support for llvm style casts.
325+
/// We provide a cast between To and From if From is mlir::Type or derives from
326+
/// it
327+
template <typename To, typename From>
328+
struct CastInfo<To, From,
329+
typename std::enable_if<
330+
std::is_same_v<mlir::Type, std::remove_const_t<From>> ||
331+
std::is_base_of_v<mlir::Type, From>>::type>
332+
: NullableValueCastFailed<To>,
333+
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
334+
/// Arguments are taken as mlir::Type here and not as From.
335+
/// Because when casting from an intermediate type of the hierarchy to one of
336+
/// its children, the val.getTypeID() inside T::classof will use the static
337+
/// getTypeID of the parent instead of the non-static Type::getTypeID return
338+
/// the dynamic ID. so T::classof would end up comparing the static TypeID of
339+
/// The children to the static TypeID of its parent making it impossible to
340+
/// downcast from the parent to the child
341+
static inline bool isPossible(mlir::Type ty) {
342+
/// Return a constant true instead of a dynamic true when casting to self or
343+
/// up the hierarchy
344+
return std::is_same_v<To, std::remove_const_t<From>> ||
345+
std::is_base_of_v<To, From> || To::classof(ty);
346+
}
347+
static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); }
348+
};
349+
328350
} // namespace llvm
329351

330352
#endif // MLIR_IR_TYPES_H

mlir/unittests/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_unittest(MLIRIRTests
77
PatternMatchTest.cpp
88
ShapedTypeTest.cpp
99
SubElementInterfaceTest.cpp
10+
TypeTest.cpp
1011

1112
DEPENDS
1213
MLIRTestInterfaceIncGen

mlir/unittests/IR/TypeTest.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//===- TypeTest.cpp - Type API unit tests ---------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/IR/Dialect.h"
10+
#include "mlir/IR/Types.h"
11+
#include "mlir/IR/BuiltinTypes.h"
12+
#include "mlir/IR/Value.h"
13+
#include "gtest/gtest.h"
14+
15+
using namespace mlir;
16+
17+
/// Mock implementations of a Type hierarchy
18+
struct LeafType;
19+
20+
struct MiddleType : Type::TypeBase<MiddleType, Type, TypeStorage> {
21+
using Base::Base;
22+
static bool classof(Type ty) {
23+
return ty.getTypeID() == TypeID::get<LeafType>() || Base::classof(ty);
24+
}
25+
};
26+
27+
struct LeafType : Type::TypeBase<LeafType, MiddleType, TypeStorage> {
28+
using Base::Base;
29+
};
30+
31+
struct FakeDialect : Dialect {
32+
FakeDialect(MLIRContext *context)
33+
: Dialect(getDialectNamespace(), context, TypeID::get<FakeDialect>()) {
34+
addTypes<MiddleType, LeafType>();
35+
}
36+
static constexpr ::llvm::StringLiteral getDialectNamespace() {
37+
return ::llvm::StringLiteral("fake");
38+
}
39+
};
40+
41+
TEST(Type, Casting) {
42+
MLIRContext ctx;
43+
ctx.loadDialect<FakeDialect>();
44+
45+
Type intTy = IntegerType::get(&ctx, 8);
46+
Type nullTy;
47+
MiddleType middleTy = MiddleType::get(&ctx);
48+
MiddleType leafTy = LeafType::get(&ctx);
49+
Type leaf2Ty = LeafType::get(&ctx);
50+
51+
EXPECT_TRUE(isa<IntegerType>(intTy));
52+
EXPECT_FALSE(isa<FunctionType>(intTy));
53+
EXPECT_FALSE(llvm::isa_and_present<IntegerType>(nullTy));
54+
EXPECT_TRUE(isa<MiddleType>(middleTy));
55+
EXPECT_FALSE(isa<LeafType>(middleTy));
56+
EXPECT_TRUE(isa<MiddleType>(leafTy));
57+
EXPECT_TRUE(isa<LeafType>(leaf2Ty));
58+
EXPECT_TRUE(isa<LeafType>(leafTy));
59+
60+
EXPECT_TRUE(static_cast<bool>(dyn_cast<IntegerType>(intTy)));
61+
EXPECT_FALSE(static_cast<bool>(dyn_cast<FunctionType>(intTy)));
62+
EXPECT_FALSE(static_cast<bool>(llvm::cast_if_present<FunctionType>(nullTy)));
63+
EXPECT_FALSE(
64+
static_cast<bool>(llvm::dyn_cast_if_present<IntegerType>(nullTy)));
65+
66+
EXPECT_EQ(8u, cast<IntegerType>(intTy).getWidth());
67+
}

0 commit comments

Comments
 (0)