Skip to content

[OpenACC] Implement Default clause for Compute Constructs #88135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,36 @@ class OpenACCClauseWithParams : public OpenACCClause {
SourceLocation getLParenLoc() const { return LParenLoc; }
};

/// A 'default' clause, has the optional 'none' or 'present' argument.
class OpenACCDefaultClause : public OpenACCClauseWithParams {
friend class ASTReaderStmt;
friend class ASTWriterStmt;

OpenACCDefaultClauseKind DefaultClauseKind;

protected:
OpenACCDefaultClause(OpenACCDefaultClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
: OpenACCClauseWithParams(OpenACCClauseKind::Default, BeginLoc, LParenLoc,
EndLoc),
DefaultClauseKind(K) {
assert((DefaultClauseKind == OpenACCDefaultClauseKind::None ||
DefaultClauseKind == OpenACCDefaultClauseKind::Present) &&
"Invalid Clause Kind");
}

public:
OpenACCDefaultClauseKind getDefaultClauseKind() const {
return DefaultClauseKind;
}

static OpenACCDefaultClause *Create(const ASTContext &C,
OpenACCDefaultClauseKind K,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand All @@ -66,6 +96,8 @@ template <class Impl> class OpenACCClauseVisitor {

switch (C->getClauseKind()) {
case OpenACCClauseKind::Default:
VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C));
return;
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
Expand Down Expand Up @@ -112,6 +144,10 @@ template <class Impl> class OpenACCClauseVisitor {
}
llvm_unreachable("Invalid Clause kind");
}

void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) {
return getDerived().VisitOpenACCDefaultClause(Clause);
}
};

class OpenACCClausePrinter final
Expand All @@ -128,6 +164,8 @@ class OpenACCClausePrinter final
}
}
OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}

void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
};

} // namespace clang
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12254,6 +12254,10 @@ def err_acc_construct_appertainment
"be used in a statement context">;
def err_acc_clause_appertainment
: Error<"OpenACC '%1' clause is not valid on '%0' directive">;
def err_acc_duplicate_clause_disallowed
: Error<"OpenACC '%1' clause cannot appear more than once on a '%0' "
"directive">;
def note_acc_previous_clause_here : Note<"previous clause is here">;
def err_acc_branch_in_out_compute_construct
: Error<"invalid %select{branch|return|throw}0 %select{out of|into}1 "
"OpenACC Compute Construct">;
Expand Down
23 changes: 23 additions & 0 deletions clang/include/clang/Basic/OpenACCKinds.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,29 @@ enum class OpenACCDefaultClauseKind {
Invalid,
};

template <typename StreamTy>
inline StreamTy &printOpenACCDefaultClauseKind(StreamTy &Out,
OpenACCDefaultClauseKind K) {
switch (K) {
case OpenACCDefaultClauseKind::None:
return Out << "none";
case OpenACCDefaultClauseKind::Present:
return Out << "present";
case OpenACCDefaultClauseKind::Invalid:
return Out << "<invalid>";
}
}

inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
OpenACCDefaultClauseKind K) {
return printOpenACCDefaultClauseKind(Out, K);
}

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
OpenACCDefaultClauseKind K) {
return printOpenACCDefaultClauseKind(Out, K);
}

enum class OpenACCReductionOperator {
/// '+'.
Addition,
Expand Down
19 changes: 18 additions & 1 deletion clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Ownership.h"
#include "clang/Sema/SemaBase.h"
#include <variant>

namespace clang {
class OpenACCClause;
Expand All @@ -35,7 +36,11 @@ class SemaOpenACC : public SemaBase {
SourceRange ClauseRange;
SourceLocation LParenLoc;

// TODO OpenACC: Add variant here to store details of individual clauses.
struct DefaultDetails {
OpenACCDefaultClauseKind DefaultClauseKind;
};

std::variant<DefaultDetails> Details;

public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
Expand All @@ -52,8 +57,20 @@ class SemaOpenACC : public SemaBase {

SourceLocation getEndLoc() const { return ClauseRange.getEnd(); }

OpenACCDefaultClauseKind getDefaultClauseKind() const {
assert(ClauseKind == OpenACCClauseKind::Default &&
"Parsed clause is not a default clause");
return std::get<DefaultDetails>(Details).DefaultClauseKind;
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

void setDefaultDetails(OpenACCDefaultClauseKind DefKind) {
assert(ClauseKind == OpenACCClauseKind::Default &&
"Parsed clause is not a default clause");
Details = DefaultDetails{DefKind};
}
};

SemaOpenACC(Sema &S);
Expand Down
19 changes: 19 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,22 @@
#include "clang/AST/ASTContext.h"

using namespace clang;

OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
OpenACCDefaultClauseKind K,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
void *Mem =
C.Allocate(sizeof(OpenACCDefaultClause), alignof(OpenACCDefaultClause));

return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc);
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
void OpenACCClausePrinter::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &C) {
OS << "default(" << C.getDefaultClauseKind() << ")";
}
5 changes: 5 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,7 +2456,12 @@ class OpenACCClauseProfiler
Visit(Clause);
}
}
void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
};

/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &Clause) {}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,17 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
{
ColorScope Color(OS, ShowColors, AttrColor);
OS << C->getClauseKind();

// Handle clauses with parens for types that have no children, likely
// because there is no sub expression.
switch (C->getClauseKind()) {
case OpenACCClauseKind::Default:
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
break;
default:
// Nothing to do here.
break;
}
}
dumpPointer(C);
dumpSourceRange(SourceRange(C->getBeginLoc(), C->getEndLoc()));
Expand Down
8 changes: 6 additions & 2 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,9 +831,13 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(

ConsumeToken();

if (getOpenACCDefaultClauseKind(DefKindTok) ==
OpenACCDefaultClauseKind::Invalid)
OpenACCDefaultClauseKind DefKind =
getOpenACCDefaultClauseKind(DefKindTok);

if (DefKind == OpenACCDefaultClauseKind::Invalid)
Diag(DefKindTok, diag::err_acc_invalid_default_clause_kind);
else
ParsedClause.setDefaultDetails(DefKind);

break;
}
Expand Down
65 changes: 58 additions & 7 deletions clang/lib/Sema/SemaOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,27 @@ bool diagnoseConstructAppertainment(SemaOpenACC &S, OpenACCDirectiveKind K,

bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
OpenACCClauseKind ClauseKind) {
// FIXME: For each clause as we implement them, we can add the
// 'legalization' list here.

// Do nothing so we can go to the 'unimplemented' diagnostic instead.
return true;
switch (ClauseKind) {
// FIXME: For each clause as we implement them, we can add the
// 'legalization' list here.
case OpenACCClauseKind::Default:
switch (DirectiveKind) {
case OpenACCDirectiveKind::Parallel:
case OpenACCDirectiveKind::Serial:
case OpenACCDirectiveKind::Kernels:
case OpenACCDirectiveKind::ParallelLoop:
case OpenACCDirectiveKind::SerialLoop:
case OpenACCDirectiveKind::KernelsLoop:
case OpenACCDirectiveKind::Data:
return true;
default:
return false;
}
default:
// Do nothing so we can go to the 'unimplemented' diagnostic instead.
return true;
}
llvm_unreachable("Invalid clause kind");
}
} // namespace

Expand All @@ -63,8 +79,43 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
return nullptr;
}

// TODO OpenACC: Switch over the clauses we implement here and 'create'
// them.
switch (Clause.getClauseKind()) {
case OpenACCClauseKind::Default: {
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Parallel &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Serial &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Kernels)
break;

// Don't add an invalid clause to the AST.
if (Clause.getDefaultClauseKind() == OpenACCDefaultClauseKind::Invalid)
return nullptr;

// OpenACC 3.3, Section 2.5.4:
// At most one 'default' clause may appear, and it must have a value of
// either 'none' or 'present'.
// Second half of the sentence is diagnosed during parsing.
auto Itr = llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Default;
});

if (Itr != ExistingClauses.end()) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_duplicate_clause_disallowed)
<< Clause.getDirectiveKind() << Clause.getClauseKind();
SemaRef.Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
return nullptr;
}

return OpenACCDefaultClause::Create(
getASTContext(), Clause.getDefaultClauseKind(), Clause.getBeginLoc(),
Clause.getLParenLoc(), Clause.getEndLoc());
}
default:
break;
}

Diag(Clause.getBeginLoc(), diag::warn_acc_clause_unimplemented)
<< Clause.getClauseKind();
Expand Down
42 changes: 39 additions & 3 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -4034,6 +4034,11 @@ class TreeTransform {
llvm::SmallVector<OpenACCClause *>
TransformOpenACCClauseList(OpenACCDirectiveKind DirKind,
ArrayRef<const OpenACCClause *> OldClauses);

OpenACCClause *
TransformOpenACCClause(ArrayRef<const OpenACCClause *> ExistingClauses,
OpenACCDirectiveKind DirKind,
const OpenACCClause *OldClause);
};

template <typename Derived>
Expand Down Expand Up @@ -11074,13 +11079,44 @@ OMPClause *TreeTransform<Derived>::TransformOMPXBareClause(OMPXBareClause *C) {
//===----------------------------------------------------------------------===//
// OpenACC transformation
//===----------------------------------------------------------------------===//
template <typename Derived>
OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(
ArrayRef<const OpenACCClause *> ExistingClauses,
OpenACCDirectiveKind DirKind, const OpenACCClause *OldClause) {

SemaOpenACC::OpenACCParsedClause ParsedClause(
DirKind, OldClause->getClauseKind(), OldClause->getBeginLoc());
ParsedClause.setEndLoc(OldClause->getEndLoc());

if (const auto *WithParms = dyn_cast<OpenACCClauseWithParams>(OldClause))
ParsedClause.setLParenLoc(WithParms->getLParenLoc());

switch (OldClause->getClauseKind()) {
case OpenACCClauseKind::Default:
// There is nothing to do here as nothing dependent can appear in this
// clause. So just set the values so Sema can set the right value.
ParsedClause.setDefaultDetails(
cast<OpenACCDefaultClause>(OldClause)->getDefaultClauseKind());
break;
default:
assert(false && "Unhandled OpenACC clause in TreeTransform");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm_unreachable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to Aaron Ballman, this isn't really what 'unreachable' is for.

There is a pretty extensive discussion somewhere between 'assert(false...) ' and 'llvm_unreachable', as the latter performs optimizations on release builds that can result in an unstable compiler.

In this case, I'm using it somewhat for potential flow-control during development, but once I'm sure I have all the cases covered, this will either go away or switch to an unreachable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I do remember something related. That's why it is just a question. :)

return nullptr;
}

return getSema().OpenACC().ActOnClause(ExistingClauses, ParsedClause);
}

template <typename Derived>
llvm::SmallVector<OpenACCClause *>
TreeTransform<Derived>::TransformOpenACCClauseList(
OpenACCDirectiveKind DirKind, ArrayRef<const OpenACCClause *> OldClauses) {
// TODO OpenACC: Ensure we loop through the list and transform the individual
// clauses.
return {};
llvm::SmallVector<OpenACCClause *> TransformedClauses;
for (const auto *Clause : OldClauses) {
if (OpenACCClause *TransformedClause = getDerived().TransformOpenACCClause(
TransformedClauses, DirKind, Clause))
TransformedClauses.push_back(TransformedClause);
}
return TransformedClauses;
}

template <typename Derived>
Expand Down
7 changes: 6 additions & 1 deletion clang/lib/Serialization/ASTReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11763,7 +11763,12 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
[[maybe_unused]] SourceLocation EndLoc = readSourceLocation();

switch (ClauseKind) {
case OpenACCClauseKind::Default:
case OpenACCClauseKind::Default: {
SourceLocation LParenLoc = readSourceLocation();
OpenACCDefaultClauseKind DCK = readEnum<OpenACCDefaultClauseKind>();
return OpenACCDefaultClause::Create(getContext(), DCK, BeginLoc, LParenLoc,
EndLoc);
}
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
Expand Down
7 changes: 6 additions & 1 deletion clang/lib/Serialization/ASTWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7406,7 +7406,12 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
writeSourceLocation(C->getEndLoc());

switch (C->getClauseKind()) {
case OpenACCClauseKind::Default:
case OpenACCClauseKind::Default: {
const auto *DC = cast<OpenACCDefaultClause>(C);
writeSourceLocation(DC->getLParenLoc());
writeEnum(DC->getDefaultClauseKind());
return;
}
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
Expand Down
Loading