Skip to content

Commit 7138397

Browse files
authored
[Clang] Add __builtin_invoke and use it in libc++ (#116709)
`std::invoke` is currently quite heavy compared to a function call, since it involves quite heavy SFINAE. This can be done significantly more efficient by the compiler, since most calls to `std::invoke` are simple function calls and 6 out of the seven overloads for `std::invoke` exist only to support member pointers. Even these boil down to a few relatively simple checks. Some real-world testing with this patch revealed some significant results. For example, instantiating `std::format("Banane")` (and its callees) went down from ~125ms on my system to ~104ms.
1 parent 43ab5bb commit 7138397

File tree

10 files changed

+624
-82
lines changed

10 files changed

+624
-82
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3798,6 +3798,17 @@ Trivially relocates ``count`` objects of relocatable, complete type ``T``
37983798
from ``src`` to ``dest`` and returns ``dest``.
37993799
This builtin is used to implement ``std::trivially_relocate``.
38003800
3801+
``__builtin_invoke``
3802+
--------------------
3803+
3804+
**Syntax**:
3805+
3806+
.. code-block:: c++
3807+
3808+
template <class Callee, class... Args>
3809+
decltype(auto) __builtin_invoke(Callee&& callee, Args&&... args);
3810+
3811+
``__builtin_invoke`` is equivalent to ``std::invoke``.
38013812
38023813
``__builtin_preserve_access_index``
38033814
-----------------------------------

clang/docs/ReleaseNotes.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ Non-comprehensive list of changes in this release
327327
different than before.
328328
- Fixed a crash when a VLA with an invalid size expression was used within a
329329
``sizeof`` or ``typeof`` expression. (#GH138444)
330+
- ``__builtin_invoke`` has been added to improve the compile time of ``std::invoke``.
330331
- Deprecation warning is emitted for the deprecated ``__reference_binds_to_temporary`` intrinsic.
331332
``__reference_constructs_from_temporary`` should be used instead. (#GH44056)
332333
- Added `__builtin_get_vtable_pointer` to directly load the primary vtable pointer from a
@@ -656,7 +657,7 @@ Improvements to Clang's diagnostics
656657
false positives in exception-heavy code, though only simple patterns
657658
are currently recognized.
658659

659-
660+
660661
Improvements to Clang's time-trace
661662
----------------------------------
662663

@@ -734,7 +735,7 @@ Bug Fixes in This Version
734735
- Fixed incorrect token location when emitting diagnostics for tokens expanded from macros. (#GH143216)
735736
- Fixed an infinite recursion when checking constexpr destructors. (#GH141789)
736737
- Fixed a crash when a malformed using declaration appears in a ``constexpr`` function. (#GH144264)
737-
- Fixed a bug when use unicode character name in macro concatenation. (#GH145240)
738+
- Fixed a bug when use unicode character name in macro concatenation. (#GH145240)
738739

739740
Bug Fixes to Compiler Builtins
740741
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4314,6 +4314,12 @@ def MoveIfNsoexcept : CxxLibBuiltin<"utility"> {
43144314
let Namespace = "std";
43154315
}
43164316

4317+
def Invoke : Builtin {
4318+
let Spellings = ["__builtin_invoke"];
4319+
let Attributes = [CustomTypeChecking, Constexpr];
4320+
let Prototype = "void(...)";
4321+
}
4322+
43174323
def Annotation : Builtin {
43184324
let Spellings = ["__builtin_annotation"];
43194325
let Attributes = [NoThrow, CustomTypeChecking];

clang/include/clang/Sema/Sema.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15192,11 +15192,18 @@ class Sema final : public SemaBase {
1519215192
SourceLocation Loc);
1519315193
QualType BuiltinRemoveReference(QualType BaseType, UTTKind UKind,
1519415194
SourceLocation Loc);
15195+
15196+
QualType BuiltinRemoveCVRef(QualType BaseType, SourceLocation Loc) {
15197+
return BuiltinRemoveReference(BaseType, UTTKind::RemoveCVRef, Loc);
15198+
}
15199+
1519515200
QualType BuiltinChangeCVRQualifiers(QualType BaseType, UTTKind UKind,
1519615201
SourceLocation Loc);
1519715202
QualType BuiltinChangeSignedness(QualType BaseType, UTTKind UKind,
1519815203
SourceLocation Loc);
1519915204

15205+
bool BuiltinIsBaseOf(SourceLocation RhsTLoc, QualType LhsT, QualType RhsT);
15206+
1520015207
/// Ensure that the type T is a literal type.
1520115208
///
1520215209
/// This routine checks whether the type @p T is a literal type. If @p T is an

clang/lib/Sema/SemaChecking.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,99 @@ static bool BuiltinCountZeroBitsGeneric(Sema &S, CallExpr *TheCall) {
22612261
return false;
22622262
}
22632263

2264+
static ExprResult BuiltinInvoke(Sema &S, CallExpr *TheCall) {
2265+
SourceLocation Loc = TheCall->getBeginLoc();
2266+
MutableArrayRef Args(TheCall->getArgs(), TheCall->getNumArgs());
2267+
assert(llvm::none_of(Args, [](Expr *Arg) { return Arg->isTypeDependent(); }));
2268+
2269+
if (Args.size() == 0) {
2270+
S.Diag(TheCall->getBeginLoc(),
2271+
diag::err_typecheck_call_too_few_args_at_least)
2272+
<< /*callee_type=*/0 << /*min_arg_count=*/1 << /*actual_arg_count=*/0
2273+
<< /*is_non_object=*/0 << TheCall->getSourceRange();
2274+
return ExprError();
2275+
}
2276+
2277+
QualType FuncT = Args[0]->getType();
2278+
2279+
if (const auto *MPT = FuncT->getAs<MemberPointerType>()) {
2280+
if (Args.size() < 2) {
2281+
S.Diag(TheCall->getBeginLoc(),
2282+
diag::err_typecheck_call_too_few_args_at_least)
2283+
<< /*callee_type=*/0 << /*min_arg_count=*/2 << /*actual_arg_count=*/1
2284+
<< /*is_non_object=*/0 << TheCall->getSourceRange();
2285+
return ExprError();
2286+
}
2287+
2288+
const Type *MemPtrClass = MPT->getQualifier()->getAsType();
2289+
QualType ObjectT = Args[1]->getType();
2290+
2291+
if (MPT->isMemberDataPointer() && S.checkArgCount(TheCall, 2))
2292+
return ExprError();
2293+
2294+
ExprResult ObjectArg = [&]() -> ExprResult {
2295+
// (1.1): (t1.*f)(t2, ..., tN) when f is a pointer to a member function of
2296+
// a class T and is_same_v<T, remove_cvref_t<decltype(t1)>> ||
2297+
// is_base_of_v<T, remove_cvref_t<decltype(t1)>> is true;
2298+
// (1.4): t1.*f when N=1 and f is a pointer to data member of a class T
2299+
// and is_same_v<T, remove_cvref_t<decltype(t1)>> ||
2300+
// is_base_of_v<T, remove_cvref_t<decltype(t1)>> is true;
2301+
if (S.Context.hasSameType(QualType(MemPtrClass, 0),
2302+
S.BuiltinRemoveCVRef(ObjectT, Loc)) ||
2303+
S.BuiltinIsBaseOf(Args[1]->getBeginLoc(), QualType(MemPtrClass, 0),
2304+
S.BuiltinRemoveCVRef(ObjectT, Loc))) {
2305+
return Args[1];
2306+
}
2307+
2308+
// (t1.get().*f)(t2, ..., tN) when f is a pointer to a member function of
2309+
// a class T and remove_cvref_t<decltype(t1)> is a specialization of
2310+
// reference_wrapper;
2311+
if (const auto *RD = ObjectT->getAsCXXRecordDecl()) {
2312+
if (RD->isInStdNamespace() &&
2313+
RD->getDeclName().getAsString() == "reference_wrapper") {
2314+
CXXScopeSpec SS;
2315+
IdentifierInfo *GetName = &S.Context.Idents.get("get");
2316+
UnqualifiedId GetID;
2317+
GetID.setIdentifier(GetName, Loc);
2318+
2319+
ExprResult MemExpr = S.ActOnMemberAccessExpr(
2320+
S.getCurScope(), Args[1], Loc, tok::period, SS,
2321+
/*TemplateKWLoc=*/SourceLocation(), GetID, nullptr);
2322+
2323+
if (MemExpr.isInvalid())
2324+
return ExprError();
2325+
2326+
return S.ActOnCallExpr(S.getCurScope(), MemExpr.get(), Loc, {}, Loc);
2327+
}
2328+
}
2329+
2330+
// ((*t1).*f)(t2, ..., tN) when f is a pointer to a member function of a
2331+
// class T and t1 does not satisfy the previous two items;
2332+
2333+
return S.ActOnUnaryOp(S.getCurScope(), Loc, tok::star, Args[1]);
2334+
}();
2335+
2336+
if (ObjectArg.isInvalid())
2337+
return ExprError();
2338+
2339+
ExprResult BinOp = S.ActOnBinOp(S.getCurScope(), TheCall->getBeginLoc(),
2340+
tok::periodstar, ObjectArg.get(), Args[0]);
2341+
if (BinOp.isInvalid())
2342+
return ExprError();
2343+
2344+
if (MPT->isMemberDataPointer())
2345+
return BinOp;
2346+
2347+
auto *MemCall = new (S.Context)
2348+
ParenExpr(SourceLocation(), SourceLocation(), BinOp.get());
2349+
2350+
return S.ActOnCallExpr(S.getCurScope(), MemCall, TheCall->getBeginLoc(),
2351+
Args.drop_front(2), TheCall->getRParenLoc());
2352+
}
2353+
return S.ActOnCallExpr(S.getCurScope(), Args.front(), TheCall->getBeginLoc(),
2354+
Args.drop_front(), TheCall->getRParenLoc());
2355+
}
2356+
22642357
ExprResult
22652358
Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
22662359
CallExpr *TheCall) {
@@ -2420,6 +2513,8 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
24202513
return BuiltinShuffleVector(TheCall);
24212514
// TheCall will be freed by the smart pointer here, but that's fine, since
24222515
// BuiltinShuffleVector guts it, but then doesn't release it.
2516+
case Builtin::BI__builtin_invoke:
2517+
return BuiltinInvoke(*this, TheCall);
24232518
case Builtin::BI__builtin_prefetch:
24242519
if (BuiltinPrefetch(TheCall))
24252520
return ExprError();

clang/lib/Sema/SemaTypeTraits.cpp

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,58 @@ ExprResult Sema::ActOnTypeTrait(TypeTrait Kind, SourceLocation KWLoc,
15791579
return BuildTypeTrait(Kind, KWLoc, ConvertedArgs, RParenLoc);
15801580
}
15811581

1582+
bool Sema::BuiltinIsBaseOf(SourceLocation RhsTLoc, QualType LhsT,
1583+
QualType RhsT) {
1584+
// C++0x [meta.rel]p2
1585+
// Base is a base class of Derived without regard to cv-qualifiers or
1586+
// Base and Derived are not unions and name the same class type without
1587+
// regard to cv-qualifiers.
1588+
1589+
const RecordType *lhsRecord = LhsT->getAs<RecordType>();
1590+
const RecordType *rhsRecord = RhsT->getAs<RecordType>();
1591+
if (!rhsRecord || !lhsRecord) {
1592+
const ObjCObjectType *LHSObjTy = LhsT->getAs<ObjCObjectType>();
1593+
const ObjCObjectType *RHSObjTy = RhsT->getAs<ObjCObjectType>();
1594+
if (!LHSObjTy || !RHSObjTy)
1595+
return false;
1596+
1597+
ObjCInterfaceDecl *BaseInterface = LHSObjTy->getInterface();
1598+
ObjCInterfaceDecl *DerivedInterface = RHSObjTy->getInterface();
1599+
if (!BaseInterface || !DerivedInterface)
1600+
return false;
1601+
1602+
if (RequireCompleteType(RhsTLoc, RhsT,
1603+
diag::err_incomplete_type_used_in_type_trait_expr))
1604+
return false;
1605+
1606+
return BaseInterface->isSuperClassOf(DerivedInterface);
1607+
}
1608+
1609+
assert(Context.hasSameUnqualifiedType(LhsT, RhsT) ==
1610+
(lhsRecord == rhsRecord));
1611+
1612+
// Unions are never base classes, and never have base classes.
1613+
// It doesn't matter if they are complete or not. See PR#41843
1614+
if (lhsRecord && lhsRecord->getDecl()->isUnion())
1615+
return false;
1616+
if (rhsRecord && rhsRecord->getDecl()->isUnion())
1617+
return false;
1618+
1619+
if (lhsRecord == rhsRecord)
1620+
return true;
1621+
1622+
// C++0x [meta.rel]p2:
1623+
// If Base and Derived are class types and are different types
1624+
// (ignoring possible cv-qualifiers) then Derived shall be a
1625+
// complete type.
1626+
if (RequireCompleteType(RhsTLoc, RhsT,
1627+
diag::err_incomplete_type_used_in_type_trait_expr))
1628+
return false;
1629+
1630+
return cast<CXXRecordDecl>(rhsRecord->getDecl())
1631+
->isDerivedFrom(cast<CXXRecordDecl>(lhsRecord->getDecl()));
1632+
}
1633+
15821634
static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT,
15831635
const TypeSourceInfo *Lhs,
15841636
const TypeSourceInfo *Rhs,
@@ -1590,58 +1642,9 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT,
15901642
"Cannot evaluate traits of dependent types");
15911643

15921644
switch (BTT) {
1593-
case BTT_IsBaseOf: {
1594-
// C++0x [meta.rel]p2
1595-
// Base is a base class of Derived without regard to cv-qualifiers or
1596-
// Base and Derived are not unions and name the same class type without
1597-
// regard to cv-qualifiers.
1598-
1599-
const RecordType *lhsRecord = LhsT->getAs<RecordType>();
1600-
const RecordType *rhsRecord = RhsT->getAs<RecordType>();
1601-
if (!rhsRecord || !lhsRecord) {
1602-
const ObjCObjectType *LHSObjTy = LhsT->getAs<ObjCObjectType>();
1603-
const ObjCObjectType *RHSObjTy = RhsT->getAs<ObjCObjectType>();
1604-
if (!LHSObjTy || !RHSObjTy)
1605-
return false;
1606-
1607-
ObjCInterfaceDecl *BaseInterface = LHSObjTy->getInterface();
1608-
ObjCInterfaceDecl *DerivedInterface = RHSObjTy->getInterface();
1609-
if (!BaseInterface || !DerivedInterface)
1610-
return false;
1611-
1612-
if (Self.RequireCompleteType(
1613-
Rhs->getTypeLoc().getBeginLoc(), RhsT,
1614-
diag::err_incomplete_type_used_in_type_trait_expr))
1615-
return false;
1616-
1617-
return BaseInterface->isSuperClassOf(DerivedInterface);
1618-
}
1619-
1620-
assert(Self.Context.hasSameUnqualifiedType(LhsT, RhsT) ==
1621-
(lhsRecord == rhsRecord));
1622-
1623-
// Unions are never base classes, and never have base classes.
1624-
// It doesn't matter if they are complete or not. See PR#41843
1625-
if (lhsRecord && lhsRecord->getDecl()->isUnion())
1626-
return false;
1627-
if (rhsRecord && rhsRecord->getDecl()->isUnion())
1628-
return false;
1645+
case BTT_IsBaseOf:
1646+
return Self.BuiltinIsBaseOf(Rhs->getTypeLoc().getBeginLoc(), LhsT, RhsT);
16291647

1630-
if (lhsRecord == rhsRecord)
1631-
return true;
1632-
1633-
// C++0x [meta.rel]p2:
1634-
// If Base and Derived are class types and are different types
1635-
// (ignoring possible cv-qualifiers) then Derived shall be a
1636-
// complete type.
1637-
if (Self.RequireCompleteType(
1638-
Rhs->getTypeLoc().getBeginLoc(), RhsT,
1639-
diag::err_incomplete_type_used_in_type_trait_expr))
1640-
return false;
1641-
1642-
return cast<CXXRecordDecl>(rhsRecord->getDecl())
1643-
->isDerivedFrom(cast<CXXRecordDecl>(lhsRecord->getDecl()));
1644-
}
16451648
case BTT_IsVirtualBaseOf: {
16461649
const RecordType *BaseRecord = LhsT->getAs<RecordType>();
16471650
const RecordType *DerivedRecord = RhsT->getAs<RecordType>();
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %clang_cc1 -triple=x86_64-linux-gnu -emit-llvm -o - %s | FileCheck %s
2+
3+
extern "C" void* memcpy(void*, const void*, decltype(sizeof(int)));
4+
void func();
5+
6+
namespace std {
7+
template <class T>
8+
class reference_wrapper {
9+
T* ptr;
10+
11+
public:
12+
T& get() { return *ptr; }
13+
};
14+
} // namespace std
15+
16+
struct Callable {
17+
void operator()() {}
18+
19+
void func();
20+
};
21+
22+
extern "C" void call1() {
23+
__builtin_invoke(func);
24+
__builtin_invoke(Callable{});
25+
__builtin_invoke(memcpy, nullptr, nullptr, 0);
26+
27+
// CHECK: define dso_local void @call1
28+
// CHECK-NEXT: entry:
29+
// CHECK-NEXT: %ref.tmp = alloca %struct.Callable, align 1
30+
// CHECK-NEXT: call void @_Z4funcv()
31+
// CHECK-NEXT: call void @_ZN8CallableclEv(ptr noundef nonnull align 1 dereferenceable(1) %ref.tmp)
32+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 1 null, ptr align 1 null, i64 0, i1 false)
33+
// CHECK-NEXT: ret void
34+
}
35+
36+
extern "C" void call_memptr(std::reference_wrapper<Callable> wrapper) {
37+
__builtin_invoke(&Callable::func, wrapper);
38+
39+
// CHECK: define dso_local void @call_memptr
40+
// CHECK-NEXT: entry:
41+
// CHECK-NEXT: %wrapper = alloca %"class.std::reference_wrapper", align 8
42+
// CHECK-NEXT: %coerce.dive = getelementptr inbounds nuw %"class.std::reference_wrapper", ptr %wrapper, i32 0, i32 0
43+
// CHECK-NEXT: store ptr %wrapper.coerce, ptr %coerce.dive, align 8
44+
// CHECK-NEXT: %call = call noundef nonnull align 1 dereferenceable(1) ptr @_ZNSt17reference_wrapperI8CallableE3getEv(ptr noundef nonnull align 8 dereferenceable(8) %wrapper)
45+
// CHECK-NEXT: %0 = getelementptr inbounds i8, ptr %call, i64 0
46+
// CHECK-NEXT: br i1 false, label %memptr.virtual, label %memptr.nonvirtual
47+
// CHECK-EMPTY:
48+
// CHECK-NEXT: memptr.virtual:
49+
// CHECK-NEXT: %vtable = load ptr, ptr %0, align 8
50+
// CHECK-NEXT: %1 = getelementptr i8, ptr %vtable, i64 sub (i64 ptrtoint (ptr @_ZN8Callable4funcEv to i64), i64 1), !nosanitize !2
51+
// CHECK-NEXT: %memptr.virtualfn = load ptr, ptr %1, align 8, !nosanitize !2
52+
// CHECK-NEXT: br label %memptr.end
53+
// CHECK-EMPTY:
54+
// CHECK-NEXT: memptr.nonvirtual:
55+
// CHECK-NEXT: br label %memptr.end
56+
// CHECK-EMPTY:
57+
// CHECK-NEXT: memptr.end:
58+
// CHECK-NEXT: %2 = phi ptr [ %memptr.virtualfn, %memptr.virtual ], [ @_ZN8Callable4funcEv, %memptr.nonvirtual ]
59+
// CHECK-NEXT: call void %2(ptr noundef nonnull align 1 dereferenceable(1) %0)
60+
// CHECK-NEXT: ret void
61+
}

0 commit comments

Comments
 (0)