Skip to content

[OpenACC][NFC] Add OpenACC Clause AST Nodes/infrastructure #87675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct {
void Visit(TypeLoc);
void Visit(const Decl *D);
void Visit(const CXXCtorInitializer *Init);
void Visit(const OpenACCClause *C);
void Visit(const OMPClause *C);
void Visit(const BlockDecl::Capture &C);
void Visit(const GenericSelectionExpr::ConstAssociation &A);
Expand Down Expand Up @@ -239,6 +240,13 @@ class ASTNodeTraverser
});
}

void Visit(const OpenACCClause *C) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
// TODO OpenACC: Switch on clauses that have children, and add them.
});
}

void Visit(const OMPClause *C) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
Expand Down Expand Up @@ -799,6 +807,11 @@ class ASTNodeTraverser
Visit(C);
}

void VisitOpenACCConstructStmt(const OpenACCConstructStmt *Node) {
for (const auto *C : Node->clauses())
Visit(C);
}

void VisitInitListExpr(const InitListExpr *ILE) {
if (auto *Filler = ILE->getArrayFiller()) {
Visit(Filler, "array_filler");
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/AST/JSONNodeDumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class JSONNodeDumper
void Visit(const TemplateArgument &TA, SourceRange R = {},
const Decl *From = nullptr, StringRef Label = {});
void Visit(const CXXCtorInitializer *Init);
void Visit(const OpenACCClause *C);
void Visit(const OMPClause *C);
void Visit(const BlockDecl::Capture &C);
void Visit(const GenericSelectionExpr::ConstAssociation &A);
Expand Down
135 changes: 135 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
//===- OpenACCClause.h - Classes for OpenACC clauses ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// This file defines OpenACC AST classes for clauses.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_AST_OPENACCCLAUSE_H
#define LLVM_CLANG_AST_OPENACCCLAUSE_H
#include "clang/AST/ASTContext.h"
#include "clang/Basic/OpenACCKinds.h"

namespace clang {
/// This is the base type for all OpenACC Clauses.
class OpenACCClause {
OpenACCClauseKind Kind;
SourceRange Location;

protected:
OpenACCClause(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation EndLoc)
: Kind(K), Location(BeginLoc, EndLoc) {}

public:
OpenACCClauseKind getClauseKind() const { return Kind; }
SourceLocation getBeginLoc() const { return Location.getBegin(); }
SourceLocation getEndLoc() const { return Location.getEnd(); }

static bool classof(const OpenACCClause *) { return true; }

virtual ~OpenACCClause() = default;
};

/// Represents a clause that has a list of parameters.
class OpenACCClauseWithParams : public OpenACCClause {
/// Location of the '('.
SourceLocation LParenLoc;

protected:
OpenACCClauseWithParams(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
: OpenACCClause(K, BeginLoc, EndLoc), LParenLoc(LParenLoc) {}

public:
SourceLocation getLParenLoc() const { return LParenLoc; }
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

public:
void VisitClauseList(ArrayRef<const OpenACCClause *> List) {
for (const OpenACCClause *Clause : List)
Visit(Clause);
}

void Visit(const OpenACCClause *C) {
if (!C)
return;

switch (C->getClauseKind()) {
case OpenACCClauseKind::Default:
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
case OpenACCClauseKind::Independent:
case OpenACCClauseKind::Auto:
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::Delete:
case OpenACCClauseKind::Detach:
case OpenACCClauseKind::Device:
case OpenACCClauseKind::DevicePtr:
case OpenACCClauseKind::DeviceResident:
case OpenACCClauseKind::FirstPrivate:
case OpenACCClauseKind::Host:
case OpenACCClauseKind::Link:
case OpenACCClauseKind::NoCreate:
case OpenACCClauseKind::Present:
case OpenACCClauseKind::Private:
case OpenACCClauseKind::CopyOut:
case OpenACCClauseKind::CopyIn:
case OpenACCClauseKind::Create:
case OpenACCClauseKind::Reduction:
case OpenACCClauseKind::Collapse:
case OpenACCClauseKind::Bind:
case OpenACCClauseKind::VectorLength:
case OpenACCClauseKind::NumGangs:
case OpenACCClauseKind::NumWorkers:
case OpenACCClauseKind::DeviceNum:
case OpenACCClauseKind::DefaultAsync:
case OpenACCClauseKind::DeviceType:
case OpenACCClauseKind::DType:
case OpenACCClauseKind::Async:
case OpenACCClauseKind::Tile:
case OpenACCClauseKind::Gang:
case OpenACCClauseKind::Wait:
case OpenACCClauseKind::Invalid:
llvm_unreachable("Clause visitor not yet implemented");
}
llvm_unreachable("Invalid Clause kind");
}
};

class OpenACCClausePrinter final
: public OpenACCClauseVisitor<OpenACCClausePrinter> {
raw_ostream &OS;

public:
void VisitClauseList(ArrayRef<const OpenACCClause *> List) {
for (const OpenACCClause *Clause : List) {
Visit(Clause);

if (Clause != List.back())
OS << ' ';
}
}
OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}
};

} // namespace clang

#endif // LLVM_CLANG_AST_OPENACCCLAUSE_H
13 changes: 11 additions & 2 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ template <typename Derived> class RecursiveASTVisitor {
bool TraverseOpenACCConstructStmt(OpenACCConstructStmt *S);
bool
TraverseOpenACCAssociatedStmtConstruct(OpenACCAssociatedStmtConstruct *S);
bool VisitOpenACCClauseList(ArrayRef<const OpenACCClause *>);
};

template <typename Derived>
Expand Down Expand Up @@ -3936,8 +3937,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPXBareClause(OMPXBareClause *C) {

template <typename Derived>
bool RecursiveASTVisitor<Derived>::TraverseOpenACCConstructStmt(
OpenACCConstructStmt *) {
// TODO OpenACC: When we implement clauses, ensure we traverse them here.
OpenACCConstructStmt *C) {
TRY_TO(VisitOpenACCClauseList(C->clauses()));
return true;
}

Expand All @@ -3949,6 +3950,14 @@ bool RecursiveASTVisitor<Derived>::TraverseOpenACCAssociatedStmtConstruct(
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOpenACCClauseList(
ArrayRef<const OpenACCClause *>) {
// TODO OpenACC: When we have Clauses with expressions, we should visit them
// here.
return true;
}

DEF_TRAVERSE_STMT(OpenACCComputeConstruct,
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })

Expand Down
55 changes: 46 additions & 9 deletions clang/include/clang/AST/StmtOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#ifndef LLVM_CLANG_AST_STMTOPENACC_H
#define LLVM_CLANG_AST_STMTOPENACC_H

#include "clang/AST/OpenACCClause.h"
#include "clang/AST/Stmt.h"
#include "clang/Basic/OpenACCKinds.h"
#include "clang/Basic/SourceLocation.h"
#include <memory>

namespace clang {
/// This is the base class for an OpenACC statement-level construct, other
Expand All @@ -30,13 +32,23 @@ class OpenACCConstructStmt : public Stmt {
/// the directive.
SourceRange Range;

// TODO OPENACC: Clauses should probably be collected in this class.
/// The list of clauses. This is stored here as an ArrayRef, as this is the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this patch is just boilerplate, but THIS decision I think is the important one. I could EITHER have the base clause store an llvm::SmallVector of clause pointers, OR do the trailing storage trick I'm doing here. The trailing-storage seemed closest to being what OMP does, but is a bit extra rigamarole to make work.

/// most convienient place to access the list, however the list itself should
/// be stored in leaf nodes, likely in trailing-storage.
MutableArrayRef<const OpenACCClause *> Clauses;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to keep as MutbleArrayRef, not just ArrayRef?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately de-Serialization means we have to be able to modify the elements. See ASTReaderStmt.cpp:2790 here.


protected:
OpenACCConstructStmt(StmtClass SC, OpenACCDirectiveKind K,
SourceLocation Start, SourceLocation End)
: Stmt(SC), Kind(K), Range(Start, End) {}

// Used only for initialization, the leaf class can initialize this to
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ends up being necessary because the trailing storage isn't initialized when we do construction, so the derived classes need to set this after this class is constructed (and their trailing storage base is legal to do stuff with).

// trailing storage.
void setClauseList(MutableArrayRef<const OpenACCClause *> NewClauses) {
assert(Clauses.empty() && "Cannot change clause list");
Clauses = NewClauses;
}

public:
OpenACCDirectiveKind getDirectiveKind() const { return Kind; }

Expand All @@ -47,6 +59,7 @@ class OpenACCConstructStmt : public Stmt {

SourceLocation getBeginLoc() const { return Range.getBegin(); }
SourceLocation getEndLoc() const { return Range.getEnd(); }
ArrayRef<const OpenACCClause *> clauses() const { return Clauses; }

child_range children() {
return child_range(child_iterator(), child_iterator());
Expand Down Expand Up @@ -101,24 +114,46 @@ class OpenACCAssociatedStmtConstruct : public OpenACCConstructStmt {
/// those three, as they are semantically identical, and have only minor
/// differences in the permitted list of clauses, which can be differentiated by
/// the 'Kind'.
class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
class OpenACCComputeConstruct final
: public OpenACCAssociatedStmtConstruct,
public llvm::TrailingObjects<OpenACCComputeConstruct,
const OpenACCClause *> {
friend class ASTStmtWriter;
friend class ASTStmtReader;
friend class ASTContext;
OpenACCComputeConstruct()
: OpenACCAssociatedStmtConstruct(
OpenACCComputeConstructClass, OpenACCDirectiveKind::Invalid,
SourceLocation{}, SourceLocation{}, /*AssociatedStmt=*/nullptr) {}
OpenACCComputeConstruct(unsigned NumClauses)
: OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass,
OpenACCDirectiveKind::Invalid,
SourceLocation{}, SourceLocation{},
/*AssociatedStmt=*/nullptr) {
// We cannot send the TrailingObjects storage to the base class (which holds
// a reference to the data) until it is constructed, so we have to set it
// separately here.
std::uninitialized_value_construct(
getTrailingObjects<const OpenACCClause *>(),
getTrailingObjects<const OpenACCClause *>() + NumClauses);
setClauseList(MutableArrayRef(getTrailingObjects<const OpenACCClause *>(),
NumClauses));
}

OpenACCComputeConstruct(OpenACCDirectiveKind K, SourceLocation Start,
SourceLocation End, Stmt *StructuredBlock)
SourceLocation End,
ArrayRef<const OpenACCClause *> Clauses,
Stmt *StructuredBlock)
: OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass, K, Start,
End, StructuredBlock) {
assert((K == OpenACCDirectiveKind::Parallel ||
K == OpenACCDirectiveKind::Serial ||
K == OpenACCDirectiveKind::Kernels) &&
"Only parallel, serial, and kernels constructs should be "
"represented by this type");

// Initialize the trailing storage.
std::uninitialized_copy(Clauses.begin(), Clauses.end(),
getTrailingObjects<const OpenACCClause *>());

setClauseList(MutableArrayRef(getTrailingObjects<const OpenACCClause *>(),
Clauses.size()));
}

void setStructuredBlock(Stmt *S) { setAssociatedStmt(S); }
Expand All @@ -128,10 +163,12 @@ class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
return T->getStmtClass() == OpenACCComputeConstructClass;
}

static OpenACCComputeConstruct *CreateEmpty(const ASTContext &C, EmptyShell);
static OpenACCComputeConstruct *CreateEmpty(const ASTContext &C,
unsigned NumClauses);
static OpenACCComputeConstruct *
Create(const ASTContext &C, OpenACCDirectiveKind K, SourceLocation BeginLoc,
SourceLocation EndLoc, Stmt *StructuredBlock);
SourceLocation EndLoc, ArrayRef<const OpenACCClause *> Clauses,
Stmt *StructuredBlock);

Stmt *getStructuredBlock() { return getAssociatedStmt(); }
const Stmt *getStructuredBlock() const {
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/AST/TextNodeDumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class TextNodeDumper

void Visit(const OMPClause *C);

void Visit(const OpenACCClause *C);

void Visit(const BlockDecl::Capture &C);

void Visit(const GenericSelectionExpr::ConstAssociation &A);
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Serialization/ASTRecordReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/APSInt.h"

namespace clang {
class OpenACCClause;
class OMPTraitInfo;
class OMPChildren;

Expand Down Expand Up @@ -278,6 +279,12 @@ class ASTRecordReader
/// Read an OpenMP children, advancing Idx.
void readOMPChildren(OMPChildren *Data);

/// Read an OpenACC clause, advancing Idx.
OpenACCClause *readOpenACCClause();

/// Read a list of OpenACC clauses into the passed SmallVector.
void readOpenACCClauseList(MutableArrayRef<const OpenACCClause *> Clauses);

/// Read a source location, advancing Idx.
SourceLocation readSourceLocation(LocSeq *Seq = nullptr) {
return Reader->ReadSourceLocation(*F, Record, Idx, Seq);
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Serialization/ASTRecordWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

namespace clang {

class OpenACCClause;
class TypeLoc;

/// An object for streaming information to a record.
Expand Down Expand Up @@ -292,6 +293,12 @@ class ASTRecordWriter
/// Writes data related to the OpenMP directives.
void writeOMPChildren(OMPChildren *Data);

/// Writes out a single OpenACC Clause.
void writeOpenACCClause(const OpenACCClause *C);

/// Writes out a list of OpenACC clauses.
void writeOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses);

/// Emit a string.
void AddString(StringRef Str) {
return Writer->AddString(Str, *Record);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ add_clang_library(clangAST
NSAPI.cpp
ODRDiagsEmitter.cpp
ODRHash.cpp
OpenACCClause.cpp
OpenMPClause.cpp
OSLog.cpp
ParentMap.cpp
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/JSONNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ void JSONNodeDumper::Visit(const CXXCtorInitializer *Init) {
llvm_unreachable("Unknown initializer type");
}

void JSONNodeDumper::Visit(const OpenACCClause *C) {}

void JSONNodeDumper::Visit(const OMPClause *C) {}

void JSONNodeDumper::Visit(const BlockDecl::Capture &C) {
Expand Down
Loading