Skip to content

Commit de0df63

Browse files
committed
[CUDA][HIP] Fix overloading resolution in global variable initializer
Currently, clang does not resolve certain overloaded functions correctly in the initializer of global variables, e.g. template<typename T1, typename U> T1 mypow(T1, U); __attribute__((device)) double mypow(double, int); double t_extent = mypow(1.0, 2); In the above example, mypow is supposed to resolve to the host version but clang resolves it to the device version instead, and emits an error (https://godbolt.org/z/17xxzaa67). However, if the variable is assigned in a host function, there is no error. The discrepancy in overloading resolution inside and outside of a function is due to clang not accounting for the host/device target when resolving functions called in the initializer of a global variable. This patch introduces a global host/device target context for CUDA/HIP for functions called outside of functions. For global variable initialization, it is determined by the host/device attribute of the variable. For other situations, a default value of host_device is sufficient. Reviewed by: Artem Belevich Differential Revision: https://reviews.llvm.org/D158247 Fixes: SWDEV-416731
1 parent 1c5fd15 commit de0df63

File tree

11 files changed

+219
-68
lines changed

11 files changed

+219
-68
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,14 @@ class Sema final {
10121012
}
10131013
} DelayedDiagnostics;
10141014

1015+
enum CUDAFunctionTarget {
1016+
CFT_Device,
1017+
CFT_Global,
1018+
CFT_Host,
1019+
CFT_HostDevice,
1020+
CFT_InvalidTarget
1021+
};
1022+
10151023
/// A RAII object to temporarily push a declaration context.
10161024
class ContextRAII {
10171025
private:
@@ -4751,8 +4759,13 @@ class Sema final {
47514759
bool isValidPointerAttrType(QualType T, bool RefOkay = false);
47524760

47534761
bool CheckRegparmAttr(const ParsedAttr &attr, unsigned &value);
4762+
4763+
/// Check validaty of calling convention attribute \p attr. If \p FD
4764+
/// is not null pointer, use \p FD to determine the CUDA/HIP host/device
4765+
/// target. Otherwise, it is specified by \p CFT.
47544766
bool CheckCallingConvAttr(const ParsedAttr &attr, CallingConv &CC,
4755-
const FunctionDecl *FD = nullptr);
4767+
const FunctionDecl *FD = nullptr,
4768+
CUDAFunctionTarget CFT = CFT_InvalidTarget);
47564769
bool CheckAttrTarget(const ParsedAttr &CurrAttr);
47574770
bool CheckAttrNoArgs(const ParsedAttr &CurrAttr);
47584771
bool checkStringLiteralArgumentAttr(const AttributeCommonInfo &CI,
@@ -13259,14 +13272,6 @@ class Sema final {
1325913272
void checkTypeSupport(QualType Ty, SourceLocation Loc,
1326013273
ValueDecl *D = nullptr);
1326113274

13262-
enum CUDAFunctionTarget {
13263-
CFT_Device,
13264-
CFT_Global,
13265-
CFT_Host,
13266-
CFT_HostDevice,
13267-
CFT_InvalidTarget
13268-
};
13269-
1327013275
/// Determines whether the given function is a CUDA device/host/kernel/etc.
1327113276
/// function.
1327213277
///
@@ -13285,6 +13290,29 @@ class Sema final {
1328513290
/// Determines whether the given variable is emitted on host or device side.
1328613291
CUDAVariableTarget IdentifyCUDATarget(const VarDecl *D);
1328713292

13293+
/// Defines kinds of CUDA global host/device context where a function may be
13294+
/// called.
13295+
enum CUDATargetContextKind {
13296+
CTCK_Unknown, /// Unknown context
13297+
CTCK_InitGlobalVar, /// Function called during global variable
13298+
/// initialization
13299+
};
13300+
13301+
/// Define the current global CUDA host/device context where a function may be
13302+
/// called. Only used when a function is called outside of any functions.
13303+
struct CUDATargetContext {
13304+
CUDAFunctionTarget Target = CFT_HostDevice;
13305+
CUDATargetContextKind Kind = CTCK_Unknown;
13306+
Decl *D = nullptr;
13307+
} CurCUDATargetCtx;
13308+
13309+
struct CUDATargetContextRAII {
13310+
Sema &S;
13311+
CUDATargetContext SavedCtx;
13312+
CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
13313+
~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
13314+
};
13315+
1328813316
/// Gets the CUDA target for the current context.
1328913317
CUDAFunctionTarget CurrentCUDATarget() {
1329013318
return IdentifyCUDATarget(dyn_cast<FunctionDecl>(CurContext));

clang/lib/Parse/ParseDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,7 @@ Decl *Parser::ParseDeclarationAfterDeclaratorAndAttributes(
25832583
}
25842584
}
25852585

2586+
Sema::CUDATargetContextRAII X(Actions, Sema::CTCK_InitGlobalVar, ThisDecl);
25862587
switch (TheInitKind) {
25872588
// Parse declarator '=' initializer.
25882589
case InitKind::Equal: {

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,37 @@ Sema::IdentifyCUDATarget(const ParsedAttributesView &Attrs) {
105105
}
106106

107107
template <typename A>
108-
static bool hasAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
108+
static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
109109
return D->hasAttrs() && llvm::any_of(D->getAttrs(), [&](Attr *Attribute) {
110110
return isa<A>(Attribute) &&
111111
!(IgnoreImplicitAttr && Attribute->isImplicit());
112112
});
113113
}
114114

115+
Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
116+
CUDATargetContextKind K,
117+
Decl *D)
118+
: S(S_) {
119+
SavedCtx = S.CurCUDATargetCtx;
120+
assert(K == CTCK_InitGlobalVar);
121+
auto *VD = dyn_cast_or_null<VarDecl>(D);
122+
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
123+
auto Target = CFT_Host;
124+
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
125+
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
126+
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
127+
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
128+
Target = CFT_Device;
129+
S.CurCUDATargetCtx = {Target, K, VD};
130+
}
131+
}
132+
115133
/// IdentifyCUDATarget - Determine the CUDA compilation target for this function
116134
Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
117135
bool IgnoreImplicitHDAttr) {
118-
// Code that lives outside a function is run on the host.
136+
// Code that lives outside a function gets the target from CurCUDATargetCtx.
119137
if (D == nullptr)
120-
return CFT_Host;
138+
return CurCUDATargetCtx.Target;
121139

122140
if (D->hasAttr<CUDAInvalidTargetAttr>())
123141
return CFT_InvalidTarget;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5317,7 +5317,8 @@ static void handleNoRandomizeLayoutAttr(Sema &S, Decl *D,
53175317
}
53185318

53195319
bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
5320-
const FunctionDecl *FD) {
5320+
const FunctionDecl *FD,
5321+
CUDAFunctionTarget CFT) {
53215322
if (Attrs.isInvalid())
53225323
return true;
53235324

@@ -5416,7 +5417,8 @@ bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
54165417
// on their host/device attributes.
54175418
if (LangOpts.CUDA) {
54185419
auto *Aux = Context.getAuxTargetInfo();
5419-
auto CudaTarget = IdentifyCUDATarget(FD);
5420+
assert(FD || CFT != CFT_InvalidTarget);
5421+
auto CudaTarget = FD ? IdentifyCUDATarget(FD) : CFT;
54205422
bool CheckHost = false, CheckDevice = false;
54215423
switch (CudaTarget) {
54225424
case CFT_HostDevice:

clang/lib/Sema/SemaOverload.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6699,17 +6699,19 @@ void Sema::AddOverloadCandidate(
66996699
}
67006700

67016701
// (CUDA B.1): Check for invalid calls between targets.
6702-
if (getLangOpts().CUDA)
6703-
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
6704-
// Skip the check for callers that are implicit members, because in this
6705-
// case we may not yet know what the member's target is; the target is
6706-
// inferred for the member automatically, based on the bases and fields of
6707-
// the class.
6708-
if (!Caller->isImplicit() && !IsAllowedCUDACall(Caller, Function)) {
6709-
Candidate.Viable = false;
6710-
Candidate.FailureKind = ovl_fail_bad_target;
6711-
return;
6712-
}
6702+
if (getLangOpts().CUDA) {
6703+
const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
6704+
// Skip the check for callers that are implicit members, because in this
6705+
// case we may not yet know what the member's target is; the target is
6706+
// inferred for the member automatically, based on the bases and fields of
6707+
// the class.
6708+
if (!(Caller && Caller->isImplicit()) &&
6709+
!IsAllowedCUDACall(Caller, Function)) {
6710+
Candidate.Viable = false;
6711+
Candidate.FailureKind = ovl_fail_bad_target;
6712+
return;
6713+
}
6714+
}
67136715

67146716
if (Function->getTrailingRequiresClause()) {
67156717
ConstraintSatisfaction Satisfaction;
@@ -7221,12 +7223,11 @@ Sema::AddMethodCandidate(CXXMethodDecl *Method, DeclAccessPair FoundDecl,
72217223

72227224
// (CUDA B.1): Check for invalid calls between targets.
72237225
if (getLangOpts().CUDA)
7224-
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
7225-
if (!IsAllowedCUDACall(Caller, Method)) {
7226-
Candidate.Viable = false;
7227-
Candidate.FailureKind = ovl_fail_bad_target;
7228-
return;
7229-
}
7226+
if (!IsAllowedCUDACall(getCurFunctionDecl(/*AllowLambda=*/true), Method)) {
7227+
Candidate.Viable = false;
7228+
Candidate.FailureKind = ovl_fail_bad_target;
7229+
return;
7230+
}
72307231

72317232
if (Method->getTrailingRequiresClause()) {
72327233
ConstraintSatisfaction Satisfaction;
@@ -12497,10 +12498,12 @@ class AddressOfFunctionResolver {
1249712498
return false;
1249812499

1249912500
if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
12500-
if (S.getLangOpts().CUDA)
12501-
if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
12502-
if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
12503-
return false;
12501+
if (S.getLangOpts().CUDA) {
12502+
FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
12503+
if (!(Caller && Caller->isImplicit()) &&
12504+
!S.IsAllowedCUDACall(Caller, FunDecl))
12505+
return false;
12506+
}
1250412507
if (FunDecl->isMultiVersion()) {
1250512508
const auto *TA = FunDecl->getAttr<TargetAttr>();
1250612509
if (TA && !TA->isDefaultVersion())

clang/lib/Sema/SemaType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4055,7 +4055,8 @@ static CallingConv getCCForDeclaratorChunk(
40554055
// function type. We'll diagnose the failure to apply them in
40564056
// handleFunctionTypeAttr.
40574057
CallingConv CC;
4058-
if (!S.CheckCallingConvAttr(AL, CC) &&
4058+
if (!S.CheckCallingConvAttr(AL, CC, /*FunctionDecl=*/nullptr,
4059+
S.IdentifyCUDATarget(D.getAttributes())) &&
40594060
(!FTI.isVariadic || supportsVariadicCall(CC))) {
40604061
return CC;
40614062
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %clang_cc1 %s -triple x86_64-linux-unknown -emit-llvm -o - \
2+
// RUN: | FileCheck -check-prefix=HOST %s
3+
// RUN: %clang_cc1 %s -fcuda-is-device \
4+
// RUN: -emit-llvm -o - -triple nvptx64 \
5+
// RUN: -aux-triple x86_64-unknown-linux-gnu | FileCheck \
6+
// RUN: -check-prefix=DEV %s
7+
8+
#include "Inputs/cuda.h"
9+
10+
// Check host/device-based overloding resolution in global variable initializer.
11+
double pow(double, double) { return 1.0; }
12+
13+
__device__ double pow(double, int) { return 2.0; }
14+
15+
// HOST-DAG: call {{.*}}double @_Z3powdd(double noundef 1.000000e+00, double noundef 1.000000e+00)
16+
double X = pow(1.0, 1);
17+
18+
constexpr double cpow(double, double) { return 11.0; }
19+
20+
constexpr __device__ double cpow(double, int) { return 12.0; }
21+
22+
// HOST-DAG: @CX = global double 1.100000e+01
23+
double CX = cpow(11.0, 1);
24+
25+
// DEV-DAG: @CY = addrspace(1) externally_initialized global double 1.200000e+01
26+
__device__ double CY = cpow(12.0, 1);
27+
28+
struct A {
29+
double pow(double, double) { return 3.0; }
30+
31+
__device__ double pow(double, int) { return 4.0; }
32+
};
33+
34+
A a;
35+
36+
// HOST-DAG: call {{.*}}double @_ZN1A3powEdd(ptr {{.*}}@a, double noundef 3.000000e+00, double noundef 1.000000e+00)
37+
double AX = a.pow(3.0, 1);
38+
39+
struct CA {
40+
constexpr double cpow(double, double) const { return 13.0; }
41+
42+
constexpr __device__ double cpow(double, int) const { return 14.0; }
43+
};
44+
45+
const CA ca;
46+
47+
// HOST-DAG: @CAX = global double 1.300000e+01
48+
double CAX = ca.cpow(13.0, 1);
49+
50+
// DEV-DAG: @CAY = addrspace(1) externally_initialized global double 1.400000e+01
51+
__device__ double CAY = ca.cpow(14.0, 1);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -aux-triple x86_64-pc-windows-msvc -fms-compatibility -fcuda-is-device -fsyntax-only -verify %s
2+
// RUN: %clang_cc1 -triple x86_64-pc-windows-msvc -fms-compatibility -fsyntax-only -verify %s
23

34
__cdecl void hostf1();
45
__vectorcall void (*hostf2)() = hostf1; // expected-error {{cannot initialize a variable of type 'void ((*))() __attribute__((vectorcall))' with an lvalue of type 'void () __attribute__((cdecl))'}}

clang/test/SemaCUDA/function-overload.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,13 @@ __host__ __device__ void hostdevicef() {
222222
// Test for address of overloaded function resolution in the global context.
223223
HostFnPtr fp_h = h;
224224
HostFnPtr fp_ch = ch;
225+
#if defined (__CUDA_ARCH__)
226+
__device__
227+
#endif
225228
CurrentFnPtr fp_dh = dh;
229+
#if defined (__CUDA_ARCH__)
230+
__device__
231+
#endif
226232
CurrentFnPtr fp_cdh = cdh;
227233
GlobalFnPtr fp_g = g;
228234

clang/test/SemaCUDA/global-initializers-host.cu

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: %clang_cc1 %s -triple x86_64-linux-unknown -fsyntax-only -o - -verify
2+
// RUN: %clang_cc1 %s -fcuda-is-device -triple nvptx -fsyntax-only -o - -verify
3+
4+
#include "Inputs/cuda.h"
5+
6+
// Check that we get an error if we try to call a __device__ function from a
7+
// module initializer.
8+
9+
struct S {
10+
// expected-note@-1 {{candidate constructor (the implicit copy constructor) not viable: requires 1 argument, but 0 were provided}}
11+
// expected-note@-2 {{candidate constructor (the implicit move constructor) not viable: requires 1 argument, but 0 were provided}}
12+
__device__ S() {}
13+
// expected-note@-1 {{candidate constructor not viable: call to __device__ function from __host__ function}}
14+
};
15+
16+
S s;
17+
// expected-error@-1 {{no matching constructor for initialization of 'S'}}
18+
19+
struct T {
20+
__host__ __device__ T() {}
21+
};
22+
T t; // No error, this is OK.
23+
24+
struct U {
25+
// expected-note@-1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const U' for 1st argument}}
26+
// expected-note@-2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'U' for 1st argument}}
27+
__host__ U() {}
28+
// expected-note@-1 {{candidate constructor not viable: requires 0 arguments, but 1 was provided}}
29+
__device__ U(int) {}
30+
// expected-note@-1 {{candidate constructor not viable: call to __device__ function from __host__ function}}
31+
};
32+
U u(42);
33+
// expected-error@-1 {{no matching constructor for initialization of 'U'}}
34+
35+
__device__ int device_fn() { return 42; }
36+
// expected-note@-1 {{candidate function not viable: call to __device__ function from __host__ function}}
37+
int n = device_fn();
38+
// expected-error@-1 {{no matching function for call to 'device_fn'}}
39+
40+
// Check host/device-based overloding resolution in global variable initializer.
41+
double pow(double, double);
42+
43+
__device__ double pow(double, int);
44+
45+
double X = pow(1.0, 1);
46+
__device__ double Y = pow(2.0, 2); // expected-error{{dynamic initialization is not supported for __device__, __constant__, __shared__, and __managed__ variables}}
47+
48+
constexpr double cpow(double, double) { return 1.0; }
49+
50+
constexpr __device__ double cpow(double, int) { return 2.0; }
51+
52+
const double CX = cpow(1.0, 1);
53+
const __device__ double CY = cpow(2.0, 2);
54+
55+
struct A {
56+
double pow(double, double);
57+
58+
__device__ double pow(double, int);
59+
60+
constexpr double cpow(double, double) const { return 1.0; }
61+
62+
constexpr __device__ double cpow(double, int) const { return 1.0; }
63+
64+
};
65+
66+
A a;
67+
double AX = a.pow(1.0, 1);
68+
__device__ double AY = a.pow(2.0, 2); // expected-error{{dynamic initialization is not supported for __device__, __constant__, __shared__, and __managed__ variables}}
69+
70+
const A ca;
71+
const double CAX = ca.cpow(1.0, 1);
72+
const __device__ double CAY = ca.cpow(2.0, 2);

0 commit comments

Comments
 (0)