Skip to content

[NFC][OpenACC] Refactor clause emission- #140586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,21 @@ class CIRGenFunction : public CIRGenTypeCache {
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
const Stmt *loopStmt);

template <typename Op>
void emitOpenACCClauses(Op &op, OpenACCDirectiveKind dirKind,
SourceLocation dirLoc,
ArrayRef<const OpenACCClause *> clauses);
// The second template argument doesn't need to be a template, since it should
// always be an mlir::acc::LoopOp, but as this is a template anyway, we make
// it a template argument as this way we can avoid including the OpenACC MLIR
// headers here. We will count on linker failures/explicit instantiation to
// ensure we don't mess this up, but it is only called from 1 place, and
// instantiated 3x.
template <typename ComputeOp, typename LoopOp>
void emitOpenACCClauses(ComputeOp &op, LoopOp &loopOp,
OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
ArrayRef<const OpenACCClause *> clauses);

public:
mlir::LogicalResult
emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@

#include <type_traits>

#include "CIRGenFunction.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "llvm/ADT/TypeSwitch.h"
namespace clang {

using namespace clang;
using namespace clang::CIRGen;

namespace {
// Simple type-trait to see if the first template arg is one of the list, so we
// can tell whether to `if-constexpr` a bunch of stuff.
template <typename ToTest, typename T, typename... Tys>
Expand All @@ -32,77 +38,10 @@ template <typename CompOpTy> struct CombinedConstructClauseInfo {
ComputeOpTy computeOp;
mlir::acc::LoopOp loopOp;
};

template <typename ToTest> constexpr bool isCombinedType = false;
template <typename T>
constexpr bool isCombinedType<CombinedConstructClauseInfo<T>> = true;

namespace {
struct DataOperandInfo {
mlir::Location beginLoc;
mlir::Value varValue;
llvm::StringRef name;
};

inline mlir::Value emitOpenACCIntExpr(CIRGen::CIRGenFunction &cgf,
CIRGen::CIRGenBuilderTy &builder,
const Expr *intExpr) {
mlir::Value expr = cgf.emitScalarExpr(intExpr);
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());

mlir::IntegerType targetType = mlir::IntegerType::get(
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
intExpr->getType()->isSignedIntegerOrEnumerationType()
? mlir::IntegerType::SignednessSemantics::Signed
: mlir::IntegerType::SignednessSemantics::Unsigned);

auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
exprLoc, targetType, expr);
return conversionOp.getResult(0);
}

// A helper function that gets the information from an operand to a data
// clause, so that it can be used to emit the data operations.
inline DataOperandInfo getDataOperandInfo(CIRGen::CIRGenFunction &cgf,
CIRGen::CIRGenBuilderTy &builder,
OpenACCDirectiveKind dk,
const Expr *e) {
// TODO: OpenACC: Cache was different enough as to need a separate
// `ActOnCacheVar`, so we are going to need to do some investigations here
// when it comes to implement this for cache.
if (dk == OpenACCDirectiveKind::Cache) {
cgf.cgm.errorNYI(e->getSourceRange(),
"OpenACC data operand for 'cache' directive");
return {cgf.cgm.getLoc(e->getBeginLoc()), {}, {}};
}

const Expr *curVarExpr = e->IgnoreParenImpCasts();

mlir::Location exprLoc = cgf.cgm.getLoc(curVarExpr->getBeginLoc());

// TODO: OpenACC: Assemble the list of bounds.
if (isa<ArraySectionExpr, ArraySubscriptExpr>(curVarExpr)) {
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
"OpenACC data clause array subscript/section");
return {exprLoc, {}, {}};
}

// TODO: OpenACC: if this is a member expr, emit the VarPtrPtr correctly.
if (isa<MemberExpr>(curVarExpr)) {
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
"OpenACC Data clause member expr");
return {exprLoc, {}, {}};
}

// Sema has made sure that only 4 types of things can get here, array
// subscript, array section, member expr, or DRE to a var decl (or the former
// 3 wrapping a var-decl), so we should be able to assume this is right.
const auto *dre = cast<DeclRefExpr>(curVarExpr);
const auto *vd = cast<VarDecl>(dre->getFoundDecl()->getCanonicalDecl());
return {exprLoc, cgf.emitDeclRefLValue(dre).getPointer(), vd->getName()};
}
} // namespace

template <typename OpTy>
class OpenACCClauseCIREmitter final
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
Expand All @@ -127,6 +66,10 @@ class OpenACCClauseCIREmitter final
// Keep track of the data operands so that we can update their async clauses.
llvm::SmallVector<mlir::Operation *> dataOperands;

void clauseNotImplemented(const OpenACCClause &c) {
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
}

void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
lastDeviceTypeValues.clear();

Expand All @@ -137,12 +80,19 @@ class OpenACCClauseCIREmitter final
});
}

void clauseNotImplemented(const OpenACCClause &c) {
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
}
mlir::Value emitIntExpr(const Expr *intExpr) {
mlir::Value expr = cgf.emitScalarExpr(intExpr);
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());

mlir::Value emitOpenACCIntExpr(const Expr *intExpr) {
return clang::emitOpenACCIntExpr(cgf, builder, intExpr);
mlir::IntegerType targetType = mlir::IntegerType::get(
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
intExpr->getType()->isSignedIntegerOrEnumerationType()
? mlir::IntegerType::SignednessSemantics::Signed
: mlir::IntegerType::SignednessSemantics::Unsigned);

auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
exprLoc, targetType, expr);
return conversionOp.getResult(0);
}

// 'condition' as an OpenACC grammar production is used for 'if' and (some
Expand Down Expand Up @@ -218,11 +168,56 @@ class OpenACCClauseCIREmitter final
computeEmitter.Visit(&c);
}

struct DataOperandInfo {
mlir::Location beginLoc;
mlir::Value varValue;
llvm::StringRef name;
};

// A helper function that gets the information from an operand to a data
// clause, so that it can be used to emit the data operations.
inline DataOperandInfo getDataOperandInfo(OpenACCDirectiveKind dk,
const Expr *e) {
// TODO: OpenACC: Cache was different enough as to need a separate
// `ActOnCacheVar`, so we are going to need to do some investigations here
// when it comes to implement this for cache.
if (dk == OpenACCDirectiveKind::Cache) {
cgf.cgm.errorNYI(e->getSourceRange(),
"OpenACC data operand for 'cache' directive");
return {cgf.cgm.getLoc(e->getBeginLoc()), {}, {}};
}

const Expr *curVarExpr = e->IgnoreParenImpCasts();

mlir::Location exprLoc = cgf.cgm.getLoc(curVarExpr->getBeginLoc());

// TODO: OpenACC: Assemble the list of bounds.
if (isa<ArraySectionExpr, ArraySubscriptExpr>(curVarExpr)) {
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
"OpenACC data clause array subscript/section");
return {exprLoc, {}, {}};
}

// TODO: OpenACC: if this is a member expr, emit the VarPtrPtr correctly.
if (isa<MemberExpr>(curVarExpr)) {
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
"OpenACC Data clause member expr");
return {exprLoc, {}, {}};
}

// Sema has made sure that only 4 types of things can get here, array
// subscript, array section, member expr, or DRE to a var decl (or the
// former 3 wrapping a var-decl), so we should be able to assume this is
// right.
const auto *dre = cast<DeclRefExpr>(curVarExpr);
const auto *vd = cast<VarDecl>(dre->getFoundDecl()->getCanonicalDecl());
return {exprLoc, cgf.emitDeclRefLValue(dre).getPointer(), vd->getName()};
}

template <typename BeforeOpTy, typename AfterOpTy>
void addDataOperand(const Expr *varOperand, mlir::acc::DataClause dataClause,
bool structured, bool implicit) {
DataOperandInfo opInfo =
getDataOperandInfo(cgf, builder, dirKind, varOperand);
DataOperandInfo opInfo = getDataOperandInfo(dirKind, varOperand);
mlir::ValueRange bounds;

// TODO: OpenACC: we should comprehend the 'modifier-list' here for the data
Expand Down Expand Up @@ -394,7 +389,7 @@ class OpenACCClauseCIREmitter final
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
mlir::acc::KernelsOp>) {
operation.addNumWorkersOperand(builder.getContext(),
emitOpenACCIntExpr(clause.getIntExpr()),
emitIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
Expand All @@ -407,7 +402,7 @@ class OpenACCClauseCIREmitter final
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
mlir::acc::KernelsOp>) {
operation.addVectorLengthOperand(builder.getContext(),
emitOpenACCIntExpr(clause.getIntExpr()),
emitIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
Expand All @@ -432,7 +427,7 @@ class OpenACCClauseCIREmitter final
mlir::OpBuilder::InsertionGuard guardCase(builder);
if (!dataOperands.empty())
builder.setInsertionPoint(dataOperands.front());
intExpr = emitOpenACCIntExpr(clause.getIntExpr());
intExpr = emitIntExpr(clause.getIntExpr());
}
operation.addAsyncOperand(builder.getContext(), intExpr,
lastDeviceTypeValues);
Expand All @@ -444,7 +439,7 @@ class OpenACCClauseCIREmitter final
operation.setAsync(true);
else
operation.getAsyncOperandMutable().append(
emitOpenACCIntExpr(clause.getIntExpr()));
emitIntExpr(clause.getIntExpr()));
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
Expand Down Expand Up @@ -499,8 +494,7 @@ class OpenACCClauseCIREmitter final
void VisitDeviceNumClause(const OpenACCDeviceNumClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp, mlir::acc::ShutdownOp,
mlir::acc::SetOp>) {
operation.getDeviceNumMutable().append(
emitOpenACCIntExpr(clause.getIntExpr()));
operation.getDeviceNumMutable().append(emitIntExpr(clause.getIntExpr()));
} else {
llvm_unreachable(
"init, shutdown, set, are only valid device_num constructs");
Expand All @@ -512,7 +506,7 @@ class OpenACCClauseCIREmitter final
mlir::acc::KernelsOp>) {
llvm::SmallVector<mlir::Value> values;
for (const Expr *E : clause.getIntExprs())
values.push_back(emitOpenACCIntExpr(E));
values.push_back(emitIntExpr(E));

operation.addNumGangsOperands(builder.getContext(), values,
lastDeviceTypeValues);
Expand All @@ -531,9 +525,9 @@ class OpenACCClauseCIREmitter final
} else {
llvm::SmallVector<mlir::Value> values;
if (clause.hasDevNumExpr())
values.push_back(emitOpenACCIntExpr(clause.getDevNumExpr()));
values.push_back(emitIntExpr(clause.getDevNumExpr()));
for (const Expr *E : clause.getQueueIdExprs())
values.push_back(emitOpenACCIntExpr(E));
values.push_back(emitIntExpr(E));
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
values, lastDeviceTypeValues);
}
Expand All @@ -549,7 +543,7 @@ class OpenACCClauseCIREmitter final
void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
operation.getDefaultAsyncMutable().append(
emitOpenACCIntExpr(clause.getIntExpr()));
emitIntExpr(clause.getIntExpr()));
} else {
llvm_unreachable("set, is only valid device_num constructs");
}
Expand Down Expand Up @@ -639,7 +633,7 @@ class OpenACCClauseCIREmitter final
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
if (clause.hasIntExpr())
operation.addWorkerNumOperand(builder.getContext(),
emitOpenACCIntExpr(clause.getIntExpr()),
emitIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
else
operation.addEmptyWorker(builder.getContext(), lastDeviceTypeValues);
Expand All @@ -657,7 +651,7 @@ class OpenACCClauseCIREmitter final
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
if (clause.hasIntExpr())
operation.addVectorOperand(builder.getContext(),
emitOpenACCIntExpr(clause.getIntExpr()),
emitIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
else
operation.addEmptyVector(builder.getContext(), lastDeviceTypeValues);
Expand Down Expand Up @@ -693,7 +687,7 @@ class OpenACCClauseCIREmitter final
} else if (isa<OpenACCAsteriskSizeExpr>(expr)) {
values.push_back(createConstantInt(exprLoc, 64, -1));
} else {
values.push_back(emitOpenACCIntExpr(expr));
values.push_back(emitIntExpr(expr));
}
}

Expand Down Expand Up @@ -728,5 +722,54 @@ auto makeClauseEmitter(OpTy &op, CIRGen::CIRGenFunction &cgf,
OpenACCDirectiveKind dirKind, SourceLocation dirLoc) {
return OpenACCClauseCIREmitter<OpTy>(op, cgf, builder, dirKind, dirLoc);
}
} // namespace

template <typename Op>
void CIRGenFunction::emitOpenACCClauses(
Op &op, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
ArrayRef<const OpenACCClause *> clauses) {
mlir::OpBuilder::InsertionGuard guardCase(builder);

// Sets insertion point before the 'op', since every new expression needs to
// be before the operation.
builder.setInsertionPoint(op);
makeClauseEmitter(op, *this, builder, dirKind, dirLoc).emitClauses(clauses);
}

#define EXPL_SPEC(N) \
template void CIRGenFunction::emitOpenACCClauses<N>( \
N &, OpenACCDirectiveKind, SourceLocation, \
ArrayRef<const OpenACCClause *>);
EXPL_SPEC(mlir::acc::ParallelOp)
EXPL_SPEC(mlir::acc::SerialOp)
EXPL_SPEC(mlir::acc::KernelsOp)
EXPL_SPEC(mlir::acc::LoopOp)
EXPL_SPEC(mlir::acc::DataOp)
EXPL_SPEC(mlir::acc::InitOp)
EXPL_SPEC(mlir::acc::ShutdownOp)
EXPL_SPEC(mlir::acc::SetOp)
EXPL_SPEC(mlir::acc::WaitOp)
#undef EXPL_SPEC

template <typename ComputeOp, typename LoopOp>
void CIRGenFunction::emitOpenACCClauses(
ComputeOp &op, LoopOp &loopOp, OpenACCDirectiveKind dirKind,
SourceLocation dirLoc, ArrayRef<const OpenACCClause *> clauses) {
static_assert(std::is_same_v<mlir::acc::LoopOp, LoopOp>);

CombinedConstructClauseInfo<ComputeOp> inf{op, loopOp};
// We cannot set the insertion point here and do so in the emitter, but make
// sure we reset it with the 'guard' anyway.
mlir::OpBuilder::InsertionGuard guardCase(builder);
makeClauseEmitter(inf, *this, builder, dirKind, dirLoc).emitClauses(clauses);
}

#define EXPL_SPEC(N) \
template void CIRGenFunction::emitOpenACCClauses<N, mlir::acc::LoopOp>( \
N &, mlir::acc::LoopOp &, OpenACCDirectiveKind, SourceLocation, \
ArrayRef<const OpenACCClause *>);

} // namespace clang
EXPL_SPEC(mlir::acc::ParallelOp)
EXPL_SPEC(mlir::acc::SerialOp)
EXPL_SPEC(mlir::acc::KernelsOp)
#undef EXPL_SPEC
Loading