Skip to content

[OpenACC] Implement 'if' clause for Compute Constructs #88411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ class ASTNodeTraverser
void Visit(const OpenACCClause *C) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
// TODO OpenACC: Switch on clauses that have children, and add them.
for (const auto *S : C->children())
Visit(S);
});
}

Expand Down
81 changes: 77 additions & 4 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef LLVM_CLANG_AST_OPENACCCLAUSE_H
#define LLVM_CLANG_AST_OPENACCCLAUSE_H
#include "clang/AST/ASTContext.h"
#include "clang/AST/StmtIterator.h"
#include "clang/Basic/OpenACCKinds.h"

namespace clang {
Expand All @@ -34,6 +35,17 @@ class OpenACCClause {

static bool classof(const OpenACCClause *) { return true; }

using child_iterator = StmtIterator;
using const_child_iterator = ConstStmtIterator;
using child_range = llvm::iterator_range<child_iterator>;
using const_child_range = llvm::iterator_range<const_child_iterator>;

child_range children();
const_child_range children() const {
auto Children = const_cast<OpenACCClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

virtual ~OpenACCClause() = default;
};

Expand All @@ -49,6 +61,13 @@ class OpenACCClauseWithParams : public OpenACCClause {

public:
SourceLocation getLParenLoc() const { return LParenLoc; }

child_range children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}
};

/// A 'default' clause, has the optional 'none' or 'present' argument.
Expand Down Expand Up @@ -81,6 +100,51 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams {
SourceLocation EndLoc);
};

/// Represents one of the handful of classes that has an optional/required
/// 'condition' expression as an argument.
class OpenACCClauseWithCondition : public OpenACCClauseWithParams {
Expr *ConditionExpr = nullptr;

protected:
OpenACCClauseWithCondition(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *ConditionExpr,
SourceLocation EndLoc)
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
ConditionExpr(ConditionExpr) {}

public:
bool hasConditionExpr() const { return ConditionExpr; }
const Expr *getConditionExpr() const { return ConditionExpr; }
Expr *getConditionExpr() { return ConditionExpr; }

child_range children() {
if (ConditionExpr)
return child_range(reinterpret_cast<Stmt **>(&ConditionExpr),
reinterpret_cast<Stmt **>(&ConditionExpr + 1));
return child_range(child_iterator(), child_iterator());
}

const_child_range children() const {
if (ConditionExpr)
return const_child_range(
reinterpret_cast<Stmt *const *>(&ConditionExpr),
reinterpret_cast<Stmt *const *>(&ConditionExpr + 1));
return const_child_range(const_child_iterator(), const_child_iterator());
}
};

/// An 'if' clause, which has a required condition expression.
class OpenACCIfClause : public OpenACCClauseWithCondition {
protected:
OpenACCIfClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *ConditionExpr, SourceLocation EndLoc);

public:
static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *ConditionExpr,
SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand All @@ -98,6 +162,9 @@ template <class Impl> class OpenACCClauseVisitor {
case OpenACCClauseKind::Default:
VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C));
return;
case OpenACCClauseKind::If:
VisitOpenACCIfClause(*cast<OpenACCIfClause>(C));
return;
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
Expand All @@ -106,7 +173,6 @@ template <class Impl> class OpenACCClauseVisitor {
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
Expand Down Expand Up @@ -145,9 +211,13 @@ template <class Impl> class OpenACCClauseVisitor {
llvm_unreachable("Invalid Clause kind");
}

void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) {
return getDerived().VisitOpenACCDefaultClause(Clause);
#define VISIT_CLAUSE(CLAUSE_NAME) \
void VisitOpenACC##CLAUSE_NAME##Clause( \
const OpenACC##CLAUSE_NAME##Clause &Clause) { \
return getDerived().VisitOpenACC##CLAUSE_NAME##Clause(Clause); \
}

#include "clang/Basic/OpenACCClauses.def"
};

class OpenACCClausePrinter final
Expand All @@ -165,7 +235,10 @@ class OpenACCClausePrinter final
}
OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}

void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
#define VISIT_CLAUSE(CLAUSE_NAME) \
void VisitOpenACC##CLAUSE_NAME##Clause( \
const OpenACC##CLAUSE_NAME##Clause &Clause);
#include "clang/Basic/OpenACCClauses.def"
};

} // namespace clang
Expand Down
21 changes: 21 additions & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===-- OpenACCClauses.def - List of implemented OpenACC Clauses -- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a list of currently implemented OpenACC Clauses (and
// eventually, the entire list) in a way that makes generating 'visitor' and
// other lists easier.
//
// The primary macro is a single-argument version taking the name of the Clause
// as used in Clang source (so `Default` instead of `default`).
//
// VISIT_CLAUSE(CLAUSE_NAME)

VISIT_CLAUSE(Default)
VISIT_CLAUSE(If)

#undef VISIT_CLAUSE
5 changes: 5 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3611,6 +3611,9 @@ class Parser : public CodeCompletionHandler {
OpenACCClauseParseResult OpenACCCannotContinue();
OpenACCClauseParseResult OpenACCSuccess(OpenACCClause *Clause);

using OpenACCConditionExprParseResult =
std::pair<ExprResult, OpenACCParseCanContinue>;

/// Parses the OpenACC directive (the entire pragma) including the clause
/// list, but does not produce the main AST node.
OpenACCDirectiveParseInfo ParseOpenACCDirective();
Expand Down Expand Up @@ -3657,6 +3660,8 @@ class Parser : public CodeCompletionHandler {
bool ParseOpenACCGangArgList();
/// Parses a 'gang-arg', used for the 'gang' clause.
bool ParseOpenACCGangArg();
/// Parses a 'condition' expr, ensuring it results in a
ExprResult ParseOpenACCConditionExpr();

private:
//===--------------------------------------------------------------------===//
Expand Down
28 changes: 27 additions & 1 deletion clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class SemaOpenACC : public SemaBase {
OpenACCDefaultClauseKind DefaultClauseKind;
};

std::variant<DefaultDetails> Details;
struct ConditionDetails {
Expr *ConditionExpr;
};

std::variant<DefaultDetails, ConditionDetails> Details;

public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
Expand All @@ -63,6 +67,16 @@ class SemaOpenACC : public SemaBase {
return std::get<DefaultDetails>(Details).DefaultClauseKind;
}

const Expr *getConditionExpr() const {
return const_cast<OpenACCParsedClause *>(this)->getConditionExpr();
}

Expr *getConditionExpr() {
assert(ClauseKind == OpenACCClauseKind::If &&
"Parsed clause kind does not have a condition expr");
return std::get<ConditionDetails>(Details).ConditionExpr;
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

Expand All @@ -71,6 +85,18 @@ class SemaOpenACC : public SemaBase {
"Parsed clause is not a default clause");
Details = DefaultDetails{DefKind};
}

void setConditionDetails(Expr *ConditionExpr) {
assert(ClauseKind == OpenACCClauseKind::If &&
"Parsed clause kind does not have a condition expr");
// In C++ we can count on this being a 'bool', but in C this gets left as
// some sort of scalar that codegen will have to take care of converting.
assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() ||
ConditionExpr->getType()->isScalarType()) &&
"Condition expression type not scalar/dependent");

Details = ConditionDetails{ConditionExpr};
}
};

SemaOpenACC(Sema &S);
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "clang/AST/OpenACCClause.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"

using namespace clang;

Expand All @@ -27,10 +28,48 @@ OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc);
}

OpenACCIfClause *OpenACCIfClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *ConditionExpr,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause));
return new (Mem) OpenACCIfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc);
}

OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *ConditionExpr,
SourceLocation EndLoc)
: OpenACCClauseWithCondition(OpenACCClauseKind::If, BeginLoc, LParenLoc,
ConditionExpr, EndLoc) {
assert(ConditionExpr && "if clause requires condition expr");
assert((ConditionExpr->isInstantiationDependent() ||
ConditionExpr->getType()->isScalarType()) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be bool type only?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, sadly no. "C" mode actually just leaves it as a scalar since it doesn't have a bool type, and we don't manage to do any real conversions there. So I have to accept 'scalar', then follow language rules in CodeGen to get to the correct type.

"Condition expression type not scalar/dependent");
}

OpenACCClause::child_range OpenACCClause::children() {
switch (getClauseKind()) {
default:
assert(false && "Clause children function not implemented");
break;
#define VISIT_CLAUSE(CLAUSE_NAME) \
case OpenACCClauseKind::CLAUSE_NAME: \
return cast<OpenACC##CLAUSE_NAME##Clause>(this)->children();

#include "clang/Basic/OpenACCClauses.def"
}
return child_range(child_iterator(), child_iterator());
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
void OpenACCClausePrinter::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &C) {
OS << "default(" << C.getDefaultClauseKind() << ")";
}

void OpenACCClausePrinter::VisitOpenACCIfClause(const OpenACCIfClause &C) {
OS << "if(" << C.getConditionExpr() << ")";
}
19 changes: 16 additions & 3 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2445,9 +2445,10 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
namespace {
class OpenACCClauseProfiler
: public OpenACCClauseVisitor<OpenACCClauseProfiler> {
StmtProfiler &Profiler;

public:
OpenACCClauseProfiler() = default;
OpenACCClauseProfiler(StmtProfiler &P) : Profiler(P) {}

void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) {
for (const OpenACCClause *Clause : Clauses) {
Expand All @@ -2456,20 +2457,32 @@ class OpenACCClauseProfiler
Visit(Clause);
}
}
void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);

#define VISIT_CLAUSE(CLAUSE_NAME) \
void VisitOpenACC##CLAUSE_NAME##Clause( \
const OpenACC##CLAUSE_NAME##Clause &Clause);

#include "clang/Basic/OpenACCClauses.def"
};

/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &Clause) {}

void OpenACCClauseProfiler::VisitOpenACCIfClause(
const OpenACCIfClause &Clause) {
assert(Clause.hasConditionExpr() &&
"if clause requires a valid condition expr");
Profiler.VisitStmt(Clause.getConditionExpr());
}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
const OpenACCComputeConstruct *S) {
// VisitStmt handles children, so the AssociatedStmt is handled.
VisitStmt(S);

OpenACCClauseProfiler P;
OpenACCClauseProfiler P{*this};
P.VisitOpenACCClauseList(S->clauses());
}

Expand Down
5 changes: 5 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
case OpenACCClauseKind::Default:
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
break;
case OpenACCClauseKind::If:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
OS << " clause";
break;
default:
// Nothing to do here.
break;
Expand Down
Loading