Skip to content

[OpenACC] Implement loop 'gang' clause. #112006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,32 +119,6 @@ class OpenACCSeqClause : public OpenACCClause {
}
};

// Not yet implemented, but the type name is necessary for 'seq' diagnostics, so
// this provides a basic, do-nothing implementation. We still need to add this
// type to the visitors/etc, as well as get it to take its proper arguments.
class OpenACCGangClause : public OpenACCClause {
protected:
OpenACCGangClause(SourceLocation BeginLoc, SourceLocation EndLoc)
: OpenACCClause(OpenACCClauseKind::Gang, BeginLoc, EndLoc) {
llvm_unreachable("Not yet implemented");
}

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Gang;
}

static OpenACCGangClause *
Create(const ASTContext &Ctx, SourceLocation BeginLoc, SourceLocation EndLoc);

child_range children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}
};

// Not yet implemented, but the type name is necessary for 'seq' diagnostics, so
// this provides a basic, do-nothing implementation. We still need to add this
// type to the visitors/etc, as well as get it to take its proper arguments.
Expand Down Expand Up @@ -177,7 +151,7 @@ class OpenACCVectorClause : public OpenACCClause {
class OpenACCWorkerClause : public OpenACCClause {
protected:
OpenACCWorkerClause(SourceLocation BeginLoc, SourceLocation EndLoc)
: OpenACCClause(OpenACCClauseKind::Gang, BeginLoc, EndLoc) {
: OpenACCClause(OpenACCClauseKind::Worker, BeginLoc, EndLoc) {
llvm_unreachable("Not yet implemented");
}

Expand Down Expand Up @@ -535,6 +509,38 @@ class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithExprs {
Expr *getIntExpr() { return hasIntExpr() ? getExprs()[0] : nullptr; };
};

class OpenACCGangClause final
: public OpenACCClauseWithExprs,
public llvm::TrailingObjects<OpenACCGangClause, Expr *, OpenACCGangKind> {
protected:
OpenACCGangClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<OpenACCGangKind> GangKinds,
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);

OpenACCGangKind getGangKind(unsigned I) const {
return getTrailingObjects<OpenACCGangKind>()[I];
}

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Gang;
}

size_t numTrailingObjects(OverloadToken<Expr *>) const {
return getNumExprs();
}

unsigned getNumExprs() const { return getExprs().size(); }
std::pair<OpenACCGangKind, const Expr *> getExpr(unsigned I) const {
return {getGangKind(I), getExprs()[I]};
}

static OpenACCGangClause *
Create(const ASTContext &Ctx, SourceLocation BeginLoc,
SourceLocation LParenLoc, ArrayRef<OpenACCGangKind> GangKinds,
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
};

class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);
Expand Down
19 changes: 19 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12576,6 +12576,7 @@ def err_acc_duplicate_clause_disallowed
: Error<"OpenACC '%1' clause cannot appear more than once on a '%0' "
"directive">;
def note_acc_previous_clause_here : Note<"previous clause is here">;
def note_acc_previous_expr_here : Note<"previous expression is here">;
def err_acc_branch_in_out_compute_construct
: Error<"invalid %select{branch|return|throw}0 %select{out of|into}1 "
"OpenACC Compute Construct">;
Expand Down Expand Up @@ -12682,6 +12683,24 @@ def err_acc_insufficient_loops
def err_acc_intervening_code
: Error<"inner loops must be tightly nested inside a '%0' clause on "
"a 'loop' construct">;
def err_acc_gang_multiple_elt
: Error<"OpenACC 'gang' clause may have at most one %select{unnamed or "
"'num'|'dim'|'static'}0 argument">;
def err_acc_gang_arg_invalid
: Error<"'%0' argument on 'gang' clause is not permitted on a%select{n "
"orphaned|||}1 'loop' construct %select{|associated with a "
"'parallel' compute construct|associated with a 'kernels' compute "
"construct|associated with a 'serial' compute construct}1">;
def err_acc_gang_dim_value
: Error<"argument to 'gang' clause dimension must be %select{a constant "
"expression|1, 2, or 3: evaluated to %1}0">;
def err_acc_gang_num_gangs_conflict
: Error<"'num' argument to 'gang' clause not allowed on a 'loop' construct "
"associated with a 'kernels' construct that has a 'num_gangs' "
"clause">;
def err_acc_gang_inside_gang
: Error<"loop with a 'gang' clause may not exist in the region of a 'gang' "
"clause on a 'kernels' compute construct">;

// AMDGCN builtins diagnostics
def err_amdgcn_global_load_lds_size_invalid_value : Error<"invalid size value">;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ VISIT_CLAUSE(DevicePtr)
VISIT_CLAUSE(DeviceType)
CLAUSE_ALIAS(DType, DeviceType, false)
VISIT_CLAUSE(FirstPrivate)
VISIT_CLAUSE(Gang)
VISIT_CLAUSE(If)
VISIT_CLAUSE(Independent)
VISIT_CLAUSE(NoCreate)
Expand Down
29 changes: 29 additions & 0 deletions clang/include/clang/Basic/OpenACCKinds.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,35 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
OpenACCReductionOperator Op) {
return printOpenACCReductionOperator(Out, Op);
}

enum class OpenACCGangKind : uint8_t {
/// num:
Num,
/// dim:
Dim,
/// static:
Static
};

template <typename StreamTy>
inline StreamTy &printOpenACCGangKind(StreamTy &Out, OpenACCGangKind GK) {
switch (GK) {
case OpenACCGangKind::Num:
return Out << "num";
case OpenACCGangKind::Dim:
return Out << "dim";
case OpenACCGangKind::Static:
return Out << "static";
}
}
inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
OpenACCGangKind Op) {
return printOpenACCGangKind(Out, Op);
}
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
OpenACCGangKind Op) {
return printOpenACCGangKind(Out, Op);
}
} // namespace clang

#endif // LLVM_CLANG_BASIC_OPENACCKINDS_H
12 changes: 9 additions & 3 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3797,9 +3797,15 @@ class Parser : public CodeCompletionHandler {
bool ParseOpenACCSizeExprList(OpenACCClauseKind CK,
llvm::SmallVectorImpl<Expr *> &SizeExprs);
/// Parses a 'gang-arg-list', used for the 'gang' clause.
bool ParseOpenACCGangArgList(SourceLocation GangLoc);
/// Parses a 'gang-arg', used for the 'gang' clause.
bool ParseOpenACCGangArg(SourceLocation GangLoc);
bool ParseOpenACCGangArgList(SourceLocation GangLoc,
llvm::SmallVectorImpl<OpenACCGangKind> &GKs,
llvm::SmallVectorImpl<Expr *> &IntExprs);

using OpenACCGangArgRes = std::pair<OpenACCGangKind, ExprResult>;
/// Parses a 'gang-arg', used for the 'gang' clause. Returns a pair of the
/// ExprResult (which contains the validity of the expression), plus the gang
/// kind for the current argument.
OpenACCGangArgRes ParseOpenACCGangArg(SourceLocation GangLoc);
/// Parses a 'condition' expr, ensuring it results in a
ExprResult ParseOpenACCConditionExpr();

Expand Down
80 changes: 73 additions & 7 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,20 @@ class SemaOpenACC : public SemaBase {
/// haven't had their 'parent' compute construct set yet. Entires will only be
/// made to this list in the case where we know the loop isn't an orphan.
llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;
/// Whether we are inside of a compute construct, and should add loops to the
/// above collection.
bool InsideComputeConstruct = false;

struct ComputeConstructInfo {
/// Which type of compute construct we are inside of, which we can use to
/// determine whether we should add loops to the above collection. We can
/// also use it to diagnose loop construct clauses.
OpenACCDirectiveKind Kind = OpenACCDirectiveKind::Invalid;
// If we have an active compute construct, stores the list of clauses we've
// prepared for it, so that we can diagnose limitations on child constructs.
ArrayRef<OpenACCClause *> Clauses;
} ActiveComputeConstructInfo;

bool isInComputeConstruct() const {
return ActiveComputeConstructInfo.Kind != OpenACCDirectiveKind::Invalid;
}

/// Certain clauses care about the same things that aren't specific to the
/// individual clause, but can be shared by a few, so store them here. All
Expand Down Expand Up @@ -99,6 +110,15 @@ class SemaOpenACC : public SemaBase {
} TileInfo;

public:
ComputeConstructInfo &getActiveComputeConstructInfo() {
return ActiveComputeConstructInfo;
}

/// If there is a current 'active' loop construct with a 'gang' clause on a
/// 'kernel' construct, this will have the source location for it. This
/// permits us to implement the restriction of no further 'gang' clauses.
SourceLocation LoopGangClauseOnKernelLoc;

// Redeclaration of the version in OpenACCClause.h.
using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;

Expand Down Expand Up @@ -149,9 +169,14 @@ class SemaOpenACC : public SemaBase {
Expr *LoopCount;
};

struct GangDetails {
SmallVector<OpenACCGangKind> GangKinds;
SmallVector<Expr *> IntExprs;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
ReductionDetails, CollapseDetails>
ReductionDetails, CollapseDetails, GangDetails>
Details = std::monostate{};

public:
Expand Down Expand Up @@ -245,9 +270,18 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::NumWorkers ||
ClauseKind == OpenACCClauseKind::Async ||
ClauseKind == OpenACCClauseKind::Tile ||
ClauseKind == OpenACCClauseKind::Gang ||
ClauseKind == OpenACCClauseKind::VectorLength) &&
"Parsed clause kind does not have a int exprs");

if (ClauseKind == OpenACCClauseKind::Gang) {
// There might not be any gang int exprs, as this is an optional
// argument.
if (std::holds_alternative<std::monostate>(Details))
return {};
return std::get<GangDetails>(Details).IntExprs;
}

return std::get<IntExprDetails>(Details).IntExprs;
}

Expand All @@ -259,6 +293,16 @@ class SemaOpenACC : public SemaBase {
return std::get<ReductionDetails>(Details).Op;
}

ArrayRef<OpenACCGangKind> getGangKinds() const {
assert(ClauseKind == OpenACCClauseKind::Gang &&
"Parsed clause kind does not have gang kind");
// The args on gang are optional, so this might not actually hold
// anything.
if (std::holds_alternative<std::monostate>(Details))
return {};
return std::get<GangDetails>(Details).GangKinds;
}

ArrayRef<Expr *> getVarList() {
assert((ClauseKind == OpenACCClauseKind::Private ||
ClauseKind == OpenACCClauseKind::NoCreate ||
Expand Down Expand Up @@ -371,6 +415,25 @@ class SemaOpenACC : public SemaBase {
Details = IntExprDetails{std::move(IntExprs)};
}

void setGangDetails(ArrayRef<OpenACCGangKind> GKs,
ArrayRef<Expr *> IntExprs) {
assert(ClauseKind == OpenACCClauseKind::Gang &&
"Parsed Clause kind does not have gang details");
assert(GKs.size() == IntExprs.size() && "Mismatched kind/size?");

Details = GangDetails{{GKs.begin(), GKs.end()},
{IntExprs.begin(), IntExprs.end()}};
}

void setGangDetails(llvm::SmallVector<OpenACCGangKind> &&GKs,
llvm::SmallVector<Expr *> &&IntExprs) {
assert(ClauseKind == OpenACCClauseKind::Gang &&
"Parsed Clause kind does not have gang details");
assert(GKs.size() == IntExprs.size() && "Mismatched kind/size?");

Details = GangDetails{std::move(GKs), std::move(IntExprs)};
}

void setVarListDetails(ArrayRef<Expr *> VarList, bool IsReadOnly,
bool IsZero) {
assert((ClauseKind == OpenACCClauseKind::Private ||
Expand Down Expand Up @@ -545,10 +608,12 @@ class SemaOpenACC : public SemaBase {
SourceLocation RBLoc);
/// Checks the loop depth value for a collapse clause.
ExprResult CheckCollapseLoopCount(Expr *LoopCount);
/// Checks a single size expr for a tile clause. 'gang' could possibly call
/// this, but has slightly stricter rules as to valid values.
/// Checks a single size expr for a tile clause.
ExprResult CheckTileSizeExpr(Expr *SizeExpr);

// Check a single expression on a gang clause.
ExprResult CheckGangExpr(OpenACCGangKind GK, Expr *E);

ExprResult BuildOpenACCAsteriskSizeExpr(SourceLocation AsteriskLoc);
ExprResult ActOnOpenACCAsteriskSizeExpr(SourceLocation AsteriskLoc);

Expand Down Expand Up @@ -595,8 +660,9 @@ class SemaOpenACC : public SemaBase {
/// Loop needing its parent construct.
class AssociatedStmtRAII {
SemaOpenACC &SemaRef;
bool WasInsideComputeConstruct;
ComputeConstructInfo OldActiveComputeConstructInfo;
OpenACCDirectiveKind DirKind;
SourceLocation OldLoopGangClauseOnKernelLoc;
llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;
LoopInConstructRAII LoopRAII;

Expand Down
50 changes: 44 additions & 6 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bool OpenACCClauseWithExprs::classof(const OpenACCClause *C) {
return OpenACCWaitClause::classof(C) || OpenACCNumGangsClause::classof(C) ||
OpenACCTileClause::classof(C) ||
OpenACCClauseWithSingleIntExpr::classof(C) ||
OpenACCClauseWithVarList::classof(C);
OpenACCGangClause::classof(C) || OpenACCClauseWithVarList::classof(C);
}
bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
return OpenACCPrivateClause::classof(C) ||
Expand Down Expand Up @@ -125,6 +125,21 @@ OpenACCNumWorkersClause::OpenACCNumWorkersClause(SourceLocation BeginLoc,
"Condition expression type not scalar/dependent");
}

OpenACCGangClause::OpenACCGangClause(SourceLocation BeginLoc,
SourceLocation LParenLoc,
ArrayRef<OpenACCGangKind> GangKinds,
ArrayRef<Expr *> IntExprs,
SourceLocation EndLoc)
: OpenACCClauseWithExprs(OpenACCClauseKind::Gang, BeginLoc, LParenLoc,
EndLoc) {
assert(GangKinds.size() == IntExprs.size() && "Mismatch exprs/kind?");
std::uninitialized_copy(IntExprs.begin(), IntExprs.end(),
getTrailingObjects<Expr *>());
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), IntExprs.size()));
std::uninitialized_copy(GangKinds.begin(), GangKinds.end(),
getTrailingObjects<OpenACCGangKind>());
}

OpenACCNumWorkersClause *
OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
Expand Down Expand Up @@ -376,11 +391,16 @@ OpenACCSeqClause *OpenACCSeqClause::Create(const ASTContext &C,
return new (Mem) OpenACCSeqClause(BeginLoc, EndLoc);
}

OpenACCGangClause *OpenACCGangClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCGangClause));
return new (Mem) OpenACCGangClause(BeginLoc, EndLoc);
OpenACCGangClause *
OpenACCGangClause::Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc,
ArrayRef<OpenACCGangKind> GangKinds,
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc) {
void *Mem =
C.Allocate(OpenACCGangClause::totalSizeToAlloc<Expr *, OpenACCGangKind>(
IntExprs.size(), GangKinds.size()));
return new (Mem)
OpenACCGangClause(BeginLoc, LParenLoc, GangKinds, IntExprs, EndLoc);
}

OpenACCWorkerClause *OpenACCWorkerClause::Create(const ASTContext &C,
Expand Down Expand Up @@ -600,3 +620,21 @@ void OpenACCClausePrinter::VisitCollapseClause(const OpenACCCollapseClause &C) {
printExpr(C.getLoopCount());
OS << ")";
}

void OpenACCClausePrinter::VisitGangClause(const OpenACCGangClause &C) {
OS << "gang";

if (C.getNumExprs() > 0) {
OS << "(";
bool first = true;
for (unsigned I = 0; I < C.getNumExprs(); ++I) {
if (!first)
OS << ", ";
first = false;

OS << C.getExpr(I).first << ": ";
printExpr(C.getExpr(I).second);
}
OS << ")";
}
}
Loading
Loading