Skip to content

[OpenACC] Implement 'loop' 'vector' clause #112259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 16 additions & 26 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,32 +119,6 @@ class OpenACCSeqClause : public OpenACCClause {
}
};

// Not yet implemented, but the type name is necessary for 'seq' diagnostics, so
// this provides a basic, do-nothing implementation. We still need to add this
// type to the visitors/etc, as well as get it to take its proper arguments.
class OpenACCVectorClause : public OpenACCClause {
protected:
OpenACCVectorClause(SourceLocation BeginLoc, SourceLocation EndLoc)
: OpenACCClause(OpenACCClauseKind::Vector, BeginLoc, EndLoc) {
llvm_unreachable("Not yet implemented");
}

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Vector;
}

static OpenACCVectorClause *
Create(const ASTContext &Ctx, SourceLocation BeginLoc, SourceLocation EndLoc);

child_range children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}
};

/// Represents a clause that has a list of parameters.
class OpenACCClauseWithParams : public OpenACCClause {
/// Location of the '('.
Expand Down Expand Up @@ -531,6 +505,22 @@ class OpenACCWorkerClause : public OpenACCClauseWithSingleIntExpr {
SourceLocation EndLoc);
};

class OpenACCVectorClause : public OpenACCClauseWithSingleIntExpr {
protected:
OpenACCVectorClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::Vector;
}

static OpenACCVectorClause *Create(const ASTContext &Ctx,
SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
SourceLocation EndLoc);
};

class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12702,7 +12702,7 @@ def err_acc_gang_dim_value
def err_acc_num_arg_conflict
: Error<"'num' argument to '%0' clause not allowed on a 'loop' construct "
"associated with a 'kernels' construct that has a "
"'%select{num_gangs|num_workers}1' "
"'%select{num_gangs|num_workers|vector_length}1' "
"clause">;
def err_acc_clause_in_clause_region
: Error<"loop with a '%0' clause may not exist in the region of a '%1' "
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ VISIT_CLAUSE(Reduction)
VISIT_CLAUSE(Self)
VISIT_CLAUSE(Seq)
VISIT_CLAUSE(Tile)
VISIT_CLAUSE(Vector)
VISIT_CLAUSE(VectorLength)
VISIT_CLAUSE(Wait)
VISIT_CLAUSE(Worker)
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ class SemaOpenACC : public SemaBase {
/// permits us to implement the restriction of no further 'gang' or 'worker'
/// clauses.
SourceLocation LoopWorkerClauseLoc;
/// If there is a current 'active' loop construct with a 'vector' clause on it
/// (on any sort of construct), this has the source location for it. This
/// permits us to implement the restriction of no further 'gang', 'vector', or
/// 'worker' clauses.
SourceLocation LoopVectorClauseLoc;

// Redeclaration of the version in OpenACCClause.h.
using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;
Expand Down Expand Up @@ -679,6 +684,7 @@ class SemaOpenACC : public SemaBase {
OpenACCDirectiveKind DirKind;
SourceLocation OldLoopGangClauseOnKernelLoc;
SourceLocation OldLoopWorkerClauseLoc;
SourceLocation OldLoopVectorClauseLoc;
llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;
LoopInConstructRAII LoopRAII;

Expand Down
31 changes: 27 additions & 4 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
return OpenACCNumWorkersClause::classof(C) ||
OpenACCVectorLengthClause::classof(C) ||
OpenACCWorkerClause::classof(C) || OpenACCCollapseClause::classof(C) ||
OpenACCAsyncClause::classof(C);
OpenACCVectorClause::classof(C) || OpenACCWorkerClause::classof(C) ||
OpenACCCollapseClause::classof(C) || OpenACCAsyncClause::classof(C);
}
OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
OpenACCDefaultClauseKind K,
Expand Down Expand Up @@ -424,11 +424,24 @@ OpenACCWorkerClause *OpenACCWorkerClause::Create(const ASTContext &C,
return new (Mem) OpenACCWorkerClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
}

OpenACCVectorClause::OpenACCVectorClause(SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc)
: OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::Vector, BeginLoc,
LParenLoc, IntExpr, EndLoc) {
assert((!IntExpr || IntExpr->isInstantiationDependent() ||
IntExpr->getType()->isIntegerType()) &&
"Int expression type not scalar/dependent");
}

OpenACCVectorClause *OpenACCVectorClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *IntExpr,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCVectorClause));
return new (Mem) OpenACCVectorClause(BeginLoc, EndLoc);
void *Mem =
C.Allocate(sizeof(OpenACCVectorClause), alignof(OpenACCVectorClause));
return new (Mem) OpenACCVectorClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -662,3 +675,13 @@ void OpenACCClausePrinter::VisitWorkerClause(const OpenACCWorkerClause &C) {
OS << ")";
}
}

void OpenACCClausePrinter::VisitVectorClause(const OpenACCVectorClause &C) {
OS << "vector";

if (C.hasIntExpr()) {
OS << "(length: ";
printExpr(C.getIntExpr());
OS << ")";
}
}
6 changes: 6 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2635,6 +2635,12 @@ void OpenACCClauseProfiler::VisitWorkerClause(
Profiler.VisitStmt(Clause.getIntExpr());
}

void OpenACCClauseProfiler::VisitVectorClause(
const OpenACCVectorClause &Clause) {
if (Clause.hasIntExpr())
Profiler.VisitStmt(Clause.getIntExpr());
}

void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
if (Clause.hasDevNumExpr())
Profiler.VisitStmt(Clause.getDevNumExpr());
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
case OpenACCClauseKind::Seq:
case OpenACCClauseKind::Tile:
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::VectorLength:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
Expand Down
146 changes: 138 additions & 8 deletions clang/lib/Sema/SemaOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
return false;
}
}
case OpenACCClauseKind::Vector: {
switch (DirectiveKind) {
case OpenACCDirectiveKind::Loop:
case OpenACCDirectiveKind::ParallelLoop:
case OpenACCDirectiveKind::SerialLoop:
case OpenACCDirectiveKind::KernelsLoop:
case OpenACCDirectiveKind::Routine:
return true;
default:
return false;
}
}
}

default:
Expand Down Expand Up @@ -512,14 +524,6 @@ class SemaOpenACCClauseVisitor {

OpenACCClause *Visit(SemaOpenACC::OpenACCParsedClause &Clause) {
switch (Clause.getClauseKind()) {
case OpenACCClauseKind::Vector: {
// TODO OpenACC: These are only implemented enough for the 'seq'
// diagnostic, otherwise treats itself as unimplemented. When we
// implement these, we can remove them from here.
DiagIfSeqClause(Clause);
return isNotImplemented();
}

#define VISIT_CLAUSE(CLAUSE_NAME) \
case OpenACCClauseKind::CLAUSE_NAME: \
return Visit##CLAUSE_NAME##Clause(Clause);
Expand Down Expand Up @@ -1035,6 +1039,97 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitIndependentClause(
Clause.getEndLoc());
}

OpenACCClause *SemaOpenACCClauseVisitor::VisitVectorClause(
SemaOpenACC::OpenACCParsedClause &Clause) {
if (DiagIfSeqClause(Clause))
return nullptr;
// Restrictions only properly implemented on 'loop' constructs, and it is
// the only construct that can do anything with this, so skip/treat as
// unimplemented for the combined constructs.
if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop)
return isNotImplemented();

Expr *IntExpr =
Clause.getNumIntExprs() != 0 ? Clause.getIntExprs()[0] : nullptr;
if (IntExpr) {
switch (SemaRef.getActiveComputeConstructInfo().Kind) {
case OpenACCDirectiveKind::Invalid:
case OpenACCDirectiveKind::Parallel:
// No restriction on when 'parallel' can contain an argument.
break;
case OpenACCDirectiveKind::Serial:
// GCC disallows this, and there is no real good reason for us to permit
// it, so disallow until we come up with a use case that makes sense.
SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
<< OpenACCClauseKind::Vector << "num" << /*serial=*/3;
IntExpr = nullptr;
break;
case OpenACCDirectiveKind::Kernels: {
const auto *Itr =
llvm::find_if(SemaRef.getActiveComputeConstructInfo().Clauses,
llvm::IsaPred<OpenACCVectorLengthClause>);
if (Itr != SemaRef.getActiveComputeConstructInfo().Clauses.end()) {
SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_num_arg_conflict)
<< OpenACCClauseKind::Vector << /*vector_length=*/2;
SemaRef.Diag((*Itr)->getBeginLoc(),
diag::note_acc_previous_clause_here);

IntExpr = nullptr;
}
break;
}
default:
llvm_unreachable("Non compute construct in active compute construct");
}
}

// OpenACC 3.3 2.9.2: When the parent compute construct is a kernels
// construct, the gang clause behaves as follows. ... The region of a loop
// with a gang clause may not contain another loop with a gang clause unless
// within a nested compute region.
if (SemaRef.LoopGangClauseOnKernelLoc.isValid()) {
// This handles the 'inner loop' diagnostic, but we cannot set that we're on
// one of these until we get to the end of the construct.
SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
<< OpenACCClauseKind::Vector << OpenACCClauseKind::Gang
<< /*skip kernels construct info*/ 0;
SemaRef.Diag(SemaRef.LoopGangClauseOnKernelLoc,
diag::note_acc_previous_clause_here);
return nullptr;
}

// OpenACC 3.3 2.9.3: The region of a loop with a 'worker' clause may not
// contain a loop with a gang or worker clause unless within a nested compute
// region.
if (SemaRef.LoopWorkerClauseLoc.isValid()) {
// This handles the 'inner loop' diagnostic, but we cannot set that we're on
// one of these until we get to the end of the construct.
SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
<< OpenACCClauseKind::Vector << OpenACCClauseKind::Worker
<< /*skip kernels construct info*/ 0;
SemaRef.Diag(SemaRef.LoopWorkerClauseLoc,
diag::note_acc_previous_clause_here);
return nullptr;
}
// OpenACC 3.3 2.9.4: The region of a loop with a 'vector' clause may not
// contain a loop with a gang, worker, or vector clause unless within a nested
// compute region.
if (SemaRef.LoopVectorClauseLoc.isValid()) {
// This handles the 'inner loop' diagnostic, but we cannot set that we're on
// one of these until we get to the end of the construct.
SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
<< OpenACCClauseKind::Vector << OpenACCClauseKind::Vector
<< /*skip kernels construct info*/ 0;
SemaRef.Diag(SemaRef.LoopVectorClauseLoc,
diag::note_acc_previous_clause_here);
return nullptr;
}

return OpenACCVectorClause::Create(Ctx, Clause.getBeginLoc(),
Clause.getLParenLoc(), IntExpr,
Clause.getEndLoc());
}

OpenACCClause *SemaOpenACCClauseVisitor::VisitWorkerClause(
SemaOpenACC::OpenACCParsedClause &Clause) {
if (DiagIfSeqClause(Clause))
Expand Down Expand Up @@ -1099,6 +1194,20 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitWorkerClause(
return nullptr;
}

// OpenACC 3.3 2.9.4: The region of a loop with a 'vector' clause may not
// contain a loop with a gang, worker, or vector clause unless within a nested
// compute region.
if (SemaRef.LoopVectorClauseLoc.isValid()) {
// This handles the 'inner loop' diagnostic, but we cannot set that we're on
// one of these until we get to the end of the construct.
SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
<< OpenACCClauseKind::Worker << OpenACCClauseKind::Vector
<< /*skip kernels construct info*/ 0;
SemaRef.Diag(SemaRef.LoopVectorClauseLoc,
diag::note_acc_previous_clause_here);
return nullptr;
}

return OpenACCWorkerClause::Create(Ctx, Clause.getBeginLoc(),
Clause.getLParenLoc(), IntExpr,
Clause.getEndLoc());
Expand Down Expand Up @@ -1193,6 +1302,20 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
return nullptr;
}

// OpenACC 3.3 2.9.4: The region of a loop with a 'vector' clause may not
// contain a loop with a gang, worker, or vector clause unless within a nested
// compute region.
if (SemaRef.LoopVectorClauseLoc.isValid()) {
// This handles the 'inner loop' diagnostic, but we cannot set that we're on
// one of these until we get to the end of the construct.
SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
<< OpenACCClauseKind::Gang << OpenACCClauseKind::Vector
<< /*kernels construct info*/ 1;
SemaRef.Diag(SemaRef.LoopVectorClauseLoc,
diag::note_acc_previous_clause_here);
return nullptr;
}

return OpenACCGangClause::Create(Ctx, Clause.getBeginLoc(),
Clause.getLParenLoc(), GangKinds, IntExprs,
Clause.getEndLoc());
Expand Down Expand Up @@ -1313,6 +1436,7 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
: SemaRef(S), OldActiveComputeConstructInfo(S.ActiveComputeConstructInfo),
DirKind(DK), OldLoopGangClauseOnKernelLoc(S.LoopGangClauseOnKernelLoc),
OldLoopWorkerClauseLoc(S.LoopWorkerClauseLoc),
OldLoopVectorClauseLoc(S.LoopVectorClauseLoc),
LoopRAII(SemaRef, /*PreserveDepth=*/false) {
// Compute constructs end up taking their 'loop'.
if (DirKind == OpenACCDirectiveKind::Parallel ||
Expand All @@ -1330,6 +1454,7 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
// Implement the 'unless within a nested compute region' part.
SemaRef.LoopGangClauseOnKernelLoc = {};
SemaRef.LoopWorkerClauseLoc = {};
SemaRef.LoopVectorClauseLoc = {};
} else if (DirKind == OpenACCDirectiveKind::Loop) {
SetCollapseInfoBeforeAssociatedStmt(UnInstClauses, Clauses);
SetTileInfoBeforeAssociatedStmt(UnInstClauses, Clauses);
Expand All @@ -1355,6 +1480,10 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
auto *Itr = llvm::find_if(Clauses, llvm::IsaPred<OpenACCWorkerClause>);
if (Itr != Clauses.end())
SemaRef.LoopWorkerClauseLoc = (*Itr)->getBeginLoc();

auto *Itr2 = llvm::find_if(Clauses, llvm::IsaPred<OpenACCVectorClause>);
if (Itr2 != Clauses.end())
SemaRef.LoopVectorClauseLoc = (*Itr2)->getBeginLoc();
}
}
}
Expand Down Expand Up @@ -1429,6 +1558,7 @@ SemaOpenACC::AssociatedStmtRAII::~AssociatedStmtRAII() {
SemaRef.ActiveComputeConstructInfo = OldActiveComputeConstructInfo;
SemaRef.LoopGangClauseOnKernelLoc = OldLoopGangClauseOnKernelLoc;
SemaRef.LoopWorkerClauseLoc = OldLoopWorkerClauseLoc;
SemaRef.LoopVectorClauseLoc = OldLoopVectorClauseLoc;

if (DirKind == OpenACCDirectiveKind::Parallel ||
DirKind == OpenACCDirectiveKind::Serial ||
Expand Down
Loading
Loading