Skip to content

[OpenACC] Implement 'num_workers' clause for compute constructs #89151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,46 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
Expr *ConditionExpr, SourceLocation EndLoc);
};

/// Represents oen of a handful of classes that have a single integer
/// expression.
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithParams {
Expr *IntExpr;

protected:
OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
SourceLocation EndLoc)
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
IntExpr(IntExpr) {}

public:
bool hasIntExpr() const { return IntExpr; }
const Expr *getIntExpr() const { return IntExpr; }

Expr *getIntExpr() { return IntExpr; };

child_range children() {
return child_range(reinterpret_cast<Stmt **>(&IntExpr),
reinterpret_cast<Stmt **>(&IntExpr + 1));
}

const_child_range children() const {
return const_child_range(reinterpret_cast<Stmt *const *>(&IntExpr),
reinterpret_cast<Stmt *const *>(&IntExpr + 1));
}
};

class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);

public:
static OpenACCNumWorkersClause *Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12268,4 +12268,16 @@ def warn_acc_if_self_conflict
: Warning<"OpenACC construct 'self' has no effect when an 'if' clause "
"evaluates to true">,
InGroup<DiagGroup<"openacc-self-if-potential-conflict">>;
def err_acc_int_expr_requires_integer
: Error<"OpenACC %select{clause|directive}0 '%1' requires expression of "
"integer type (%2 invalid)">;
def err_acc_int_expr_incomplete_class_type
: Error<"OpenACC integer expression has incomplete class type %0">;
def err_acc_int_expr_explicit_conversion
: Error<"OpenACC integer expression type %0 requires explicit conversion "
"to %1">;
def note_acc_int_expr_conversion
: Note<"conversion to %select{integral|enumeration}0 type %1">;
def err_acc_int_expr_multiple_conversions
: Error<"multiple conversions from expression type %0 to an integral type">;
} // end of sema component.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
VISIT_CLAUSE(Default)
VISIT_CLAUSE(If)
VISIT_CLAUSE(Self)
VISIT_CLAUSE(NumWorkers)

#undef VISIT_CLAUSE
9 changes: 5 additions & 4 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3640,13 +3640,14 @@ class Parser : public CodeCompletionHandler {
/// Parses the clause-list for an OpenACC directive.
SmallVector<OpenACCClause *>
ParseOpenACCClauseList(OpenACCDirectiveKind DirKind);
bool ParseOpenACCWaitArgument();
bool ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective);
/// Parses the clause of the 'bind' argument, which can be a string literal or
/// an ID expression.
ExprResult ParseOpenACCBindClauseArgument();
/// Parses the clause kind of 'int-expr', which can be any integral
/// expression.
ExprResult ParseOpenACCIntExpr();
ExprResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
SourceLocation Loc);
/// Parses the 'device-type-list', which is a list of identifiers.
bool ParseOpenACCDeviceTypeList();
/// Parses the 'async-argument', which is an integral value with two
Expand All @@ -3657,9 +3658,9 @@ class Parser : public CodeCompletionHandler {
/// Parses a comma delimited list of 'size-expr's.
bool ParseOpenACCSizeExprList();
/// Parses a 'gang-arg-list', used for the 'gang' clause.
bool ParseOpenACCGangArgList();
bool ParseOpenACCGangArgList(SourceLocation GangLoc);
/// Parses a 'gang-arg', used for the 'gang' clause.
bool ParseOpenACCGangArg();
bool ParseOpenACCGangArg(SourceLocation GangLoc);
/// Parses a 'condition' expr, ensuring it results in a
ExprResult ParseOpenACCConditionExpr();

Expand Down
36 changes: 34 additions & 2 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ class SemaOpenACC : public SemaBase {
Expr *ConditionExpr;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails> Details =
std::monostate{};
struct IntExprDetails {
SmallVector<Expr *> IntExprs;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails>
Details = std::monostate{};

public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
Expand Down Expand Up @@ -87,6 +92,22 @@ class SemaOpenACC : public SemaBase {
return std::get<ConditionDetails>(Details).ConditionExpr;
}

unsigned getNumIntExprs() const {
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
"Parsed clause kind does not have a int exprs");
return std::get<IntExprDetails>(Details).IntExprs.size();
}

ArrayRef<Expr *> getIntExprs() {
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
"Parsed clause kind does not have a int exprs");
return std::get<IntExprDetails>(Details).IntExprs;
}

ArrayRef<Expr *> getIntExprs() const {
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

Expand All @@ -109,6 +130,12 @@ class SemaOpenACC : public SemaBase {

Details = ConditionDetails{ConditionExpr};
}

void setIntExprDetails(ArrayRef<Expr *> IntExprs) {
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
"Parsed clause kind does not have a int exprs");
Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
}
};

SemaOpenACC(Sema &S);
Expand Down Expand Up @@ -148,6 +175,11 @@ class SemaOpenACC : public SemaBase {
/// Called after the directive has been completely parsed, including the
/// declaration group or associated statement.
DeclGroupRef ActOnEndDeclDirective();

/// Called when encountering an 'int-expr' for OpenACC, and manages
/// conversions and diagnostics to 'int'.
ExprResult ActOnIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
SourceLocation Loc, Expr *IntExpr);
};

} // namespace clang
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,27 @@ OpenACCClause::child_range OpenACCClause::children() {
return child_range(child_iterator(), child_iterator());
}

OpenACCNumWorkersClause::OpenACCNumWorkersClause(SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *IntExpr,
SourceLocation EndLoc)
: OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::NumWorkers, BeginLoc,
LParenLoc, IntExpr, EndLoc) {
assert((!IntExpr || IntExpr->isInstantiationDependent() ||
IntExpr->getType()->isIntegerType()) &&
"Condition expression type not scalar/dependent");
}

OpenACCNumWorkersClause *
OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCNumWorkersClause),
alignof(OpenACCNumWorkersClause));
return new (Mem)
OpenACCNumWorkersClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
Expand All @@ -98,3 +119,8 @@ void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
if (const Expr *CondExpr = C.getConditionExpr())
OS << "(" << CondExpr << ")";
}

void OpenACCClausePrinter::VisitNumWorkersClause(
const OpenACCNumWorkersClause &C) {
OS << "num_workers(" << C.getIntExpr() << ")";
}
7 changes: 7 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,13 @@ void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
if (Clause.hasConditionExpr())
Profiler.VisitStmt(Clause.getConditionExpr());
}

void OpenACCClauseProfiler::VisitNumWorkersClause(
const OpenACCNumWorkersClause &Clause) {
assert(Clause.hasIntExpr() && "num_workers clause requires a valid int expr");
Profiler.VisitStmt(Clause.getIntExpr());
}

} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
break;
case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::NumWorkers:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
OS << " clause";
Expand Down
57 changes: 40 additions & 17 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,10 +632,16 @@ Parser::ParseOpenACCClauseList(OpenACCDirectiveKind DirKind) {
return Clauses;
}

ExprResult Parser::ParseOpenACCIntExpr() {
// FIXME: this is required to be an integer expression (or dependent), so we
// should ensure that is the case by passing this to SEMA here.
return getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
ExprResult Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
OpenACCClauseKind CK,
SourceLocation Loc) {
ExprResult ER =
getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());

if (!ER.isUsable())
return ER;

return getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get());
}

bool Parser::ParseOpenACCClauseVarList(OpenACCClauseKind Kind) {
Expand Down Expand Up @@ -739,7 +745,7 @@ bool Parser::ParseOpenACCSizeExprList() {
/// [num:]int-expr
/// dim:int-expr
/// static:size-expr
bool Parser::ParseOpenACCGangArg() {
bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {

if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Static, getCurToken()) &&
NextToken().is(tok::colon)) {
Expand All @@ -753,7 +759,9 @@ bool Parser::ParseOpenACCGangArg() {
NextToken().is(tok::colon)) {
ConsumeToken();
ConsumeToken();
return ParseOpenACCIntExpr().isInvalid();
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
OpenACCClauseKind::Gang, GangLoc)
.isInvalid();
}

if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Num, getCurToken()) &&
Expand All @@ -763,11 +771,13 @@ bool Parser::ParseOpenACCGangArg() {
// Fallthrough to the 'int-expr' handling for when 'num' is omitted.
}
// This is just the 'num' case where 'num' is optional.
return ParseOpenACCIntExpr().isInvalid();
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
OpenACCClauseKind::Gang, GangLoc)
.isInvalid();
}

bool Parser::ParseOpenACCGangArgList() {
if (ParseOpenACCGangArg()) {
bool Parser::ParseOpenACCGangArgList(SourceLocation GangLoc) {
if (ParseOpenACCGangArg(GangLoc)) {
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
Parser::StopBeforeMatch);
return false;
Expand All @@ -776,7 +786,7 @@ bool Parser::ParseOpenACCGangArgList() {
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
ExpectAndConsume(tok::comma);

if (ParseOpenACCGangArg()) {
if (ParseOpenACCGangArg(GangLoc)) {
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
Parser::StopBeforeMatch);
return false;
Expand Down Expand Up @@ -941,11 +951,18 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::DeviceNum:
case OpenACCClauseKind::DefaultAsync:
case OpenACCClauseKind::VectorLength: {
ExprResult IntExpr = ParseOpenACCIntExpr();
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
ClauseKind, ClauseLoc);
if (IntExpr.isInvalid()) {
Parens.skipToEnd();
return OpenACCCanContinue();
}

// TODO OpenACC: as we implement the 'rest' of the above, this 'if' should
// be removed leaving just the 'setIntExprDetails'.
if (ClauseKind == OpenACCClauseKind::NumWorkers)
ParsedClause.setIntExprDetails(IntExpr.get());

break;
}
case OpenACCClauseKind::DType:
Expand Down Expand Up @@ -998,7 +1015,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
? OpenACCSpecialTokenKind::Length
: OpenACCSpecialTokenKind::Num,
ClauseKind);
ExprResult IntExpr = ParseOpenACCIntExpr();
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
ClauseKind, ClauseLoc);
if (IntExpr.isInvalid()) {
Parens.skipToEnd();
return OpenACCCanContinue();
Expand All @@ -1014,13 +1032,14 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
break;
}
case OpenACCClauseKind::Gang:
if (ParseOpenACCGangArgList()) {
if (ParseOpenACCGangArgList(ClauseLoc)) {
Parens.skipToEnd();
return OpenACCCanContinue();
}
break;
case OpenACCClauseKind::Wait:
if (ParseOpenACCWaitArgument()) {
if (ParseOpenACCWaitArgument(ClauseLoc,
/*IsDirective=*/false)) {
Parens.skipToEnd();
return OpenACCCanContinue();
}
Expand Down Expand Up @@ -1052,7 +1071,7 @@ ExprResult Parser::ParseOpenACCAsyncArgument() {
/// In this section and throughout the specification, the term wait-argument
/// means:
/// [ devnum : int-expr : ] [ queues : ] async-argument-list
bool Parser::ParseOpenACCWaitArgument() {
bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
// [devnum : int-expr : ]
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::DevNum, Tok) &&
NextToken().is(tok::colon)) {
Expand All @@ -1061,7 +1080,11 @@ bool Parser::ParseOpenACCWaitArgument() {
// Consume colon.
ConsumeToken();

ExprResult IntExpr = ParseOpenACCIntExpr();
ExprResult IntExpr = ParseOpenACCIntExpr(
IsDirective ? OpenACCDirectiveKind::Wait
: OpenACCDirectiveKind::Invalid,
IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
Loc);
if (IntExpr.isInvalid())
return true;

Expand Down Expand Up @@ -1245,7 +1268,7 @@ Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
break;
case OpenACCDirectiveKind::Wait:
// OpenACC has an optional paren-wrapped 'wait-argument'.
if (ParseOpenACCWaitArgument())
if (ParseOpenACCWaitArgument(StartLoc, /*IsDirective=*/true))
T.skipToEnd();
else
T.consumeClose();
Expand Down
Loading