@@ -23,9 +23,25 @@ constexpr bool isOneOfTypes =
23
23
template <typename ToTest, typename T>
24
24
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
25
25
26
+ // Holds information for emitting clauses for a combined construct. We
27
+ // instantiate the clause emitter with this type so that it can use
28
+ // if-constexpr to specially handle these.
29
+ template <typename CompOpTy> struct CombinedConstructClauseInfo {
30
+ using ComputeOpTy = CompOpTy;
31
+ ComputeOpTy computeOp;
32
+ mlir::acc::LoopOp loopOp;
33
+ };
34
+
35
+ template <typename ToTest> constexpr bool isCombinedType = false ;
36
+ template <typename T>
37
+ constexpr bool isCombinedType<CombinedConstructClauseInfo<T>> = true ;
38
+
26
39
template <typename OpTy>
27
40
class OpenACCClauseCIREmitter final
28
41
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
42
+ // Necessary for combined constructs.
43
+ template <typename FriendOpTy> friend class OpenACCClauseCIREmitter ;
44
+
29
45
OpTy &operation;
30
46
CIRGen::CIRGenFunction &cgf;
31
47
CIRGen::CIRGenBuilderTy &builder;
@@ -119,6 +135,26 @@ class OpenACCClauseCIREmitter final
119
135
llvm_unreachable (" unknown gang kind" );
120
136
}
121
137
138
+ template <typename U = void ,
139
+ typename = std::enable_if_t <isCombinedType<OpTy>, U>>
140
+ void applyToLoopOp (const OpenACCClause &c) {
141
+ // TODO OpenACC: we have to set the insertion scope here correctly still.
142
+ OpenACCClauseCIREmitter<mlir::acc::LoopOp> loopEmitter{
143
+ operation.loopOp , cgf, builder, dirKind, dirLoc};
144
+ loopEmitter.lastDeviceTypeValues = lastDeviceTypeValues;
145
+ loopEmitter.Visit (&c);
146
+ }
147
+
148
+ template <typename U = void ,
149
+ typename = std::enable_if_t <isCombinedType<OpTy>, U>>
150
+ void applyToComputeOp (const OpenACCClause &c) {
151
+ // TODO OpenACC: we have to set the insertion scope here correctly still.
152
+ OpenACCClauseCIREmitter<typename OpTy::ComputeOpTy> computeEmitter{
153
+ operation.computeOp , cgf, builder, dirKind, dirLoc};
154
+ computeEmitter.lastDeviceTypeValues = lastDeviceTypeValues;
155
+ computeEmitter.Visit (&c);
156
+ }
157
+
122
158
public:
123
159
OpenACCClauseCIREmitter (OpTy &operation, CIRGen::CIRGenFunction &cgf,
124
160
CIRGen::CIRGenBuilderTy &builder,
@@ -145,10 +181,10 @@ class OpenACCClauseCIREmitter final
145
181
case OpenACCDefaultClauseKind::Invalid:
146
182
break ;
147
183
}
184
+ } else if constexpr (isCombinedType<OpTy>) {
185
+ applyToComputeOp (clause);
148
186
} else {
149
- // TODO: When we've implemented this for everything, switch this to an
150
- // unreachable. Combined constructs remain.
151
- return clauseNotImplemented (clause);
187
+ llvm_unreachable (" Unknown construct kind in VisitDefaultClause" );
152
188
}
153
189
}
154
190
@@ -175,9 +211,12 @@ class OpenACCClauseCIREmitter final
175
211
// Nothing to do here, these constructs don't have any IR for these, as
176
212
// they just modify the other clauses IR. So setting of
177
213
// `lastDeviceTypeValues` (done above) is all we need.
214
+ } else if constexpr (isCombinedType<OpTy>) {
215
+ // Nothing to do here either, combined constructs are just going to use
216
+ // 'lastDeviceTypeValues' to set the value for the child visitor.
178
217
} else {
179
218
// TODO: When we've implemented this for everything, switch this to an
180
- // unreachable. update, data, routine, combined constructs remain.
219
+ // unreachable. update, data, routine constructs remain.
181
220
return clauseNotImplemented (clause);
182
221
}
183
222
}
@@ -334,9 +373,11 @@ class OpenACCClauseCIREmitter final
334
373
void VisitSeqClause (const OpenACCSeqClause &clause) {
335
374
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
336
375
operation.addSeq (builder.getContext (), lastDeviceTypeValues);
376
+ } else if constexpr (isCombinedType<OpTy>) {
377
+ applyToLoopOp (clause);
337
378
} else {
338
379
// TODO: When we've implemented this for everything, switch this to an
339
- // unreachable. Routine, Combined constructs remain .
380
+ // unreachable. Routine construct remains .
340
381
return clauseNotImplemented (clause);
341
382
}
342
383
}
0 commit comments