Skip to content

Commit 5ac2385

Browse files
committed
[Clang][OpenMP] Add interchange directive
1 parent b15caff commit 5ac2385

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+7403
-5
lines changed

clang/include/clang-c/Index.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,10 @@ enum CXCursorKind {
21502150
*/
21512151
CXCursor_OMPReverseDirective = 307,
21522152

2153+
/** OpenMP interchange directive.
2154+
*/
2155+
CXCursor_OMPInterchangeDirective = 308,
2156+
21532157
/** OpenACC Compute Construct.
21542158
*/
21552159
CXCursor_OpenACCComputeConstruct = 320,

clang/include/clang/AST/OpenMPClause.h

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,106 @@ class OMPSizesClause final
870870
}
871871
};
872872

873+
/// This class represents the 'permutation' clause in the
874+
/// '#pragma omp interchange' directive.
875+
///
876+
/// \code{c}
877+
/// #pragma omp interchange permutation(2,1)
878+
/// for (int i = 0; i < 64; ++i)
879+
/// for (int j = 0; j < 64; ++j)
880+
/// \endcode
881+
class OMPPermutationClause final
882+
: public OMPClause,
883+
private llvm::TrailingObjects<OMPSizesClause, Expr *> {
884+
friend class OMPClauseReader;
885+
friend class llvm::TrailingObjects<OMPSizesClause, Expr *>;
886+
887+
/// Location of '('.
888+
SourceLocation LParenLoc;
889+
890+
/// Number of arguments in the clause, and hence also the number of loops to
891+
/// be permuted.
892+
unsigned NumLoops;
893+
894+
/// Sets the permutation index expressions.
895+
void setArgRefs(ArrayRef<Expr *> VL) {
896+
assert(VL.size() == NumLoops);
897+
std::copy(VL.begin(), VL.end(),
898+
static_cast<OMPPermutationClause *>(this)
899+
->template getTrailingObjects<Expr *>());
900+
}
901+
902+
/// Build an empty clause.
903+
explicit OMPPermutationClause(int NumLoops)
904+
: OMPClause(llvm::omp::OMPC_permutation, SourceLocation(),
905+
SourceLocation()),
906+
NumLoops(NumLoops) {}
907+
908+
public:
909+
/// Build a 'permutation' clause AST node.
910+
///
911+
/// \param C Context of the AST.
912+
/// \param StartLoc Location of the 'permutation' identifier.
913+
/// \param LParenLoc Location of '('.
914+
/// \param EndLoc Location of ')'.
915+
/// \param Args Content of the clause.
916+
static OMPPermutationClause *
917+
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
918+
SourceLocation EndLoc, ArrayRef<Expr *> Args);
919+
920+
/// Build an empty 'permutation' AST node for deserialization.
921+
///
922+
/// \param C Context of the AST.
923+
/// \param NumLoops Number of arguments in the clause.
924+
static OMPPermutationClause *CreateEmpty(const ASTContext &C,
925+
unsigned NumLoops);
926+
927+
/// Sets the location of '('.
928+
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
929+
930+
/// Returns the location of '('.
931+
SourceLocation getLParenLoc() const { return LParenLoc; }
932+
933+
/// Returns the number of list items.
934+
unsigned getNumLoops() const { return NumLoops; }
935+
936+
/// Returns the permutation index expressions.
937+
///@{
938+
MutableArrayRef<Expr *> getArgsRefs() {
939+
return MutableArrayRef<Expr *>(static_cast<OMPPermutationClause *>(this)
940+
->template getTrailingObjects<Expr *>(),
941+
NumLoops);
942+
}
943+
ArrayRef<Expr *> getArgsRefs() const {
944+
return ArrayRef<Expr *>(static_cast<const OMPPermutationClause *>(this)
945+
->template getTrailingObjects<Expr *>(),
946+
NumLoops);
947+
}
948+
///@}
949+
950+
child_range children() {
951+
MutableArrayRef<Expr *> Args = getArgsRefs();
952+
return child_range(reinterpret_cast<Stmt **>(Args.begin()),
953+
reinterpret_cast<Stmt **>(Args.end()));
954+
}
955+
const_child_range children() const {
956+
ArrayRef<Expr *> Args = getArgsRefs();
957+
return const_child_range(reinterpret_cast<Stmt *const *>(Args.begin()),
958+
reinterpret_cast<Stmt *const *>(Args.end()));
959+
}
960+
961+
child_range used_children() {
962+
return child_range(child_iterator(), child_iterator());
963+
}
964+
const_child_range used_children() const {
965+
return const_child_range(const_child_iterator(), const_child_iterator());
966+
}
967+
968+
static bool classof(const OMPClause *T) {
969+
return T->getClauseKind() == llvm::omp::OMPC_permutation;
970+
}
971+
};
972+
873973
/// Representation of the 'full' clause of the '#pragma omp unroll' directive.
874974
///
875975
/// \code

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,6 +3024,9 @@ DEF_TRAVERSE_STMT(OMPUnrollDirective,
30243024
DEF_TRAVERSE_STMT(OMPReverseDirective,
30253025
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
30263026

3027+
DEF_TRAVERSE_STMT(OMPInterchangeDirective,
3028+
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
3029+
30273030
DEF_TRAVERSE_STMT(OMPForDirective,
30283031
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
30293032

@@ -3322,6 +3325,14 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) {
33223325
return true;
33233326
}
33243327

3328+
template <typename Derived>
3329+
bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause(
3330+
OMPPermutationClause *C) {
3331+
for (Expr *E : C->getArgsRefs())
3332+
TRY_TO(TraverseStmt(E));
3333+
return true;
3334+
}
3335+
33253336
template <typename Derived>
33263337
bool RecursiveASTVisitor<Derived>::VisitOMPFullClause(OMPFullClause *C) {
33273338
return true;

clang/include/clang/AST/StmtOpenMP.h

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
10091009
static bool classof(const Stmt *T) {
10101010
Stmt::StmtClass C = T->getStmtClass();
10111011
return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
1012-
C == OMPReverseDirectiveClass;
1012+
C == OMPReverseDirectiveClass || C == OMPInterchangeDirectiveClass;
10131013
}
10141014
};
10151015

@@ -5779,6 +5779,80 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective {
57795779
}
57805780
};
57815781

5782+
/// Represents the '#pragma omp interchange' loop transformation directive.
5783+
///
5784+
/// \code{c}
5785+
/// #pragma omp interchange
5786+
/// for (int i = 0; i < m; ++i)
5787+
/// for (int j = 0; j < n; ++j)
5788+
/// ..
5789+
/// \endcode
5790+
class OMPInterchangeDirective final : public OMPLoopTransformationDirective {
5791+
friend class ASTStmtReader;
5792+
friend class OMPExecutableDirective;
5793+
5794+
/// Offsets of child members.
5795+
enum {
5796+
PreInitsOffset = 0,
5797+
TransformedStmtOffset,
5798+
};
5799+
5800+
explicit OMPInterchangeDirective(SourceLocation StartLoc,
5801+
SourceLocation EndLoc, unsigned NumLoops)
5802+
: OMPLoopTransformationDirective(OMPInterchangeDirectiveClass,
5803+
llvm::omp::OMPD_interchange, StartLoc,
5804+
EndLoc, NumLoops) {
5805+
setNumGeneratedLoops(3 * NumLoops);
5806+
}
5807+
5808+
void setPreInits(Stmt *PreInits) {
5809+
Data->getChildren()[PreInitsOffset] = PreInits;
5810+
}
5811+
5812+
void setTransformedStmt(Stmt *S) {
5813+
Data->getChildren()[TransformedStmtOffset] = S;
5814+
}
5815+
5816+
public:
5817+
/// Create a new AST node representation for '#pragma omp interchange'.
5818+
///
5819+
/// \param C Context of the AST.
5820+
/// \param StartLoc Location of the introducer (e.g. the 'omp' token).
5821+
/// \param EndLoc Location of the directive's end (e.g. the tok::eod).
5822+
/// \param Clauses The directive's clauses.
5823+
/// \param NumLoops Number of affected loops
5824+
/// (number of items in the 'permutation' clause if present).
5825+
/// \param AssociatedStmt The outermost associated loop.
5826+
/// \param TransformedStmt The loop nest after tiling, or nullptr in
5827+
/// dependent contexts.
5828+
/// \param PreInits Helper preinits statements for the loop nest.
5829+
static OMPInterchangeDirective *
5830+
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
5831+
ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
5832+
Stmt *TransformedStmt, Stmt *PreInits);
5833+
5834+
/// Build an empty '#pragma omp interchange' AST node for deserialization.
5835+
///
5836+
/// \param C Context of the AST.
5837+
/// \param NumClauses Number of clauses to allocate.
5838+
/// \param NumLoops Number of associated loops to allocate.
5839+
static OMPInterchangeDirective *
5840+
CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned NumLoops);
5841+
5842+
/// Gets the associated loops after the transformation. This is the de-sugared
5843+
/// replacement or nullptr in dependent contexts.
5844+
Stmt *getTransformedStmt() const {
5845+
return Data->getChildren()[TransformedStmtOffset];
5846+
}
5847+
5848+
/// Return preinits statement.
5849+
Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
5850+
5851+
static bool classof(const Stmt *T) {
5852+
return T->getStmtClass() == OMPInterchangeDirectiveClass;
5853+
}
5854+
};
5855+
57825856
/// This represents '#pragma omp scan' directive.
57835857
///
57845858
/// \code

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11521,6 +11521,10 @@ def err_omp_dispatch_statement_call
1152111521
" to a target function or an assignment to one">;
1152211522
def err_omp_unroll_full_variable_trip_count : Error<
1152311523
"loop to be fully unrolled must have a constant trip count">;
11524+
def err_omp_interchange_permutation_value_range : Error<
11525+
"permutation index must be at least 1 and at most %0">;
11526+
def err_omp_interchange_permutation_value_repeated : Error<
11527+
"index %0 must appear exactly once in the permutation clause">;
1152411528
def note_omp_directive_here : Note<"'%0' directive found here">;
1152511529
def err_omp_instantiation_not_supported
1152611530
: Error<"instantiation of '%0' not supported yet">;

clang/include/clang/Basic/StmtNodes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
230230
def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
231231
def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
232232
def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
233+
def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>;
233234
def OMPForDirective : StmtNode<OMPLoopDirective>;
234235
def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
235236
def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;

clang/include/clang/Parse/Parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3542,6 +3542,9 @@ class Parser : public CodeCompletionHandler {
35423542
/// Parses the 'sizes' clause of a '#pragma omp tile' directive.
35433543
OMPClause *ParseOpenMPSizesClause();
35443544

3545+
/// Parses the 'permutation' clause of a '#pragma omp interchange' directive.
3546+
OMPClause *ParseOpenMPPermutationClause();
3547+
35453548
/// Parses clause without any additional arguments.
35463549
///
35473550
/// \param Kind Kind of current clause.

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,12 @@ class SemaOpenMP : public SemaBase {
427427
StmtResult ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
428428
Stmt *AStmt, SourceLocation StartLoc,
429429
SourceLocation EndLoc);
430+
/// Called on well-formed '#pragma omp interchange' after parsing of its
431+
/// clauses and the associated statement.
432+
StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses,
433+
Stmt *AStmt,
434+
SourceLocation StartLoc,
435+
SourceLocation EndLoc);
430436
/// Called on well-formed '\#pragma omp for' after parsing
431437
/// of the associated statement.
432438
StmtResult
@@ -864,6 +870,11 @@ class SemaOpenMP : public SemaBase {
864870
SourceLocation StartLoc,
865871
SourceLocation LParenLoc,
866872
SourceLocation EndLoc);
873+
/// Called on well-form 'permutation' clause after parsing its arguments.
874+
OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
875+
SourceLocation StartLoc,
876+
SourceLocation LParenLoc,
877+
SourceLocation EndLoc);
867878
/// Called on well-form 'full' clauses.
868879
OMPClause *ActOnOpenMPFullClause(SourceLocation StartLoc,
869880
SourceLocation EndLoc);

clang/include/clang/Serialization/ASTBitCodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,7 @@ enum StmtCode {
18571857
STMT_OMP_TILE_DIRECTIVE,
18581858
STMT_OMP_UNROLL_DIRECTIVE,
18591859
STMT_OMP_REVERSE_DIRECTIVE,
1860+
STMT_OMP_INTERCHANGE_DIRECTIVE,
18601861
STMT_OMP_FOR_DIRECTIVE,
18611862
STMT_OMP_FOR_SIMD_DIRECTIVE,
18621863
STMT_OMP_SECTIONS_DIRECTIVE,

clang/lib/AST/OpenMPClause.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,25 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C,
971971
return new (Mem) OMPSizesClause(NumSizes);
972972
}
973973

974+
OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C,
975+
SourceLocation StartLoc,
976+
SourceLocation LParenLoc,
977+
SourceLocation EndLoc,
978+
ArrayRef<Expr *> Args) {
979+
OMPPermutationClause *Clause = CreateEmpty(C, Args.size());
980+
Clause->setLocStart(StartLoc);
981+
Clause->setLParenLoc(LParenLoc);
982+
Clause->setLocEnd(EndLoc);
983+
Clause->setArgRefs(Args);
984+
return Clause;
985+
}
986+
987+
OMPPermutationClause *OMPPermutationClause::CreateEmpty(const ASTContext &C,
988+
unsigned NumLoops) {
989+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumLoops));
990+
return new (Mem) OMPPermutationClause(NumLoops);
991+
}
992+
974993
OMPFullClause *OMPFullClause::Create(const ASTContext &C,
975994
SourceLocation StartLoc,
976995
SourceLocation EndLoc) {
@@ -1774,6 +1793,18 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) {
17741793
OS << ")";
17751794
}
17761795

1796+
void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) {
1797+
OS << "permutation(";
1798+
bool First = true;
1799+
for (Expr *Size : Node->getArgsRefs()) {
1800+
if (!First)
1801+
OS << ", ";
1802+
Size->printPretty(OS, nullptr, Policy, 0);
1803+
First = false;
1804+
}
1805+
OS << ")";
1806+
}
1807+
17771808
void OMPClausePrinter::VisitOMPFullClause(OMPFullClause *Node) { OS << "full"; }
17781809

17791810
void OMPClausePrinter::VisitOMPPartialClause(OMPPartialClause *Node) {

clang/lib/AST/StmtOpenMP.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,26 @@ OMPReverseDirective *OMPReverseDirective::CreateEmpty(const ASTContext &C,
468468
SourceLocation(), SourceLocation());
469469
}
470470

471+
OMPInterchangeDirective *OMPInterchangeDirective::Create(
472+
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
473+
ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
474+
Stmt *TransformedStmt, Stmt *PreInits) {
475+
OMPInterchangeDirective *Dir = createDirective<OMPInterchangeDirective>(
476+
C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc,
477+
NumLoops);
478+
Dir->setTransformedStmt(TransformedStmt);
479+
Dir->setPreInits(PreInits);
480+
return Dir;
481+
}
482+
483+
OMPInterchangeDirective *
484+
OMPInterchangeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses,
485+
unsigned NumLoops) {
486+
return createEmptyDirective<OMPInterchangeDirective>(
487+
C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
488+
SourceLocation(), SourceLocation(), NumLoops);
489+
}
490+
471491
OMPForSimdDirective *
472492
OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc,
473493
SourceLocation EndLoc, unsigned CollapsedNum,

clang/lib/AST/StmtPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,11 @@ void StmtPrinter::VisitOMPReverseDirective(OMPReverseDirective *Node) {
768768
PrintOMPExecutableDirective(Node);
769769
}
770770

771+
void StmtPrinter::VisitOMPInterchangeDirective(OMPInterchangeDirective *Node) {
772+
Indent() << "#pragma omp interchange";
773+
PrintOMPExecutableDirective(Node);
774+
}
775+
771776
void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) {
772777
Indent() << "#pragma omp for";
773778
PrintOMPExecutableDirective(Node);

0 commit comments

Comments
 (0)