Skip to content

[OpenMP 60] Initial parsing/sema for need_device_addr modifier on adjust_args clause #143442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4630,6 +4630,7 @@ def OMPDeclareVariant : InheritableAttr {
OMPTraitInfoArgument<"TraitInfos">,
VariadicExprArgument<"AdjustArgsNothing">,
VariadicExprArgument<"AdjustArgsNeedDevicePtr">,
VariadicExprArgument<"AdjustArgsNeedDeviceAddr">,
VariadicOMPInteropInfoArgument<"AppendArgs">,
];
let AdditionalMembers = [{
Expand Down
6 changes: 4 additions & 2 deletions clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1581,8 +1581,10 @@ def err_omp_unexpected_append_op : Error<
"unexpected operation specified in 'append_args' clause, expected 'interop'">;
def err_omp_unexpected_execution_modifier : Error<
"unexpected 'execution' modifier in non-executable context">;
def err_omp_unknown_adjust_args_op : Error<
"incorrect adjust_args type, expected 'need_device_ptr' or 'nothing'">;
def err_omp_unknown_adjust_args_op
: Error<
"incorrect 'adjust_args' type, expected 'need_device_ptr'%select{|, "
"'need_device_addr',}0 or 'nothing'">;
def err_omp_declare_variant_wrong_clause : Error<
"expected %select{'match'|'match', 'adjust_args', or 'append_args'}0 clause "
"on 'omp declare variant' directive">;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenMPKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ OPENMP_ORIGINAL_SHARING_MODIFIER(default)
// Adjust-op kinds for the 'adjust_args' clause.
OPENMP_ADJUST_ARGS_KIND(nothing)
OPENMP_ADJUST_ARGS_KIND(need_device_ptr)
OPENMP_ADJUST_ARGS_KIND(need_device_addr)

// Binding kinds for the 'bind' clause.
OPENMP_BIND_KIND(teams)
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ class SemaOpenMP : public SemaBase {
FunctionDecl *FD, Expr *VariantRef, OMPTraitInfo &TI,
ArrayRef<Expr *> AdjustArgsNothing,
ArrayRef<Expr *> AdjustArgsNeedDevicePtr,
ArrayRef<Expr *> AdjustArgsNeedDeviceAddr,
ArrayRef<OMPInteropInfo> AppendArgs, SourceLocation AdjustArgsLoc,
SourceLocation AppendArgsLoc, SourceRange SR);

Expand Down
6 changes: 6 additions & 0 deletions clang/lib/AST/AttrImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ void OMPDeclareVariantAttr::printPrettyPragma(
PrintExprs(adjustArgsNeedDevicePtr_begin(), adjustArgsNeedDevicePtr_end());
OS << ")";
}
if (adjustArgsNeedDeviceAddr_size()) {
OS << " adjust_args(need_device_addr:";
PrintExprs(adjustArgsNeedDeviceAddr_begin(),
adjustArgsNeedDeviceAddr_end());
OS << ")";
}

auto PrintInteropInfo = [&OS](OMPInteropInfo *Begin, OMPInteropInfo *End) {
for (OMPInteropInfo *I = Begin; I != End; ++I) {
Expand Down
28 changes: 20 additions & 8 deletions clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,7 @@ void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
OMPTraitInfo &TI = ASTCtx.getNewOMPTraitInfo();
SmallVector<Expr *, 6> AdjustNothing;
SmallVector<Expr *, 6> AdjustNeedDevicePtr;
SmallVector<Expr *, 6> AdjustNeedDeviceAddr;
SmallVector<OMPInteropInfo, 3> AppendArgs;
SourceLocation AdjustArgsLoc, AppendArgsLoc;

Expand Down Expand Up @@ -1515,11 +1516,21 @@ void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
SmallVector<Expr *> Vars;
IsError = ParseOpenMPVarList(OMPD_declare_variant, OMPC_adjust_args,
Vars, Data);
if (!IsError)
llvm::append_range(Data.ExtraModifier == OMPC_ADJUST_ARGS_nothing
? AdjustNothing
: AdjustNeedDevicePtr,
Vars);
if (!IsError) {
switch (Data.ExtraModifier) {
case OMPC_ADJUST_ARGS_nothing:
llvm::append_range(AdjustNothing, Vars);
break;
case OMPC_ADJUST_ARGS_need_device_ptr:
llvm::append_range(AdjustNeedDevicePtr, Vars);
break;
case OMPC_ADJUST_ARGS_need_device_addr:
llvm::append_range(AdjustNeedDeviceAddr, Vars);
break;
default:
llvm_unreachable("Unexpected 'adjust_args' clause modifier.");
}
}
break;
}
case OMPC_append_args:
Expand Down Expand Up @@ -1559,8 +1570,8 @@ void Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
if (DeclVarData && !TI.Sets.empty())
Actions.OpenMP().ActOnOpenMPDeclareVariantDirective(
DeclVarData->first, DeclVarData->second, TI, AdjustNothing,
AdjustNeedDevicePtr, AppendArgs, AdjustArgsLoc, AppendArgsLoc,
SourceRange(Loc, Tok.getLocation()));
AdjustNeedDevicePtr, AdjustNeedDeviceAddr, AppendArgs, AdjustArgsLoc,
AppendArgsLoc, SourceRange(Loc, Tok.getLocation()));

// Skip the last annot_pragma_openmp_end.
(void)ConsumeAnnotationToken();
Expand Down Expand Up @@ -4818,7 +4829,8 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind,
getLangOpts());
Data.ExtraModifierLoc = Tok.getLocation();
if (Data.ExtraModifier == OMPC_ADJUST_ARGS_unknown) {
Diag(Tok, diag::err_omp_unknown_adjust_args_op);
Diag(Tok, diag::err_omp_unknown_adjust_args_op)
<< (getLangOpts().OpenMP >= 60 ? 1 : 0);
SkipUntil(tok::r_paren, tok::annot_pragma_openmp_end, StopBeforeMatch);
} else {
ConsumeToken();
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7122,6 +7122,7 @@ void SemaOpenMP::ActOnFinishedFunctionDefinitionInOpenMPDeclareVariantScope(
getASTContext(), VariantFuncRef, DVScope.TI,
/*NothingArgs=*/nullptr, /*NothingArgsSize=*/0,
/*NeedDevicePtrArgs=*/nullptr, /*NeedDevicePtrArgsSize=*/0,
/*NeedDeviceAddrArgs=*/nullptr, /*NeedDeviceAddrArgsSize=*/0,
/*AppendArgs=*/nullptr, /*AppendArgsSize=*/0);
for (FunctionDecl *BaseFD : Bases)
BaseFD->addAttr(OMPDeclareVariantA);
Expand Down Expand Up @@ -7553,6 +7554,7 @@ void SemaOpenMP::ActOnOpenMPDeclareVariantDirective(
FunctionDecl *FD, Expr *VariantRef, OMPTraitInfo &TI,
ArrayRef<Expr *> AdjustArgsNothing,
ArrayRef<Expr *> AdjustArgsNeedDevicePtr,
ArrayRef<Expr *> AdjustArgsNeedDeviceAddr,
ArrayRef<OMPInteropInfo> AppendArgs, SourceLocation AdjustArgsLoc,
SourceLocation AppendArgsLoc, SourceRange SR) {

Expand All @@ -7564,6 +7566,7 @@ void SemaOpenMP::ActOnOpenMPDeclareVariantDirective(
SmallVector<Expr *, 8> AllAdjustArgs;
llvm::append_range(AllAdjustArgs, AdjustArgsNothing);
llvm::append_range(AllAdjustArgs, AdjustArgsNeedDevicePtr);
llvm::append_range(AllAdjustArgs, AdjustArgsNeedDeviceAddr);

if (!AllAdjustArgs.empty() || !AppendArgs.empty()) {
VariantMatchInfo VMI;
Expand Down Expand Up @@ -7614,6 +7617,8 @@ void SemaOpenMP::ActOnOpenMPDeclareVariantDirective(
const_cast<Expr **>(AdjustArgsNothing.data()), AdjustArgsNothing.size(),
const_cast<Expr **>(AdjustArgsNeedDevicePtr.data()),
AdjustArgsNeedDevicePtr.size(),
const_cast<Expr **>(AdjustArgsNeedDeviceAddr.data()),
AdjustArgsNeedDeviceAddr.size(),
const_cast<OMPInteropInfo *>(AppendArgs.data()), AppendArgs.size(), SR);
FD->addAttr(NewAttr);
}
Expand Down
11 changes: 9 additions & 2 deletions clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ static void instantiateOMPDeclareVariantAttr(

SmallVector<Expr *, 8> NothingExprs;
SmallVector<Expr *, 8> NeedDevicePtrExprs;
SmallVector<Expr *, 8> NeedDeviceAddrExprs;
SmallVector<OMPInteropInfo, 4> AppendArgs;

for (Expr *E : Attr.adjustArgsNothing()) {
Expand All @@ -541,14 +542,20 @@ static void instantiateOMPDeclareVariantAttr(
continue;
NeedDevicePtrExprs.push_back(ER.get());
}
for (Expr *E : Attr.adjustArgsNeedDeviceAddr()) {
ExprResult ER = Subst(E);
if (ER.isInvalid())
continue;
NeedDeviceAddrExprs.push_back(ER.get());
}
for (OMPInteropInfo &II : Attr.appendArgs()) {
// When prefer_type is implemented for append_args handle them here too.
AppendArgs.emplace_back(II.IsTarget, II.IsTargetSync);
}

S.OpenMP().ActOnOpenMPDeclareVariantDirective(
FD, E, TI, NothingExprs, NeedDevicePtrExprs, AppendArgs, SourceLocation(),
SourceLocation(), Attr.getRange());
FD, E, TI, NothingExprs, NeedDevicePtrExprs, NeedDeviceAddrExprs,
AppendArgs, SourceLocation(), SourceLocation(), Attr.getRange());
}

static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
Expand Down
26 changes: 16 additions & 10 deletions clang/test/OpenMP/declare_variant_clauses_ast_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ void foo_v3(float *AAA, float *BBB, int *I) {return;}
//DUMP: DeclRefExpr{{.*}}Function{{.*}}foo_v1
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'AAA'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'BBB'
//PRINT: #pragma omp declare variant(foo_v3) match(construct={dispatch}, device={arch(x86, x86_64)}) adjust_args(nothing:I) adjust_args(need_device_ptr:BBB)
//PRINT: #pragma omp declare variant(foo_v3) match(construct={dispatch}, device={arch(x86, x86_64)}) adjust_args(nothing:I) adjust_args(need_device_ptr:BBB) adjust_args(need_device_addr:AAA)

//PRINT: #pragma omp declare variant(foo_v2) match(construct={dispatch}, device={arch(ppc)}) adjust_args(need_device_ptr:AAA)
//PRINT: #pragma omp declare variant(foo_v2) match(construct={dispatch}, device={arch(ppc)}) adjust_args(need_device_ptr:AAA) adjust_args(need_device_addr:BBB)

//PRINT: omp declare variant(foo_v1) match(construct={dispatch}, device={arch(arm)}) adjust_args(need_device_ptr:AAA,BBB)

Expand All @@ -66,42 +66,48 @@ void foo_v3(float *AAA, float *BBB, int *I) {return;}

#pragma omp declare variant(foo_v2) \
match(construct={dispatch}, device={arch(ppc)}), \
adjust_args(need_device_ptr:AAA)
adjust_args(need_device_ptr:AAA) \
adjust_args(need_device_addr:BBB)

#pragma omp declare variant(foo_v3) \
adjust_args(need_device_ptr:BBB) adjust_args(nothing:I) \
adjust_args(need_device_addr:AAA) \
match(construct={dispatch}, device={arch(x86,x86_64)})

void foo(float *AAA, float *BBB, int *I) {return;}

void Foo_Var(float *AAA, float *BBB) {return;}
void Foo_Var(float *AAA, float *BBB, float *CCC) {return;}

#pragma omp declare variant(Foo_Var) \
match(construct={dispatch}, device={arch(x86_64)}) \
adjust_args(need_device_ptr:AAA) adjust_args(nothing:BBB)
adjust_args(need_device_ptr:AAA) adjust_args(nothing:BBB) \
adjust_args(need_device_addr:CCC)
template<typename T>
void Foo(T *AAA, T *BBB) {return;}
void Foo(T *AAA, T *BBB, T *CCC) {return;}

//PRINT: #pragma omp declare variant(Foo_Var) match(construct={dispatch}, device={arch(x86_64)}) adjust_args(nothing:BBB) adjust_args(need_device_ptr:AAA)
//DUMP: FunctionDecl{{.*}} Foo 'void (T *, T *)'
//PRINT: #pragma omp declare variant(Foo_Var) match(construct={dispatch}, device={arch(x86_64)}) adjust_args(nothing:BBB) adjust_args(need_device_ptr:AAA) adjust_args(need_device_addr:CCC)
//DUMP: FunctionDecl{{.*}} Foo 'void (T *, T *, T *)'
//DUMP: OMPDeclareVariantAttr{{.*}}device={arch(x86_64)}
//DUMP: DeclRefExpr{{.*}}Function{{.*}}Foo_Var
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'BBB'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'AAA'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'CCC'
//
//DUMP: FunctionDecl{{.*}} Foo 'void (float *, float *)'
//DUMP: FunctionDecl{{.*}} Foo 'void (float *, float *, float *)'
//DUMP: OMPDeclareVariantAttr{{.*}}device={arch(x86_64)}
//DUMP: DeclRefExpr{{.*}}Function{{.*}}Foo_Var
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'BBB'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'AAA'
//DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'CCC'

void func()
{
float *A;
float *B;
float *C;

//#pragma omp dispatch
Foo(A, B);
Foo(A, B, C);
}

typedef void *omp_interop_t;
Expand Down
24 changes: 17 additions & 7 deletions clang/test/OpenMP/declare_variant_clauses_messages.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -std=c++11 \
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -std=c++11 \
// RUN: -DNO_INTEROP_T_DEF -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -Wno-strict-prototypes -DC -x c -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -std=c++11 -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-unknown-linux -fopenmp -fopenmp-version=60 -Wno-strict-prototypes -DC -x c -o - %s
// RUN: %clang_cc1 -verify -triple x86_64-pc-windows-msvc -fms-compatibility \
// RUN: -fopenmp -Wno-strict-prototypes -DC -DWIN -x c -o - %s
// RUN: -fopenmp -fopenmp-version=60 -Wno-strict-prototypes -DC -DWIN -x c -o - %s

#ifdef NO_INTEROP_T_DEF
void foo_v1(float *, void *);
Expand Down Expand Up @@ -114,6 +114,16 @@ void vararg_bar2(const char *fmt) { return; }
match(construct={dispatch}, device={arch(ppc)}), \
adjust_args(need_device_ptr:AAA) adjust_args(nothing:AAA)

// expected-error@+3 {{'adjust_arg' argument 'AAA' used in multiple clauses}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(arm)}) \
adjust_args(need_device_ptr:AAA,BBB) adjust_args(need_device_addr:AAA)

// expected-error@+3 {{'adjust_arg' argument 'AAA' used in multiple clauses}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(ppc)}), \
adjust_args(need_device_addr:AAA) adjust_args(nothing:AAA)

// expected-error@+2 {{use of undeclared identifier 'J'}}
#pragma omp declare variant(foo_v1) \
adjust_args(nothing:J) \
Expand Down Expand Up @@ -186,12 +196,12 @@ void vararg_bar2(const char *fmt) { return; }
// expected-error@+1 {{variant in '#pragma omp declare variant' with type 'void (float *, float *, int *, omp_interop_t)' (aka 'void (float *, float *, int *, void *)') is incompatible with type 'void (float *, float *, int *)'}}
#pragma omp declare variant(foo_v4) match(construct={dispatch})

// expected-error@+3 {{incorrect adjust_args type, expected 'need_device_ptr' or 'nothing'}}
// expected-error@+3 {{incorrect 'adjust_args' type, expected 'need_device_ptr', 'need_device_addr', or 'nothing'}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(arm)}) \
adjust_args(badaaop:AAA,BBB)

// expected-error@+3 {{incorrect adjust_args type, expected 'need_device_ptr' or 'nothing'}}
// expected-error@+3 {{incorrect 'adjust_args' type, expected 'need_device_ptr', 'need_device_addr', or 'nothing'}}
#pragma omp declare variant(foo_v1) \
match(construct={dispatch}, device={arch(arm)}) \
adjust_args(badaaop AAA,BBB)
Expand Down
Loading