Skip to content

[Clang][OpenMP] Add permutation clause #92030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d6057c4
[Clang][OpenMP] Fix tile/unroll on iterator- and foreach-loops
Meinersbur May 21, 2024
b15caff
[Clang][OpenMP] Add reverse directive
Meinersbur May 21, 2024
c2bd6a5
Extract out appendFlattendedStmtList
Meinersbur May 21, 2024
0a38da3
Address review
Meinersbur May 21, 2024
4bc8e76
Merge branch 'main' into users/meinersbur/clang_openmp_unroll-tile_fo…
Meinersbur May 22, 2024
8eb4b90
Address review comments
Meinersbur May 22, 2024
a2cd085
Merge branch 'users/meinersbur/clang_openmp_unroll-tile_foreach' into…
Meinersbur May 22, 2024
76634ad
[Clang][OpenMP] Add interchange directive
Meinersbur May 22, 2024
fa15956
[Clang][OpenMP] Add permutation clause
Meinersbur May 22, 2024
7e2e69c
ge branch 'main' into users/meinersbur/clang_openmp_reverse
Meinersbur May 24, 2024
6c88e43
Address review comment
Meinersbur May 24, 2024
3b8f0be
Merge branch 'users/meinersbur/clang_openmp_reverse' into users/meine…
Meinersbur May 24, 2024
7ef1d9e
Address review comment
Meinersbur May 24, 2024
e15dfd5
Merge branch 'users/meinersbur/clang_openmp_interchange' into clang_o…
Meinersbur May 24, 2024
0039dc4
Address reviewer comments
Meinersbur May 24, 2024
68e113b
Uppercase variable name
Meinersbur May 24, 2024
d68804f
Merge branch 'main' (early part) into users/meinersbur/clang_openmp_i…
Meinersbur Jul 18, 2024
d84db47
Compile fix
Meinersbur Jul 18, 2024
af01517
[Clang] Handle OMPInterchangeDirectiveClass in switch
Meinersbur Jul 18, 2024
9011adf
Merge branch 'main' into users/meinersbur/clang_openmp_interchange
Meinersbur Jul 18, 2024
f730b3d
Compile fix
Meinersbur Jul 18, 2024
8922a0f
Use SmallString
Meinersbur Jul 18, 2024
327432a
Merge remote-tracking branch 'official/users/meinersbur/clang_openmp_…
Meinersbur Jul 19, 2024
00363f8
Merge branch 'main' (early part) into clang_openmp_reverse+interchange
Meinersbur Jul 19, 2024
7b97f44
Merge branch 'main' into clang_openmp_reverse+interchange
Meinersbur Jul 19, 2024
8e3fa6d
Fix OMP.td
Meinersbur Jul 19, 2024
6ea3bf5
Build fix
Meinersbur Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,105 @@ class OMPSizesClause final
}
};

/// This class represents the 'permutation' clause in the
/// '#pragma omp interchange' directive.
///
/// \code{.c}
/// #pragma omp interchange permutation(2,1)
/// for (int i = 0; i < 64; ++i)
/// for (int j = 0; j < 64; ++j)
/// \endcode
class OMPPermutationClause final
: public OMPClause,
private llvm::TrailingObjects<OMPSizesClause, Expr *> {
friend class OMPClauseReader;
friend class llvm::TrailingObjects<OMPSizesClause, Expr *>;

/// Location of '('.
SourceLocation LParenLoc;

/// Number of arguments in the clause, and hence also the number of loops to
/// be permuted.
unsigned NumLoops;

/// Sets the permutation index expressions.
void setArgRefs(ArrayRef<Expr *> VL) {
assert(VL.size() == NumLoops && "Expecting one expression per loop");
llvm::copy(VL, static_cast<OMPPermutationClause *>(this)
->template getTrailingObjects<Expr *>());
}

/// Build an empty clause.
explicit OMPPermutationClause(int NumLoops)
: OMPClause(llvm::omp::OMPC_permutation, SourceLocation(),
SourceLocation()),
NumLoops(NumLoops) {}

public:
/// Build a 'permutation' clause AST node.
///
/// \param C Context of the AST.
/// \param StartLoc Location of the 'permutation' identifier.
/// \param LParenLoc Location of '('.
/// \param EndLoc Location of ')'.
/// \param Args Content of the clause.
static OMPPermutationClause *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> Args);

/// Build an empty 'permutation' AST node for deserialization.
///
/// \param C Context of the AST.
/// \param NumLoops Number of arguments in the clause.
static OMPPermutationClause *CreateEmpty(const ASTContext &C,
unsigned NumLoops);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Returns the number of list items.
unsigned getNumLoops() const { return NumLoops; }

/// Returns the permutation index expressions.
///@{
MutableArrayRef<Expr *> getArgsRefs() {
return MutableArrayRef<Expr *>(static_cast<OMPPermutationClause *>(this)
->template getTrailingObjects<Expr *>(),
NumLoops);
}
ArrayRef<Expr *> getArgsRefs() const {
return ArrayRef<Expr *>(static_cast<const OMPPermutationClause *>(this)
->template getTrailingObjects<Expr *>(),
NumLoops);
}
///@}

child_range children() {
MutableArrayRef<Expr *> Args = getArgsRefs();
return child_range(reinterpret_cast<Stmt **>(Args.begin()),
reinterpret_cast<Stmt **>(Args.end()));
}
const_child_range children() const {
ArrayRef<Expr *> Args = getArgsRefs();
return const_child_range(reinterpret_cast<Stmt *const *>(Args.begin()),
reinterpret_cast<Stmt *const *>(Args.end()));
}

child_range used_children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range used_children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}

static bool classof(const OMPClause *T) {
return T->getClauseKind() == llvm::omp::OMPC_permutation;
}
};

/// Representation of the 'full' clause of the '#pragma omp unroll' directive.
///
/// \code
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3336,6 +3336,14 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause(
OMPPermutationClause *C) {
for (Expr *E : C->getArgsRefs())
TRY_TO(TraverseStmt(E));
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPFullClause(OMPFullClause *C) {
return true;
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11613,6 +11613,10 @@ def err_omp_dispatch_statement_call
" to a target function or an assignment to one">;
def err_omp_unroll_full_variable_trip_count : Error<
"loop to be fully unrolled must have a constant trip count">;
def err_omp_interchange_permutation_value_range : Error<
"permutation index must be at least 1 and at most %0">;
def err_omp_interchange_permutation_value_repeated : Error<
"index %0 must appear exactly once in the permutation clause">;
def note_omp_directive_here : Note<"'%0' directive found here">;
def err_omp_instantiation_not_supported
: Error<"instantiation of '%0' not supported yet">;
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3574,6 +3574,9 @@ class Parser : public CodeCompletionHandler {
/// Parses the 'sizes' clause of a '#pragma omp tile' directive.
OMPClause *ParseOpenMPSizesClause();

/// Parses the 'permutation' clause of a '#pragma omp interchange' directive.
OMPClause *ParseOpenMPPermutationClause();

/// Parses clause without any additional arguments.
///
/// \param Kind Kind of current clause.
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,11 @@ class SemaOpenMP : public SemaBase {
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-form 'permutation' clause after parsing its arguments.
OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-form 'full' clauses.
OMPClause *ActOnOpenMPFullClause(SourceLocation StartLoc,
SourceLocation EndLoc);
Expand Down
27 changes: 27 additions & 0 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,25 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPSizesClause(NumSizes);
}

OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc,
ArrayRef<Expr *> Args) {
OMPPermutationClause *Clause = CreateEmpty(C, Args.size());
Clause->setLocStart(StartLoc);
Clause->setLParenLoc(LParenLoc);
Clause->setLocEnd(EndLoc);
Clause->setArgRefs(Args);
return Clause;
}

OMPPermutationClause *OMPPermutationClause::CreateEmpty(const ASTContext &C,
unsigned NumLoops) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumLoops));
return new (Mem) OMPPermutationClause(NumLoops);
}

OMPFullClause *OMPFullClause::Create(const ASTContext &C,
SourceLocation StartLoc,
SourceLocation EndLoc) {
Expand Down Expand Up @@ -1774,6 +1793,14 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) {
OS << ")";
}

void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) {
OS << "permutation(";
llvm::interleaveComma(Node->getArgsRefs(), OS, [&](const Expr *E) {
E->printPretty(OS, nullptr, Policy, 0);
});
OS << ")";
}

void OMPClausePrinter::VisitOMPFullClause(OMPFullClause *Node) { OS << "full"; }

void OMPClausePrinter::VisitOMPPartialClause(OMPPartialClause *Node) {
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) {
Profiler->VisitExpr(E);
}

void OMPClauseProfiler::VisitOMPPermutationClause(
const OMPPermutationClause *C) {
for (Expr *E : C->getArgsRefs())
if (E)
Profiler->VisitExpr(E);
}

void OMPClauseProfiler::VisitOMPFullClause(const OMPFullClause *C) {}

void OMPClauseProfiler::VisitOMPPartialClause(const OMPPartialClause *C) {
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Basic/OpenMPKinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
case OMPC_safelen:
case OMPC_simdlen:
case OMPC_sizes:
case OMPC_permutation:
case OMPC_allocator:
case OMPC_allocate:
case OMPC_collapse:
Expand Down Expand Up @@ -512,6 +513,7 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
case OMPC_safelen:
case OMPC_simdlen:
case OMPC_sizes:
case OMPC_permutation:
case OMPC_allocator:
case OMPC_allocate:
case OMPC_collapse:
Expand Down
20 changes: 20 additions & 0 deletions clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3005,6 +3005,18 @@ OMPClause *Parser::ParseOpenMPSizesClause() {
OpenLoc, CloseLoc);
}

OMPClause *Parser::ParseOpenMPPermutationClause() {
SourceLocation ClauseNameLoc, OpenLoc, CloseLoc;
SmallVector<Expr *> ArgExprs;
if (ParseOpenMPExprListClause(OMPC_permutation, ClauseNameLoc, OpenLoc,
CloseLoc, ArgExprs,
/*ReqIntConst=*/true))
return nullptr;

return Actions.OpenMP().ActOnOpenMPPermutationClause(ArgExprs, ClauseNameLoc,
OpenLoc, CloseLoc);
}

OMPClause *Parser::ParseOpenMPUsesAllocatorClause(OpenMPDirectiveKind DKind) {
SourceLocation Loc = Tok.getLocation();
ConsumeAnyToken();
Expand Down Expand Up @@ -3293,6 +3305,14 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,

Clause = ParseOpenMPSizesClause();
break;
case OMPC_permutation:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
ErrorFound = true;
}
Clause = ParseOpenMPPermutationClause();
break;
case OMPC_uses_allocators:
Clause = ParseOpenMPUsesAllocatorClause(DKind);
break;
Expand Down
87 changes: 85 additions & 2 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14873,7 +14873,9 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
return StmtError();

// interchange without permutation clause swaps two loops.
constexpr size_t NumLoops = 2;
const OMPPermutationClause *PermutationClause =
OMPExecutableDirective::getSingleClause<OMPPermutationClause>(Clauses);
size_t NumLoops = PermutationClause ? PermutationClause->getNumLoops() : 2;

// Verify and diagnose loop nest.
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
Expand All @@ -14888,6 +14890,12 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
NumLoops, AStmt, nullptr, nullptr);

// An invalid expression in the permutation clause is set to nullptr in
// ActOnOpenMPPermutationClause.
if (PermutationClause &&
llvm::is_contained(PermutationClause->getArgsRefs(), nullptr))
return StmtError();

assert(LoopHelpers.size() == NumLoops &&
"Expecting loop iteration space dimensionaly to match number of "
"affected loops");
Expand All @@ -14896,7 +14904,44 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
"affected loops");

// Decode the permutation clause.
constexpr uint64_t Permutation[] = {1, 0};
SmallVector<uint64_t, 2> Permutation;
if (!PermutationClause) {
Permutation = {1, 0};
} else {
ArrayRef<Expr *> PermArgs = PermutationClause->getArgsRefs();
llvm::BitVector Flags(PermArgs.size());
for (Expr *PermArg : PermArgs) {
std::optional<llvm::APSInt> PermCstExpr =
PermArg->getIntegerConstantExpr(Context);
if (!PermCstExpr)
continue;
uint64_t PermInt = PermCstExpr->getZExtValue();
assert(1 <= PermInt && PermInt <= NumLoops &&
"Must be a permutation; diagnostic emitted in "
"ActOnOpenMPPermutationClause");
if (Flags[PermInt - 1]) {
SourceRange ExprRange(PermArg->getBeginLoc(), PermArg->getEndLoc());
Diag(PermArg->getExprLoc(),
diag::err_omp_interchange_permutation_value_repeated)
<< PermInt << ExprRange;
continue;
}
Flags[PermInt - 1] = true;

Permutation.push_back(PermInt - 1);
}

if (Permutation.size() != NumLoops)
return StmtError();
}

// Nothing to transform with trivial permutation.
if (NumLoops <= 1 || llvm::all_of(llvm::enumerate(Permutation), [](auto P) {
auto [Idx, Arg] = P;
return Idx == Arg;
}))
return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
NumLoops, AStmt, AStmt, nullptr);

// Find the affected loops.
SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
Expand Down Expand Up @@ -16029,6 +16074,44 @@ OMPClause *SemaOpenMP::ActOnOpenMPSizesClause(ArrayRef<Expr *> SizeExprs,
SanitizedSizeExprs);
}

OMPClause *SemaOpenMP::ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
size_t NumLoops = PermExprs.size();
SmallVector<Expr *> SanitizedPermExprs;
llvm::append_range(SanitizedPermExprs, PermExprs);

for (Expr *&PermExpr : SanitizedPermExprs) {
// Skip if template-dependent or already sanitized, e.g. during a partial
// template instantiation.
if (!PermExpr || PermExpr->isInstantiationDependent())
continue;

llvm::APSInt PermVal;
ExprResult PermEvalExpr = SemaRef.VerifyIntegerConstantExpression(
PermExpr, &PermVal, Sema::AllowFold);
bool IsValid = PermEvalExpr.isUsable();
if (IsValid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (IsValid)
if (!IsValid) {
PermExpr = nullptr;
continue;
}

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In partial template instantiations the expression may not evaluate to a constant yet because it depends on another template parameter. We cannot set PermExpr to nullptr because we need to preserve the expression to be evaluated at another call to ActOnOpenMPPermutationClause when the template is fully instantiated.

PermExpr = PermEvalExpr.get();

if (IsValid && (PermVal < 1 || NumLoops < PermVal)) {
SourceRange ExprRange(PermEvalExpr.get()->getBeginLoc(),
PermEvalExpr.get()->getEndLoc());
Diag(PermEvalExpr.get()->getExprLoc(),
diag::err_omp_interchange_permutation_value_range)
<< NumLoops << ExprRange;
IsValid = false;
}

if (!PermExpr->isInstantiationDependent() && !IsValid)
PermExpr = nullptr;
}

return OMPPermutationClause::Create(getASTContext(), StartLoc, LParenLoc,
EndLoc, SanitizedPermExprs);
}

OMPClause *SemaOpenMP::ActOnOpenMPFullClause(SourceLocation StartLoc,
SourceLocation EndLoc) {
return OMPFullClause::Create(getASTContext(), StartLoc, EndLoc);
Expand Down
Loading