Skip to content

Commit 107a35d

Browse files
committed
[ADT] Make null PointerUnion with different active members compare equal
Currently, two instances of `PointerUnion` with different active members and null value compare unequal. In some cases, this results in counterintuitive behavior when using functions from `Casting.h`, e.g.: ``` PointerUnion<int *, float *> U; // U = (int *)nullptr; dyn_cast<int *>(U); // Aborts dyn_cast<float *>(U); // Aborts U = (float *)nullptr; dyn_cast<int *>(U); // OK dyn_cast<float *>(U); // OK ``` `dyn_cast` should abort in all cases because the argument is null. Currently, it aborts only if the first member is active. This happens because the partial template specialization of `ValueIsPresent` for nullable types compares the union with a union constructed from nullptr, and the two unions compare equal only if their active members are the same. This patch makes two instances of a union compare equal if they are both null regardless of their active members, and fixes two places where the old behavior was exploited.
1 parent be21bd9 commit 107a35d

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

llvm/include/llvm/ADT/PointerUnion.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,14 @@ class PointerUnion
198198
}
199199
};
200200

201-
template <typename ...PTs>
202-
bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
203-
return lhs.getOpaqueValue() == rhs.getOpaqueValue();
201+
template <typename... PTs>
202+
bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
203+
return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
204204
}
205205

206-
template <typename ...PTs>
207-
bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
208-
return lhs.getOpaqueValue() != rhs.getOpaqueValue();
206+
template <typename... PTs>
207+
bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
208+
return !operator==(LHS, RHS);
209209
}
210210

211211
template <typename ...PTs>

llvm/lib/CodeGen/RegisterBankInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(
134134

135135
// If the register already has a class, fallback to MRI::constrainRegClass.
136136
auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
137-
if (isa<const TargetRegisterClass *>(RegClassOrBank))
137+
if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
138138
return MRI.constrainRegClass(Reg, &RC);
139139

140-
const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
140+
const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
141141
// Otherwise, all we can do is ensure the bank covers the class, and set it.
142142
if (RB && !RB->covers(RC))
143143
return nullptr;

llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3708,10 +3708,10 @@ const TargetRegisterClass *
37083708
SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
37093709
const MachineRegisterInfo &MRI) const {
37103710
const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
3711-
if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
3711+
if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
37123712
return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);
37133713

3714-
if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
3714+
if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
37153715
return getAllocatableClass(RC);
37163716

37173717
return nullptr;

llvm/unittests/ADT/PointerUnionTest.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,16 @@ TEST_F(PointerUnionTest, Comparison) {
5353
EXPECT_TRUE(i4 != l4);
5454
EXPECT_TRUE(f4 != l4);
5555
EXPECT_TRUE(l4 != d4);
56-
EXPECT_TRUE(i4null != f4null);
57-
EXPECT_TRUE(i4null != l4null);
58-
EXPECT_TRUE(i4null != d4null);
56+
EXPECT_TRUE(i4null == f4null);
57+
EXPECT_FALSE(i4null != f4null);
58+
EXPECT_TRUE(i4null == l4null);
59+
EXPECT_FALSE(i4null != l4null);
60+
EXPECT_TRUE(i4null == d4null);
61+
EXPECT_FALSE(i4null != d4null);
62+
EXPECT_FALSE(i4null == i4);
63+
EXPECT_TRUE(i4null != i4);
64+
EXPECT_FALSE(i4null == f4);
65+
EXPECT_TRUE(i4null != f4);
5966
}
6067

6168
TEST_F(PointerUnionTest, Null) {

0 commit comments

Comments
 (0)