Skip to content

Commit 5b03efb

Browse files
authored
[Clang][OpenMP] Add permutation clause (#92030)
Add the permutation clause for the interchange directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in [Technical Report12](https://www.openmp.org/wp-content/uploads/openmp-TR12.pdf).
1 parent 32db6fb commit 5b03efb

File tree

23 files changed

+3946
-17
lines changed

23 files changed

+3946
-17
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,105 @@ class OMPSizesClause final
930930
}
931931
};
932932

933+
/// This class represents the 'permutation' clause in the
934+
/// '#pragma omp interchange' directive.
935+
///
936+
/// \code{.c}
937+
/// #pragma omp interchange permutation(2,1)
938+
/// for (int i = 0; i < 64; ++i)
939+
/// for (int j = 0; j < 64; ++j)
940+
/// \endcode
941+
class OMPPermutationClause final
942+
: public OMPClause,
943+
private llvm::TrailingObjects<OMPSizesClause, Expr *> {
944+
friend class OMPClauseReader;
945+
friend class llvm::TrailingObjects<OMPSizesClause, Expr *>;
946+
947+
/// Location of '('.
948+
SourceLocation LParenLoc;
949+
950+
/// Number of arguments in the clause, and hence also the number of loops to
951+
/// be permuted.
952+
unsigned NumLoops;
953+
954+
/// Sets the permutation index expressions.
955+
void setArgRefs(ArrayRef<Expr *> VL) {
956+
assert(VL.size() == NumLoops && "Expecting one expression per loop");
957+
llvm::copy(VL, static_cast<OMPPermutationClause *>(this)
958+
->template getTrailingObjects<Expr *>());
959+
}
960+
961+
/// Build an empty clause.
962+
explicit OMPPermutationClause(int NumLoops)
963+
: OMPClause(llvm::omp::OMPC_permutation, SourceLocation(),
964+
SourceLocation()),
965+
NumLoops(NumLoops) {}
966+
967+
public:
968+
/// Build a 'permutation' clause AST node.
969+
///
970+
/// \param C Context of the AST.
971+
/// \param StartLoc Location of the 'permutation' identifier.
972+
/// \param LParenLoc Location of '('.
973+
/// \param EndLoc Location of ')'.
974+
/// \param Args Content of the clause.
975+
static OMPPermutationClause *
976+
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
977+
SourceLocation EndLoc, ArrayRef<Expr *> Args);
978+
979+
/// Build an empty 'permutation' AST node for deserialization.
980+
///
981+
/// \param C Context of the AST.
982+
/// \param NumLoops Number of arguments in the clause.
983+
static OMPPermutationClause *CreateEmpty(const ASTContext &C,
984+
unsigned NumLoops);
985+
986+
/// Sets the location of '('.
987+
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
988+
989+
/// Returns the location of '('.
990+
SourceLocation getLParenLoc() const { return LParenLoc; }
991+
992+
/// Returns the number of list items.
993+
unsigned getNumLoops() const { return NumLoops; }
994+
995+
/// Returns the permutation index expressions.
996+
///@{
997+
MutableArrayRef<Expr *> getArgsRefs() {
998+
return MutableArrayRef<Expr *>(static_cast<OMPPermutationClause *>(this)
999+
->template getTrailingObjects<Expr *>(),
1000+
NumLoops);
1001+
}
1002+
ArrayRef<Expr *> getArgsRefs() const {
1003+
return ArrayRef<Expr *>(static_cast<const OMPPermutationClause *>(this)
1004+
->template getTrailingObjects<Expr *>(),
1005+
NumLoops);
1006+
}
1007+
///@}
1008+
1009+
child_range children() {
1010+
MutableArrayRef<Expr *> Args = getArgsRefs();
1011+
return child_range(reinterpret_cast<Stmt **>(Args.begin()),
1012+
reinterpret_cast<Stmt **>(Args.end()));
1013+
}
1014+
const_child_range children() const {
1015+
ArrayRef<Expr *> Args = getArgsRefs();
1016+
return const_child_range(reinterpret_cast<Stmt *const *>(Args.begin()),
1017+
reinterpret_cast<Stmt *const *>(Args.end()));
1018+
}
1019+
1020+
child_range used_children() {
1021+
return child_range(child_iterator(), child_iterator());
1022+
}
1023+
const_child_range used_children() const {
1024+
return const_child_range(const_child_iterator(), const_child_iterator());
1025+
}
1026+
1027+
static bool classof(const OMPClause *T) {
1028+
return T->getClauseKind() == llvm::omp::OMPC_permutation;
1029+
}
1030+
};
1031+
9331032
/// Representation of the 'full' clause of the '#pragma omp unroll' directive.
9341033
///
9351034
/// \code

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3348,6 +3348,14 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) {
33483348
return true;
33493349
}
33503350

3351+
template <typename Derived>
3352+
bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause(
3353+
OMPPermutationClause *C) {
3354+
for (Expr *E : C->getArgsRefs())
3355+
TRY_TO(TraverseStmt(E));
3356+
return true;
3357+
}
3358+
33513359
template <typename Derived>
33523360
bool RecursiveASTVisitor<Derived>::VisitOMPFullClause(OMPFullClause *C) {
33533361
return true;

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11702,6 +11702,10 @@ def err_omp_dispatch_statement_call
1170211702
" to a target function or an assignment to one">;
1170311703
def err_omp_unroll_full_variable_trip_count : Error<
1170411704
"loop to be fully unrolled must have a constant trip count">;
11705+
def err_omp_interchange_permutation_value_range : Error<
11706+
"permutation index must be at least 1 and at most %0">;
11707+
def err_omp_interchange_permutation_value_repeated : Error<
11708+
"index %0 must appear exactly once in the permutation clause">;
1170511709
def note_omp_directive_here : Note<"'%0' directive found here">;
1170611710
def err_omp_instantiation_not_supported
1170711711
: Error<"instantiation of '%0' not supported yet">;

clang/include/clang/Parse/Parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,6 +3595,9 @@ class Parser : public CodeCompletionHandler {
35953595
/// Parses the 'sizes' clause of a '#pragma omp tile' directive.
35963596
OMPClause *ParseOpenMPSizesClause();
35973597

3598+
/// Parses the 'permutation' clause of a '#pragma omp interchange' directive.
3599+
OMPClause *ParseOpenMPPermutationClause();
3600+
35983601
/// Parses clause without any additional arguments.
35993602
///
36003603
/// \param Kind Kind of current clause.

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,11 @@ class SemaOpenMP : public SemaBase {
891891
SourceLocation StartLoc,
892892
SourceLocation LParenLoc,
893893
SourceLocation EndLoc);
894+
/// Called on well-form 'permutation' clause after parsing its arguments.
895+
OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
896+
SourceLocation StartLoc,
897+
SourceLocation LParenLoc,
898+
SourceLocation EndLoc);
894899
/// Called on well-form 'full' clauses.
895900
OMPClause *ActOnOpenMPFullClause(SourceLocation StartLoc,
896901
SourceLocation EndLoc);

clang/lib/AST/OpenMPClause.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,25 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C,
971971
return new (Mem) OMPSizesClause(NumSizes);
972972
}
973973

974+
OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C,
975+
SourceLocation StartLoc,
976+
SourceLocation LParenLoc,
977+
SourceLocation EndLoc,
978+
ArrayRef<Expr *> Args) {
979+
OMPPermutationClause *Clause = CreateEmpty(C, Args.size());
980+
Clause->setLocStart(StartLoc);
981+
Clause->setLParenLoc(LParenLoc);
982+
Clause->setLocEnd(EndLoc);
983+
Clause->setArgRefs(Args);
984+
return Clause;
985+
}
986+
987+
OMPPermutationClause *OMPPermutationClause::CreateEmpty(const ASTContext &C,
988+
unsigned NumLoops) {
989+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumLoops));
990+
return new (Mem) OMPPermutationClause(NumLoops);
991+
}
992+
974993
OMPFullClause *OMPFullClause::Create(const ASTContext &C,
975994
SourceLocation StartLoc,
976995
SourceLocation EndLoc) {
@@ -1841,6 +1860,14 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) {
18411860
OS << ")";
18421861
}
18431862

1863+
void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) {
1864+
OS << "permutation(";
1865+
llvm::interleaveComma(Node->getArgsRefs(), OS, [&](const Expr *E) {
1866+
E->printPretty(OS, nullptr, Policy, 0);
1867+
});
1868+
OS << ")";
1869+
}
1870+
18441871
void OMPClausePrinter::VisitOMPFullClause(OMPFullClause *Node) { OS << "full"; }
18451872

18461873
void OMPClausePrinter::VisitOMPPartialClause(OMPPartialClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,13 @@ void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) {
493493
Profiler->VisitExpr(E);
494494
}
495495

496+
void OMPClauseProfiler::VisitOMPPermutationClause(
497+
const OMPPermutationClause *C) {
498+
for (Expr *E : C->getArgsRefs())
499+
if (E)
500+
Profiler->VisitExpr(E);
501+
}
502+
496503
void OMPClauseProfiler::VisitOMPFullClause(const OMPFullClause *C) {}
497504

498505
void OMPClauseProfiler::VisitOMPPartialClause(const OMPPartialClause *C) {

clang/lib/Basic/OpenMPKinds.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
188188
case OMPC_safelen:
189189
case OMPC_simdlen:
190190
case OMPC_sizes:
191+
case OMPC_permutation:
191192
case OMPC_allocator:
192193
case OMPC_allocate:
193194
case OMPC_collapse:
@@ -512,6 +513,7 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
512513
case OMPC_safelen:
513514
case OMPC_simdlen:
514515
case OMPC_sizes:
516+
case OMPC_permutation:
515517
case OMPC_allocator:
516518
case OMPC_allocate:
517519
case OMPC_collapse:

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3080,6 +3080,18 @@ OMPClause *Parser::ParseOpenMPSizesClause() {
30803080
OpenLoc, CloseLoc);
30813081
}
30823082

3083+
OMPClause *Parser::ParseOpenMPPermutationClause() {
3084+
SourceLocation ClauseNameLoc, OpenLoc, CloseLoc;
3085+
SmallVector<Expr *> ArgExprs;
3086+
if (ParseOpenMPExprListClause(OMPC_permutation, ClauseNameLoc, OpenLoc,
3087+
CloseLoc, ArgExprs,
3088+
/*ReqIntConst=*/true))
3089+
return nullptr;
3090+
3091+
return Actions.OpenMP().ActOnOpenMPPermutationClause(ArgExprs, ClauseNameLoc,
3092+
OpenLoc, CloseLoc);
3093+
}
3094+
30833095
OMPClause *Parser::ParseOpenMPUsesAllocatorClause(OpenMPDirectiveKind DKind) {
30843096
SourceLocation Loc = Tok.getLocation();
30853097
ConsumeAnyToken();
@@ -3377,6 +3389,14 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
33773389

33783390
Clause = ParseOpenMPSizesClause();
33793391
break;
3392+
case OMPC_permutation:
3393+
if (!FirstClause) {
3394+
Diag(Tok, diag::err_omp_more_one_clause)
3395+
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
3396+
ErrorFound = true;
3397+
}
3398+
Clause = ParseOpenMPPermutationClause();
3399+
break;
33803400
case OMPC_uses_allocators:
33813401
Clause = ParseOpenMPUsesAllocatorClause(DKind);
33823402
break;

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14956,7 +14956,9 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
1495614956
return StmtError();
1495714957

1495814958
// interchange without permutation clause swaps two loops.
14959-
constexpr size_t NumLoops = 2;
14959+
const OMPPermutationClause *PermutationClause =
14960+
OMPExecutableDirective::getSingleClause<OMPPermutationClause>(Clauses);
14961+
size_t NumLoops = PermutationClause ? PermutationClause->getNumLoops() : 2;
1496014962

1496114963
// Verify and diagnose loop nest.
1496214964
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
@@ -14971,6 +14973,12 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
1497114973
return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
1497214974
NumLoops, AStmt, nullptr, nullptr);
1497314975

14976+
// An invalid expression in the permutation clause is set to nullptr in
14977+
// ActOnOpenMPPermutationClause.
14978+
if (PermutationClause &&
14979+
llvm::is_contained(PermutationClause->getArgsRefs(), nullptr))
14980+
return StmtError();
14981+
1497414982
assert(LoopHelpers.size() == NumLoops &&
1497514983
"Expecting loop iteration space dimensionaly to match number of "
1497614984
"affected loops");
@@ -14979,7 +14987,44 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
1497914987
"affected loops");
1498014988

1498114989
// Decode the permutation clause.
14982-
constexpr uint64_t Permutation[] = {1, 0};
14990+
SmallVector<uint64_t, 2> Permutation;
14991+
if (!PermutationClause) {
14992+
Permutation = {1, 0};
14993+
} else {
14994+
ArrayRef<Expr *> PermArgs = PermutationClause->getArgsRefs();
14995+
llvm::BitVector Flags(PermArgs.size());
14996+
for (Expr *PermArg : PermArgs) {
14997+
std::optional<llvm::APSInt> PermCstExpr =
14998+
PermArg->getIntegerConstantExpr(Context);
14999+
if (!PermCstExpr)
15000+
continue;
15001+
uint64_t PermInt = PermCstExpr->getZExtValue();
15002+
assert(1 <= PermInt && PermInt <= NumLoops &&
15003+
"Must be a permutation; diagnostic emitted in "
15004+
"ActOnOpenMPPermutationClause");
15005+
if (Flags[PermInt - 1]) {
15006+
SourceRange ExprRange(PermArg->getBeginLoc(), PermArg->getEndLoc());
15007+
Diag(PermArg->getExprLoc(),
15008+
diag::err_omp_interchange_permutation_value_repeated)
15009+
<< PermInt << ExprRange;
15010+
continue;
15011+
}
15012+
Flags[PermInt - 1] = true;
15013+
15014+
Permutation.push_back(PermInt - 1);
15015+
}
15016+
15017+
if (Permutation.size() != NumLoops)
15018+
return StmtError();
15019+
}
15020+
15021+
// Nothing to transform with trivial permutation.
15022+
if (NumLoops <= 1 || llvm::all_of(llvm::enumerate(Permutation), [](auto P) {
15023+
auto [Idx, Arg] = P;
15024+
return Idx == Arg;
15025+
}))
15026+
return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
15027+
NumLoops, AStmt, AStmt, nullptr);
1498315028

1498415029
// Find the affected loops.
1498515030
SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
@@ -16111,6 +16156,44 @@ OMPClause *SemaOpenMP::ActOnOpenMPSizesClause(ArrayRef<Expr *> SizeExprs,
1611116156
SanitizedSizeExprs);
1611216157
}
1611316158

16159+
OMPClause *SemaOpenMP::ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
16160+
SourceLocation StartLoc,
16161+
SourceLocation LParenLoc,
16162+
SourceLocation EndLoc) {
16163+
size_t NumLoops = PermExprs.size();
16164+
SmallVector<Expr *> SanitizedPermExprs;
16165+
llvm::append_range(SanitizedPermExprs, PermExprs);
16166+
16167+
for (Expr *&PermExpr : SanitizedPermExprs) {
16168+
// Skip if template-dependent or already sanitized, e.g. during a partial
16169+
// template instantiation.
16170+
if (!PermExpr || PermExpr->isInstantiationDependent())
16171+
continue;
16172+
16173+
llvm::APSInt PermVal;
16174+
ExprResult PermEvalExpr = SemaRef.VerifyIntegerConstantExpression(
16175+
PermExpr, &PermVal, Sema::AllowFold);
16176+
bool IsValid = PermEvalExpr.isUsable();
16177+
if (IsValid)
16178+
PermExpr = PermEvalExpr.get();
16179+
16180+
if (IsValid && (PermVal < 1 || NumLoops < PermVal)) {
16181+
SourceRange ExprRange(PermEvalExpr.get()->getBeginLoc(),
16182+
PermEvalExpr.get()->getEndLoc());
16183+
Diag(PermEvalExpr.get()->getExprLoc(),
16184+
diag::err_omp_interchange_permutation_value_range)
16185+
<< NumLoops << ExprRange;
16186+
IsValid = false;
16187+
}
16188+
16189+
if (!PermExpr->isInstantiationDependent() && !IsValid)
16190+
PermExpr = nullptr;
16191+
}
16192+
16193+
return OMPPermutationClause::Create(getASTContext(), StartLoc, LParenLoc,
16194+
EndLoc, SanitizedPermExprs);
16195+
}
16196+
1611416197
OMPClause *SemaOpenMP::ActOnOpenMPFullClause(SourceLocation StartLoc,
1611516198
SourceLocation EndLoc) {
1611616199
return OMPFullClause::Create(getASTContext(), StartLoc, EndLoc);

0 commit comments

Comments
 (0)