9
9
// Emit OpenACC Stmt nodes as CIR code.
10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
+ #include < type_traits>
12
13
13
14
#include " CIRGenBuilder.h"
14
15
#include " CIRGenFunction.h"
@@ -23,22 +24,39 @@ using namespace cir;
23
24
using namespace mlir ::acc;
24
25
25
26
namespace {
27
+ // Simple type-trait to see if the first template arg is one of the list, so we
28
+ // can tell whether to `if-constexpr` a bunch of stuff.
29
+ template <typename ToTest, typename T, typename ... Tys>
30
+ constexpr bool isOneOfTypes =
31
+ std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
32
+ template <typename ToTest, typename T>
33
+ constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
34
+
26
35
class OpenACCClauseCIREmitter final
27
36
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
28
37
CIRGenModule &cgm;
38
+ // This is necessary since a few of the clauses emit differently based on the
39
+ // directive kind they are attached to.
40
+ OpenACCDirectiveKind dirKind;
41
+ SourceLocation dirLoc;
29
42
30
43
struct AttributeData {
31
44
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
32
45
// constructs as a 'default-attr'.
33
46
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
47
+ // For directives that have their device type architectures listed in
48
+ // attributes (init/shutdown/etc), the list of architectures to be emitted.
49
+ llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
34
50
} attrData;
35
51
36
52
void clauseNotImplemented (const OpenACCClause &c) {
37
53
cgm.errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
38
54
}
39
55
40
56
public:
41
- OpenACCClauseCIREmitter (CIRGenModule &cgm) : cgm(cgm) {}
57
+ OpenACCClauseCIREmitter (CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
58
+ SourceLocation dirLoc)
59
+ : cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
42
60
43
61
void VisitClause (const OpenACCClause &clause) {
44
62
clauseNotImplemented (clause);
@@ -57,31 +75,92 @@ class OpenACCClauseCIREmitter final
57
75
}
58
76
}
59
77
78
+ mlir::acc::DeviceType decodeDeviceType (const IdentifierInfo *ii) {
79
+ // '*' case leaves no identifier-info, just a nullptr.
80
+ if (!ii)
81
+ return mlir::acc::DeviceType::Star;
82
+ return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName ())
83
+ .CaseLower (" default" , mlir::acc::DeviceType::Default)
84
+ .CaseLower (" host" , mlir::acc::DeviceType::Host)
85
+ .CaseLower (" multicore" , mlir::acc::DeviceType::Multicore)
86
+ .CasesLower (" nvidia" , " acc_device_nvidia" ,
87
+ mlir::acc::DeviceType::Nvidia)
88
+ .CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
89
+ }
90
+
91
+ void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
92
+
93
+ switch (dirKind) {
94
+ case OpenACCDirectiveKind::Init:
95
+ case OpenACCDirectiveKind::Shutdown: {
96
+ // Device type has a list that is either a 'star' (emitted as 'star'),
97
+ // or an identifer list, all of which get added for attributes.
98
+
99
+ for (const DeviceTypeArgument &arg : clause.getArchitectures ())
100
+ attrData.deviceTypeArchs .push_back (decodeDeviceType (arg.first ));
101
+ break ;
102
+ }
103
+ default :
104
+ return clauseNotImplemented (clause);
105
+ }
106
+ }
107
+
60
108
// Apply any of the clauses that resulted in an 'attribute'.
61
- template <typename Op> void applyAttributes (Op &op) {
62
- if (attrData.defaultVal .has_value ())
63
- op.setDefaultAttr (*attrData.defaultVal );
109
+ template <typename Op>
110
+ void applyAttributes (CIRGenBuilderTy &builder, Op &op) {
111
+
112
+ if (attrData.defaultVal .has_value ()) {
113
+ // FIXME: OpenACC: as we implement this for other directive kinds, we have
114
+ // to expand this list.
115
+ // This type-trait checks if 'op'(the first arg) is one of the mlir::acc
116
+ // operations listed in the rest of the arguments.
117
+ if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
118
+ op.setDefaultAttr (*attrData.defaultVal );
119
+ else
120
+ cgm.errorNYI (dirLoc, " OpenACC 'default' clause lowering for " , dirKind);
121
+ }
122
+
123
+ if (!attrData.deviceTypeArchs .empty ()) {
124
+ // FIXME: OpenACC: as we implement this for other directive kinds, we have
125
+ // to expand this list, or more likely, have a 'noop' branch as most other
126
+ // uses of this apply to the operands instead.
127
+ // This type-trait checks if 'op'(the first arg) is one of the mlir::acc
128
+ if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
129
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
130
+ for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs )
131
+ deviceTypes.push_back (
132
+ mlir::acc::DeviceTypeAttr::get (builder.getContext (), DT));
133
+
134
+ op.setDeviceTypesAttr (
135
+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
136
+ } else {
137
+ cgm.errorNYI (dirLoc, " OpenACC 'device_type' clause lowering for " ,
138
+ dirKind);
139
+ }
140
+ }
64
141
}
65
142
};
143
+
66
144
} // namespace
67
145
68
146
template <typename Op, typename TermOp>
69
147
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt (
70
- mlir::Location start, mlir::Location end,
71
- llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
148
+ mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
149
+ SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
150
+ const Stmt *associatedStmt) {
72
151
mlir::LogicalResult res = mlir::success ();
73
152
74
153
llvm::SmallVector<mlir::Type> retTy;
75
154
llvm::SmallVector<mlir::Value> operands;
76
155
77
156
// Clause-emitter must be here because it might modify operands.
78
- OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule ());
157
+ OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule (), dirKind, dirLoc );
79
158
clauseEmitter.VisitClauseList (clauses);
80
159
81
160
auto op = builder.create <Op>(start, retTy, operands);
82
161
83
162
// Apply the attributes derived from the clauses.
84
- clauseEmitter.applyAttributes (op);
163
+ clauseEmitter.applyAttributes (builder, op);
85
164
86
165
mlir::Block &block = op.getRegion ().emplaceBlock ();
87
166
mlir::OpBuilder::InsertionGuard guardCase (builder);
@@ -95,19 +174,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
95
174
}
96
175
97
176
template <typename Op>
98
- mlir::LogicalResult
99
- CIRGenFunction::emitOpenACCOp ( mlir::Location start,
100
- llvm::ArrayRef<const OpenACCClause *> clauses) {
177
+ mlir::LogicalResult CIRGenFunction::emitOpenACCOp (
178
+ mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc ,
179
+ llvm::ArrayRef<const OpenACCClause *> clauses) {
101
180
mlir::LogicalResult res = mlir::success ();
102
181
103
182
llvm::SmallVector<mlir::Type> retTy;
104
183
llvm::SmallVector<mlir::Value> operands;
105
184
106
185
// Clause-emitter must be here because it might modify operands.
107
- OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule ());
186
+ OpenACCClauseCIREmitter clauseEmitter (getCIRGenModule (), dirKind, dirLoc );
108
187
clauseEmitter.VisitClauseList (clauses);
109
188
110
- builder.create <Op>(start, retTy, operands);
189
+ auto op = builder.create <Op>(start, retTy, operands);
190
+ // Apply the attributes derived from the clauses.
191
+ clauseEmitter.applyAttributes (builder, op);
111
192
return res;
112
193
}
113
194
@@ -119,13 +200,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
119
200
switch (s.getDirectiveKind ()) {
120
201
case OpenACCDirectiveKind::Parallel:
121
202
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
122
- start, end, s.clauses (), s.getStructuredBlock ());
203
+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
204
+ s.getStructuredBlock ());
123
205
case OpenACCDirectiveKind::Serial:
124
206
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
125
- start, end, s.clauses (), s.getStructuredBlock ());
207
+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
208
+ s.getStructuredBlock ());
126
209
case OpenACCDirectiveKind::Kernels:
127
210
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
128
- start, end, s.clauses (), s.getStructuredBlock ());
211
+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
212
+ s.getStructuredBlock ());
129
213
default :
130
214
llvm_unreachable (" invalid compute construct kind" );
131
215
}
@@ -137,18 +221,22 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
137
221
mlir::Location end = getLoc (s.getSourceRange ().getEnd ());
138
222
139
223
return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
140
- start, end, s.clauses (), s.getStructuredBlock ());
224
+ start, end, s.getDirectiveKind (), s.getDirectiveLoc (), s.clauses (),
225
+ s.getStructuredBlock ());
141
226
}
142
227
143
228
mlir::LogicalResult
144
229
CIRGenFunction::emitOpenACCInitConstruct (const OpenACCInitConstruct &s) {
145
230
mlir::Location start = getLoc (s.getSourceRange ().getEnd ());
146
- return emitOpenACCOp<InitOp>(start, s.clauses ());
231
+ return emitOpenACCOp<InitOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
232
+ s.clauses ());
147
233
}
234
+
148
235
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct (
149
236
const OpenACCShutdownConstruct &s) {
150
237
mlir::Location start = getLoc (s.getSourceRange ().getEnd ());
151
- return emitOpenACCOp<ShutdownOp>(start, s.clauses ());
238
+ return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind (),
239
+ s.getDirectiveLoc (), s.clauses ());
152
240
}
153
241
154
242
mlir::LogicalResult
0 commit comments