Skip to content

Commit 74c2b41

Browse files
authored
[OpenACC][CIR] Implement 'device_type' clause lowering for 'init'/'sh… (#135102)
…utdown' This patch emits the lowering for 'device_type' on an 'init' or 'shutdown'. This one is fairly unique, as these directives have it as an attribute, rather than as a component of the individual operands, like the rest of the constructs. So this patch implements the lowering as an attribute. In order to do tis, a few refactorings had to happen: First, the 'emitOpenACCOp' functions needed to pick up th edirective kind/location so that the NYI diagnostic could be reasonable. Second, and most impactful, the `applyAttributes` function ends up needing to encode some of the appertainment rules, thanks to the way the OpenACC-MLIR operands get their attributes attached. Since they each use a special function (rather than something that can be legalized at runtime), the forms of 'setDefaultAttr' is only valid for some ops. SO this patch uses some `if constexpr` and a small type-trait to help legalize these.
1 parent dcb9078 commit 74c2b41

File tree

4 files changed

+139
-24
lines changed

4 files changed

+139
-24
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,15 +604,16 @@ class CIRGenFunction : public CIRGenTypeCache {
604604
private:
605605
template <typename Op>
606606
mlir::LogicalResult
607-
emitOpenACCOp(mlir::Location start,
607+
emitOpenACCOp(mlir::Location start, OpenACCDirectiveKind dirKind,
608+
SourceLocation dirLoc,
608609
llvm::ArrayRef<const OpenACCClause *> clauses);
609610
// Function to do the basic implementation of an operation with an Associated
610611
// Statement. Models AssociatedStmtConstruct.
611612
template <typename Op, typename TermOp>
612-
mlir::LogicalResult
613-
emitOpenACCOpAssociatedStmt(mlir::Location start, mlir::Location end,
614-
llvm::ArrayRef<const OpenACCClause *> clauses,
615-
const Stmt *associatedStmt);
613+
mlir::LogicalResult emitOpenACCOpAssociatedStmt(
614+
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
615+
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
616+
const Stmt *associatedStmt);
616617

617618
public:
618619
mlir::LogicalResult

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// Emit OpenACC Stmt nodes as CIR code.
1010
//
1111
//===----------------------------------------------------------------------===//
12+
#include <type_traits>
1213

1314
#include "CIRGenBuilder.h"
1415
#include "CIRGenFunction.h"
@@ -23,22 +24,39 @@ using namespace cir;
2324
using namespace mlir::acc;
2425

2526
namespace {
27+
// Simple type-trait to see if the first template arg is one of the list, so we
28+
// can tell whether to `if-constexpr` a bunch of stuff.
29+
template <typename ToTest, typename T, typename... Tys>
30+
constexpr bool isOneOfTypes =
31+
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
32+
template <typename ToTest, typename T>
33+
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
34+
2635
class OpenACCClauseCIREmitter final
2736
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
2837
CIRGenModule &cgm;
38+
// This is necessary since a few of the clauses emit differently based on the
39+
// directive kind they are attached to.
40+
OpenACCDirectiveKind dirKind;
41+
SourceLocation dirLoc;
2942

3043
struct AttributeData {
3144
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
3245
// constructs as a 'default-attr'.
3346
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
47+
// For directives that have their device type architectures listed in
48+
// attributes (init/shutdown/etc), the list of architectures to be emitted.
49+
llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
3450
} attrData;
3551

3652
void clauseNotImplemented(const OpenACCClause &c) {
3753
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
3854
}
3955

4056
public:
41-
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {}
57+
OpenACCClauseCIREmitter(CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
58+
SourceLocation dirLoc)
59+
: cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
4260

4361
void VisitClause(const OpenACCClause &clause) {
4462
clauseNotImplemented(clause);
@@ -57,31 +75,92 @@ class OpenACCClauseCIREmitter final
5775
}
5876
}
5977

78+
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
79+
// '*' case leaves no identifier-info, just a nullptr.
80+
if (!ii)
81+
return mlir::acc::DeviceType::Star;
82+
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
83+
.CaseLower("default", mlir::acc::DeviceType::Default)
84+
.CaseLower("host", mlir::acc::DeviceType::Host)
85+
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
86+
.CasesLower("nvidia", "acc_device_nvidia",
87+
mlir::acc::DeviceType::Nvidia)
88+
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
89+
}
90+
91+
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
92+
93+
switch (dirKind) {
94+
case OpenACCDirectiveKind::Init:
95+
case OpenACCDirectiveKind::Shutdown: {
96+
// Device type has a list that is either a 'star' (emitted as 'star'),
97+
// or an identifer list, all of which get added for attributes.
98+
99+
for (const DeviceTypeArgument &arg : clause.getArchitectures())
100+
attrData.deviceTypeArchs.push_back(decodeDeviceType(arg.first));
101+
break;
102+
}
103+
default:
104+
return clauseNotImplemented(clause);
105+
}
106+
}
107+
60108
// Apply any of the clauses that resulted in an 'attribute'.
61-
template <typename Op> void applyAttributes(Op &op) {
62-
if (attrData.defaultVal.has_value())
63-
op.setDefaultAttr(*attrData.defaultVal);
109+
template <typename Op>
110+
void applyAttributes(CIRGenBuilderTy &builder, Op &op) {
111+
112+
if (attrData.defaultVal.has_value()) {
113+
// FIXME: OpenACC: as we implement this for other directive kinds, we have
114+
// to expand this list.
115+
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
116+
// operations listed in the rest of the arguments.
117+
if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
118+
op.setDefaultAttr(*attrData.defaultVal);
119+
else
120+
cgm.errorNYI(dirLoc, "OpenACC 'default' clause lowering for ", dirKind);
121+
}
122+
123+
if (!attrData.deviceTypeArchs.empty()) {
124+
// FIXME: OpenACC: as we implement this for other directive kinds, we have
125+
// to expand this list, or more likely, have a 'noop' branch as most other
126+
// uses of this apply to the operands instead.
127+
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
128+
if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
129+
llvm::SmallVector<mlir::Attribute> deviceTypes;
130+
for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs)
131+
deviceTypes.push_back(
132+
mlir::acc::DeviceTypeAttr::get(builder.getContext(), DT));
133+
134+
op.setDeviceTypesAttr(
135+
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
136+
} else {
137+
cgm.errorNYI(dirLoc, "OpenACC 'device_type' clause lowering for ",
138+
dirKind);
139+
}
140+
}
64141
}
65142
};
143+
66144
} // namespace
67145

68146
template <typename Op, typename TermOp>
69147
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
70-
mlir::Location start, mlir::Location end,
71-
llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
148+
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
149+
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
150+
const Stmt *associatedStmt) {
72151
mlir::LogicalResult res = mlir::success();
73152

74153
llvm::SmallVector<mlir::Type> retTy;
75154
llvm::SmallVector<mlir::Value> operands;
76155

77156
// Clause-emitter must be here because it might modify operands.
78-
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
157+
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
79158
clauseEmitter.VisitClauseList(clauses);
80159

81160
auto op = builder.create<Op>(start, retTy, operands);
82161

83162
// Apply the attributes derived from the clauses.
84-
clauseEmitter.applyAttributes(op);
163+
clauseEmitter.applyAttributes(builder, op);
85164

86165
mlir::Block &block = op.getRegion().emplaceBlock();
87166
mlir::OpBuilder::InsertionGuard guardCase(builder);
@@ -95,19 +174,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
95174
}
96175

97176
template <typename Op>
98-
mlir::LogicalResult
99-
CIRGenFunction::emitOpenACCOp(mlir::Location start,
100-
llvm::ArrayRef<const OpenACCClause *> clauses) {
177+
mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
178+
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
179+
llvm::ArrayRef<const OpenACCClause *> clauses) {
101180
mlir::LogicalResult res = mlir::success();
102181

103182
llvm::SmallVector<mlir::Type> retTy;
104183
llvm::SmallVector<mlir::Value> operands;
105184

106185
// Clause-emitter must be here because it might modify operands.
107-
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
186+
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
108187
clauseEmitter.VisitClauseList(clauses);
109188

110-
builder.create<Op>(start, retTy, operands);
189+
auto op = builder.create<Op>(start, retTy, operands);
190+
// Apply the attributes derived from the clauses.
191+
clauseEmitter.applyAttributes(builder, op);
111192
return res;
112193
}
113194

@@ -119,13 +200,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
119200
switch (s.getDirectiveKind()) {
120201
case OpenACCDirectiveKind::Parallel:
121202
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
122-
start, end, s.clauses(), s.getStructuredBlock());
203+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
204+
s.getStructuredBlock());
123205
case OpenACCDirectiveKind::Serial:
124206
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
125-
start, end, s.clauses(), s.getStructuredBlock());
207+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
208+
s.getStructuredBlock());
126209
case OpenACCDirectiveKind::Kernels:
127210
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
128-
start, end, s.clauses(), s.getStructuredBlock());
211+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
212+
s.getStructuredBlock());
129213
default:
130214
llvm_unreachable("invalid compute construct kind");
131215
}
@@ -137,18 +221,22 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
137221
mlir::Location end = getLoc(s.getSourceRange().getEnd());
138222

139223
return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
140-
start, end, s.clauses(), s.getStructuredBlock());
224+
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
225+
s.getStructuredBlock());
141226
}
142227

143228
mlir::LogicalResult
144229
CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
145230
mlir::Location start = getLoc(s.getSourceRange().getEnd());
146-
return emitOpenACCOp<InitOp>(start, s.clauses());
231+
return emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
232+
s.clauses());
147233
}
234+
148235
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
149236
const OpenACCShutdownConstruct &s) {
150237
mlir::Location start = getLoc(s.getSourceRange().getEnd());
151-
return emitOpenACCOp<ShutdownOp>(start, s.clauses());
238+
return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
239+
s.getDirectiveLoc(), s.clauses());
152240
}
153241

154242
mlir::LogicalResult

clang/test/CIR/CodeGenOpenACC/init.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,17 @@ void acc_init(void) {
44
// CHECK: cir.func @acc_init() {
55
#pragma acc init
66
// CHECK-NEXT: acc.init loc(#{{[a-zA-Z0-9]+}}){{$}}
7+
8+
#pragma acc init device_type(*)
9+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<star>]}
10+
#pragma acc init device_type(nvidia)
11+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
12+
#pragma acc init device_type(host, multicore)
13+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
14+
#pragma acc init device_type(NVIDIA)
15+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
16+
#pragma acc init device_type(HoSt, MuLtIcORe)
17+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
18+
#pragma acc init device_type(HoSt) device_type(MuLtIcORe)
19+
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
720
}

clang/test/CIR/CodeGenOpenACC/shutdown.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,17 @@ void acc_shutdown(void) {
44
// CHECK: cir.func @acc_shutdown() {
55
#pragma acc shutdown
66
// CHECK-NEXT: acc.shutdown loc(#{{[a-zA-Z0-9]+}}){{$}}
7+
8+
#pragma acc shutdown device_type(*)
9+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<star>]}
10+
#pragma acc shutdown device_type(nvidia)
11+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
12+
#pragma acc shutdown device_type(host, multicore)
13+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
14+
#pragma acc shutdown device_type(NVIDIA)
15+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<nvidia>]}
16+
#pragma acc shutdown device_type(HoSt, MuLtIcORe)
17+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
18+
#pragma acc shutdown device_type(HoSt) device_type(MuLtIcORe)
19+
// CHECK-NEXT: acc.shutdown attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
720
}

0 commit comments

Comments
 (0)