Skip to content

Commit 512739e

Browse files
authored
[clang][Interp] Three-way comparisons (#65901)
1 parent 7cc83c5 commit 512739e

File tree

9 files changed

+160
-0
lines changed

9 files changed

+160
-0
lines changed

clang/lib/AST/Interp/Boolean.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ class Boolean final {
8484
Boolean truncate(unsigned TruncBits) const { return *this; }
8585

8686
void print(llvm::raw_ostream &OS) const { OS << (V ? "true" : "false"); }
87+
std::string toDiagnosticString(const ASTContext &Ctx) const {
88+
std::string NameStr;
89+
llvm::raw_string_ostream OS(NameStr);
90+
print(OS);
91+
return NameStr;
92+
}
8793

8894
static Boolean min(unsigned NumBits) { return Boolean(false); }
8995
static Boolean max(unsigned NumBits) { return Boolean(true); }

clang/lib/AST/Interp/ByteCodeExprGen.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,29 @@ bool ByteCodeExprGen<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
253253
return this->delegate(RHS);
254254
}
255255

256+
// Special case for C++'s three-way/spaceship operator <=>, which
257+
// returns a std::{strong,weak,partial}_ordering (which is a class, so doesn't
258+
// have a PrimType).
259+
if (!T) {
260+
if (DiscardResult)
261+
return true;
262+
const ComparisonCategoryInfo *CmpInfo =
263+
Ctx.getASTContext().CompCategories.lookupInfoForType(BO->getType());
264+
assert(CmpInfo);
265+
266+
// We need a temporary variable holding our return value.
267+
if (!Initializing) {
268+
std::optional<unsigned> ResultIndex = this->allocateLocal(BO, false);
269+
if (!this->emitGetPtrLocal(*ResultIndex, BO))
270+
return false;
271+
}
272+
273+
if (!visit(LHS) || !visit(RHS))
274+
return false;
275+
276+
return this->emitCMP3(*LT, CmpInfo, BO);
277+
}
278+
256279
if (!LT || !RT || !T)
257280
return this->bail(BO);
258281

clang/lib/AST/Interp/Floating.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ class Floating final {
7676
F.toString(Buffer);
7777
OS << Buffer;
7878
}
79+
std::string toDiagnosticString(const ASTContext &Ctx) const {
80+
std::string NameStr;
81+
llvm::raw_string_ostream OS(NameStr);
82+
print(OS);
83+
return NameStr;
84+
}
7985

8086
unsigned bitWidth() const { return F.semanticsSizeInBits(F.getSemantics()); }
8187

clang/lib/AST/Interp/Integral.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ template <unsigned Bits, bool Signed> class Integral final {
128128
return Compare(V, RHS.V);
129129
}
130130

131+
std::string toDiagnosticString(const ASTContext &Ctx) const {
132+
std::string NameStr;
133+
llvm::raw_string_ostream OS(NameStr);
134+
OS << V;
135+
return NameStr;
136+
}
137+
131138
unsigned countLeadingZeros() const {
132139
if constexpr (!Signed)
133140
return llvm::countl_zero<ReprT>(V);

clang/lib/AST/Interp/Interp.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ bool CheckCtorCall(InterpState &S, CodePtr OpPC, const Pointer &This);
112112
bool CheckPotentialReinterpretCast(InterpState &S, CodePtr OpPC,
113113
const Pointer &Ptr);
114114

115+
/// Sets the given integral value to the pointer, which is of
116+
/// a std::{weak,partial,strong}_ordering type.
117+
bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
118+
const Pointer &Ptr, const APSInt &IntValue);
119+
115120
/// Checks if the shift operation is legal.
116121
template <typename LT, typename RT>
117122
bool CheckShift(InterpState &S, CodePtr OpPC, const LT &LHS, const RT &RHS,
@@ -781,6 +786,30 @@ bool EQ(InterpState &S, CodePtr OpPC) {
781786
});
782787
}
783788

789+
template <PrimType Name, class T = typename PrimConv<Name>::T>
790+
bool CMP3(InterpState &S, CodePtr OpPC, const ComparisonCategoryInfo *CmpInfo) {
791+
const T &RHS = S.Stk.pop<T>();
792+
const T &LHS = S.Stk.pop<T>();
793+
const Pointer &P = S.Stk.peek<Pointer>();
794+
795+
ComparisonCategoryResult CmpResult = LHS.compare(RHS);
796+
if (CmpResult == ComparisonCategoryResult::Unordered) {
797+
// This should only happen with pointers.
798+
const SourceInfo &Loc = S.Current->getSource(OpPC);
799+
S.FFDiag(Loc, diag::note_constexpr_pointer_comparison_unspecified)
800+
<< LHS.toDiagnosticString(S.getCtx())
801+
<< RHS.toDiagnosticString(S.getCtx());
802+
return false;
803+
}
804+
805+
assert(CmpInfo);
806+
const auto *CmpValueInfo = CmpInfo->getValueInfo(CmpResult);
807+
assert(CmpValueInfo);
808+
assert(CmpValueInfo->hasValidIntValue());
809+
APSInt IntValue = CmpValueInfo->getIntValue();
810+
return SetThreeWayComparisonField(S, OpPC, P, IntValue);
811+
}
812+
784813
template <PrimType Name, class T = typename PrimConv<Name>::T>
785814
bool NE(InterpState &S, CodePtr OpPC) {
786815
return CmpHelperEQ<T>(S, OpPC, [](ComparisonCategoryResult R) {

clang/lib/AST/Interp/InterpBuiltin.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,5 +594,22 @@ bool InterpretOffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E,
594594
return true;
595595
}
596596

597+
bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
598+
const Pointer &Ptr, const APSInt &IntValue) {
599+
600+
const Record *R = Ptr.getRecord();
601+
assert(R);
602+
assert(R->getNumFields() == 1);
603+
604+
unsigned FieldOffset = R->getField(0u)->Offset;
605+
const Pointer &FieldPtr = Ptr.atField(FieldOffset);
606+
PrimType FieldT = *S.getContext().classify(FieldPtr.getType());
607+
608+
INT_TYPE_SWITCH(FieldT,
609+
FieldPtr.deref<T>() = T::from(IntValue.getSExtValue()));
610+
FieldPtr.initialize();
611+
return true;
612+
}
613+
597614
} // namespace interp
598615
} // namespace clang

clang/lib/AST/Interp/Opcodes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def ArgCastKind : ArgType { let Name = "CastKind"; }
5555
def ArgCallExpr : ArgType { let Name = "const CallExpr *"; }
5656
def ArgOffsetOfExpr : ArgType { let Name = "const OffsetOfExpr *"; }
5757
def ArgDeclRef : ArgType { let Name = "const DeclRefExpr *"; }
58+
def ArgCCI : ArgType { let Name = "const ComparisonCategoryInfo *"; }
5859

5960
//===----------------------------------------------------------------------===//
6061
// Classes of types instructions operate on.
@@ -607,6 +608,10 @@ class ComparisonOpcode : Opcode {
607608
let HasGroup = 1;
608609
}
609610

611+
def CMP3 : ComparisonOpcode {
612+
let Args = [ArgCCI];
613+
}
614+
610615
def LT : ComparisonOpcode;
611616
def LE : ComparisonOpcode;
612617
def GT : ComparisonOpcode;

clang/lib/AST/Interp/Pointer.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,19 @@ class Pointer {
362362
/// Deactivates an entire strurcutre.
363363
void deactivate() const;
364364

365+
/// Compare two pointers.
366+
ComparisonCategoryResult compare(const Pointer &Other) const {
367+
if (!hasSameBase(*this, Other))
368+
return ComparisonCategoryResult::Unordered;
369+
370+
if (Offset < Other.Offset)
371+
return ComparisonCategoryResult::Less;
372+
else if (Offset > Other.Offset)
373+
return ComparisonCategoryResult::Greater;
374+
375+
return ComparisonCategoryResult::Equal;
376+
}
377+
365378
/// Checks if two pointers are comparable.
366379
static bool hasSameBase(const Pointer &A, const Pointer &B);
367380
/// Checks if two pointers can be subtracted.

clang/test/AST/Interp/cxx20.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,3 +646,57 @@ namespace ImplicitFunction {
646646
// expected-error {{not an integral constant expression}} \
647647
// expected-note {{in call to 'callMe()'}}
648648
}
649+
650+
/// FIXME: Unfortunately, the similar tests in test/SemaCXX/{compare-cxx2a.cpp use member pointers,
651+
/// which we don't support yet.
652+
namespace std {
653+
class strong_ordering {
654+
public:
655+
int n;
656+
static const strong_ordering less, equal, greater;
657+
constexpr bool operator==(int n) const noexcept { return this->n == n;}
658+
constexpr bool operator!=(int n) const noexcept { return this->n != n;}
659+
};
660+
constexpr strong_ordering strong_ordering::less = {-1};
661+
constexpr strong_ordering strong_ordering::equal = {0};
662+
constexpr strong_ordering strong_ordering::greater = {1};
663+
664+
class partial_ordering {
665+
public:
666+
long n;
667+
static const partial_ordering less, equal, greater, equivalent, unordered;
668+
constexpr bool operator==(long n) const noexcept { return this->n == n;}
669+
constexpr bool operator!=(long n) const noexcept { return this->n != n;}
670+
};
671+
constexpr partial_ordering partial_ordering::less = {-1};
672+
constexpr partial_ordering partial_ordering::equal = {0};
673+
constexpr partial_ordering partial_ordering::greater = {1};
674+
constexpr partial_ordering partial_ordering::equivalent = {0};
675+
constexpr partial_ordering partial_ordering::unordered = {-127};
676+
} // namespace std
677+
678+
namespace ThreeWayCmp {
679+
static_assert(1 <=> 2 == -1, "");
680+
static_assert(1 <=> 1 == 0, "");
681+
static_assert(2 <=> 1 == 1, "");
682+
static_assert(1.0 <=> 2.f == -1, "");
683+
static_assert(1.0 <=> 1.0 == 0, "");
684+
static_assert(2.0 <=> 1.0 == 1, "");
685+
constexpr int k = (1 <=> 1, 0); // expected-warning {{comparison result unused}} \
686+
// ref-warning {{comparison result unused}}
687+
static_assert(k== 0, "");
688+
689+
/// Pointers.
690+
constexpr int a[] = {1,2,3};
691+
constexpr int b[] = {1,2,3};
692+
constexpr const int *pa1 = &a[1];
693+
constexpr const int *pa2 = &a[2];
694+
constexpr const int *pb1 = &b[1];
695+
static_assert(pa1 <=> pb1 != 0, ""); // expected-error {{not an integral constant expression}} \
696+
// expected-note {{has unspecified value}} \
697+
// ref-error {{not an integral constant expression}} \
698+
// ref-note {{has unspecified value}}
699+
static_assert(pa1 <=> pa1 == 0, "");
700+
static_assert(pa1 <=> pa2 == -1, "");
701+
static_assert(pa2 <=> pa1 == 1, "");
702+
}

0 commit comments

Comments
 (0)