Skip to content

[OpenACC][CIR] Implement 'device_type' clause lowering for 'init'/'sh… #135102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,16 @@ class CIRGenFunction : public CIRGenTypeCache {
private:
template <typename Op>
mlir::LogicalResult
emitOpenACCOp(mlir::Location start,
emitOpenACCOp(mlir::Location start, OpenACCDirectiveKind dirKind,
SourceLocation dirLoc,
llvm::ArrayRef<const OpenACCClause *> clauses);
// Function to do the basic implementation of an operation with an Associated
// Statement. Models AssociatedStmtConstruct.
template <typename Op, typename TermOp>
mlir::LogicalResult
emitOpenACCOpAssociatedStmt(mlir::Location start, mlir::Location end,
llvm::ArrayRef<const OpenACCClause *> clauses,
const Stmt *associatedStmt);
mlir::LogicalResult emitOpenACCOpAssociatedStmt(
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
const Stmt *associatedStmt);

public:
mlir::LogicalResult
Expand Down
126 changes: 107 additions & 19 deletions clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// Emit OpenACC Stmt nodes as CIR code.
//
//===----------------------------------------------------------------------===//
#include <type_traits>

#include "CIRGenBuilder.h"
#include "CIRGenFunction.h"
Expand All @@ -23,22 +24,39 @@ using namespace cir;
using namespace mlir::acc;

namespace {
// Simple type-trait to see if the first template arg is one of the list, so we
// can tell whether to `if-constexpr` a bunch of stuff.
template <typename ToTest, typename T, typename... Tys>
constexpr bool isOneOfTypes =
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
template <typename ToTest, typename T>
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;

class OpenACCClauseCIREmitter final
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
CIRGenModule &cgm;
// This is necessary since a few of the clauses emit differently based on the
// directive kind they are attached to.
OpenACCDirectiveKind dirKind;
SourceLocation dirLoc;

struct AttributeData {
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
// constructs as a 'default-attr'.
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
// For directives that have their device type architectures listed in
// attributes (init/shutdown/etc), the list of architectures to be emitted.
llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
} attrData;

void clauseNotImplemented(const OpenACCClause &c) {
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
}

public:
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {}
OpenACCClauseCIREmitter(CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
SourceLocation dirLoc)
: cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}

void VisitClause(const OpenACCClause &clause) {
clauseNotImplemented(clause);
Expand All @@ -57,31 +75,92 @@ class OpenACCClauseCIREmitter final
}
}

mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
// '*' case leaves no identifier-info, just a nullptr.
if (!ii)
return mlir::acc::DeviceType::Star;
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
.CaseLower("default", mlir::acc::DeviceType::Default)
.CaseLower("host", mlir::acc::DeviceType::Host)
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
.CasesLower("nvidia", "acc_device_nvidia",
mlir::acc::DeviceType::Nvidia)
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
}

void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {

switch (dirKind) {
case OpenACCDirectiveKind::Init:
case OpenACCDirectiveKind::Shutdown: {
// Device type has a list that is either a 'star' (emitted as 'star'),
// or an identifer list, all of which get added for attributes.

for (const DeviceTypeArgument &arg : clause.getArchitectures())
attrData.deviceTypeArchs.push_back(decodeDeviceType(arg.first));
break;
}
default:
return clauseNotImplemented(clause);
}
}

// Apply any of the clauses that resulted in an 'attribute'.
template <typename Op> void applyAttributes(Op &op) {
if (attrData.defaultVal.has_value())
op.setDefaultAttr(*attrData.defaultVal);
template <typename Op>
void applyAttributes(CIRGenBuilderTy &builder, Op &op) {

if (attrData.defaultVal.has_value()) {
// FIXME: OpenACC: as we implement this for other directive kinds, we have
// to expand this list.
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
// operations listed in the rest of the arguments.
if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right that ParallelOp is acc::ParallelOp?

I was confused about what was happening here because I didn't read the comment where isOneOfType was defined. A brief comment here explaining what you're looking for ('Is Op one of the expected types?") would be helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, ParallelOp is mlir::acc::ParallelOp. We have a using namespace mlir::acc above. I'll add the comment.

op.setDefaultAttr(*attrData.defaultVal);
else
cgm.errorNYI(dirLoc, "OpenACC 'default' clause lowering for ", dirKind);
}

if (!attrData.deviceTypeArchs.empty()) {
// FIXME: OpenACC: as we implement this for other directive kinds, we have
// to expand this list, or more likely, have a 'noop' branch as most other
// uses of this apply to the operands instead.
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
llvm::SmallVector<mlir::Attribute> deviceTypes;
for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs)
deviceTypes.push_back(
mlir::acc::DeviceTypeAttr::get(builder.getContext(), DT));

op.setDeviceTypesAttr(
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
} else {
cgm.errorNYI(dirLoc, "OpenACC 'device_type' clause lowering for ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really "NYI"? It feels more like we shouldn't get here. In either case, since the condition is static, can this be a static compile error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, there are entries for Device_type that aren't implemented yet, the list on 126 is incomplete. Eventually, this will become a no-op/unreachable.

It can't be a static_assert or anything because the condition on line 122 is not constexpr. So we would still have to evaluate the static_assert in every other case.

dirKind);
}
}
}
};

} // namespace

template <typename Op, typename TermOp>
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
mlir::Location start, mlir::Location end,
llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
const Stmt *associatedStmt) {
mlir::LogicalResult res = mlir::success();

llvm::SmallVector<mlir::Type> retTy;
llvm::SmallVector<mlir::Value> operands;

// Clause-emitter must be here because it might modify operands.
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
clauseEmitter.VisitClauseList(clauses);

auto op = builder.create<Op>(start, retTy, operands);

// Apply the attributes derived from the clauses.
clauseEmitter.applyAttributes(op);
clauseEmitter.applyAttributes(builder, op);

mlir::Block &block = op.getRegion().emplaceBlock();
mlir::OpBuilder::InsertionGuard guardCase(builder);
Expand All @@ -95,19 +174,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
}

template <typename Op>
mlir::LogicalResult
CIRGenFunction::emitOpenACCOp(mlir::Location start,
llvm::ArrayRef<const OpenACCClause *> clauses) {
mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
llvm::ArrayRef<const OpenACCClause *> clauses) {
mlir::LogicalResult res = mlir::success();

llvm::SmallVector<mlir::Type> retTy;
llvm::SmallVector<mlir::Value> operands;

// Clause-emitter must be here because it might modify operands.
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
clauseEmitter.VisitClauseList(clauses);

builder.create<Op>(start, retTy, operands);
auto op = builder.create<Op>(start, retTy, operands);
// Apply the attributes derived from the clauses.
clauseEmitter.applyAttributes(builder, op);
return res;
}

Expand All @@ -119,13 +200,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
switch (s.getDirectiveKind()) {
case OpenACCDirectiveKind::Parallel:
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
start, end, s.clauses(), s.getStructuredBlock());
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
s.getStructuredBlock());
case OpenACCDirectiveKind::Serial:
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
start, end, s.clauses(), s.getStructuredBlock());
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
s.getStructuredBlock());
case OpenACCDirectiveKind::Kernels:
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
start, end, s.clauses(), s.getStructuredBlock());
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
s.getStructuredBlock());
default:
llvm_unreachable("invalid compute construct kind");
}
Expand All @@ -137,18 +221,22 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
mlir::Location end = getLoc(s.getSourceRange().getEnd());

return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
start, end, s.clauses(), s.getStructuredBlock());
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
s.getStructuredBlock());
}

mlir::LogicalResult
CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
mlir::Location start = getLoc(s.getSourceRange().getEnd());
return emitOpenACCOp<InitOp>(start, s.clauses());
return emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
s.clauses());
}

mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
const OpenACCShutdownConstruct &s) {
mlir::Location start = getLoc(s.getSourceRange().getEnd());
return emitOpenACCOp<ShutdownOp>(start, s.clauses());
return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
s.getDirectiveLoc(), s.clauses());
}

mlir::LogicalResult
Expand Down
13 changes: 13 additions & 0 deletions clang/test/CIR/CodeGenOpenACC/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,17 @@ void acc_init(void) {
// CHECK: cir.func @acc_init() {
#pragma acc init
// CHECK-NEXT: acc.init loc(#{{[a-zA-Z0-9]+}}){{$}}

#pragma acc init device_type(*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you have this?

#pragma acc init device_type(*) device_type(nvidia)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will faithfully represent them, like this:

  // CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<star>, #acc.device_type<nvidia>]}

There isn't really anything that prohibits it by standard, so it seems reasonable to do.

// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<star>]}
#pragma acc init device_type(nvidia)
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
#pragma acc init device_type(host, multicore)
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
#pragma acc init device_type(NVIDIA)
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
#pragma acc init device_type(HoSt, MuLtIcORe)
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
#pragma acc init device_type(HoSt) device_type(MuLtIcORe)
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
}
13 changes: 13 additions & 0 deletions clang/test/CIR/CodeGenOpenACC/shutdown.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,17 @@ void acc_shutdown(void) {
// CHECK: cir.func @acc_shutdown() {
#pragma acc shutdown
// CHECK-NEXT: acc.shutdown loc(#{{[a-zA-Z0-9]+}}){{$}}

#pragma acc shutdown device_type(*)
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<star>]}
#pragma acc shutdown device_type(nvidia)
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
#pragma acc shutdown device_type(host, multicore)
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
#pragma acc shutdown device_type(NVIDIA)
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
#pragma acc shutdown device_type(HoSt, MuLtIcORe)
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
#pragma acc shutdown device_type(HoSt) device_type(MuLtIcORe)
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
}