Skip to content

Commit af63e1b

Browse files
authored
[OpenACC][CIR] Implement 'self' lowering on compute constructs (llvm#135851)
This is our first attempt at lowering a clause that is an 'operand' in the OpenACC operand, so it does quite a bit of refactoring. My previous plans on how to emit the clauses was not viable, so we instead do 'create the op, then use the visitor to fill in the operands'. This resulted in the 'applyAttributes' function getting removed and a few other functions simplified. Additionally, it requires setting the insertion point a little to make sure we're inserting 'around' the operation correctly. Finally, since the OpenACC dialect only understands the MLIR types, we had to introduce a use of the unrealized-conversion-cast, which we'll probably getting good use out of in the future.
1 parent 16980d5 commit af63e1b

File tree

5 files changed

+205
-106
lines changed

5 files changed

+205
-106
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,11 @@ class OpenACCSelfClause final
430430
}
431431

432432
bool isConditionExprClause() const { return HasConditionExpr.has_value(); }
433+
bool isVarListClause() const { return !isConditionExprClause(); }
434+
bool isEmptySelfClause() const {
435+
return (isConditionExprClause() && !hasConditionExpr()) ||
436+
(!isConditionExprClause() && getVarList().empty());
437+
}
433438

434439
bool hasConditionExpr() const {
435440
assert(HasConditionExpr.has_value() &&

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 116 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -32,46 +32,51 @@ constexpr bool isOneOfTypes =
3232
template <typename ToTest, typename T>
3333
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
3434

35+
template <typename OpTy>
3536
class OpenACCClauseCIREmitter final
36-
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
37-
CIRGenModule &cgm;
37+
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
38+
OpTy &operation;
39+
CIRGenFunction &cgf;
40+
CIRGenBuilderTy &builder;
41+
3842
// This is necessary since a few of the clauses emit differently based on the
3943
// directive kind they are attached to.
4044
OpenACCDirectiveKind dirKind;
45+
// TODO(cir): This source location should be able to go away once the NYI
46+
// diagnostics are gone.
4147
SourceLocation dirLoc;
4248

43-
struct AttributeData {
44-
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
45-
// constructs as a 'default-attr'.
46-
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
47-
// For directives that have their device type architectures listed in
48-
// attributes (init/shutdown/etc), the list of architectures to be emitted.
49-
llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{};
50-
} attrData;
51-
5249
void clauseNotImplemented(const OpenACCClause &c) {
53-
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
50+
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
5451
}
5552

5653
public:
57-
OpenACCClauseCIREmitter(CIRGenModule &cgm, OpenACCDirectiveKind dirKind,
58-
SourceLocation dirLoc)
59-
: cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {}
54+
OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf,
55+
CIRGenBuilderTy &builder,
56+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc)
57+
: operation(operation), cgf(cgf), builder(builder), dirKind(dirKind),
58+
dirLoc(dirLoc) {}
6059

6160
void VisitClause(const OpenACCClause &clause) {
6261
clauseNotImplemented(clause);
6362
}
6463

6564
void VisitDefaultClause(const OpenACCDefaultClause &clause) {
66-
switch (clause.getDefaultClauseKind()) {
67-
case OpenACCDefaultClauseKind::None:
68-
attrData.defaultVal = ClauseDefaultValue::None;
69-
break;
70-
case OpenACCDefaultClauseKind::Present:
71-
attrData.defaultVal = ClauseDefaultValue::Present;
72-
break;
73-
case OpenACCDefaultClauseKind::Invalid:
74-
break;
65+
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
66+
// operations listed in the rest of the arguments.
67+
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
68+
switch (clause.getDefaultClauseKind()) {
69+
case OpenACCDefaultClauseKind::None:
70+
operation.setDefaultAttr(ClauseDefaultValue::None);
71+
break;
72+
case OpenACCDefaultClauseKind::Present:
73+
operation.setDefaultAttr(ClauseDefaultValue::Present);
74+
break;
75+
case OpenACCDefaultClauseKind::Invalid:
76+
break;
77+
}
78+
} else {
79+
return clauseNotImplemented(clause);
7580
}
7681
}
7782

@@ -89,64 +94,70 @@ class OpenACCClauseCIREmitter final
8994
}
9095

9196
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
97+
if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
98+
llvm::SmallVector<mlir::Attribute> deviceTypes;
99+
std::optional<mlir::ArrayAttr> existingDeviceTypes =
100+
operation.getDeviceTypes();
101+
102+
// Ensure we keep the existing ones, and in the correct 'new' order.
103+
if (existingDeviceTypes) {
104+
for (const mlir::Attribute &Attr : *existingDeviceTypes)
105+
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
106+
builder.getContext(),
107+
cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
108+
}
92109

93-
switch (dirKind) {
94-
case OpenACCDirectiveKind::Init:
95-
case OpenACCDirectiveKind::Set:
96-
case OpenACCDirectiveKind::Shutdown: {
97-
// Device type has a list that is either a 'star' (emitted as 'star'),
98-
// or an identifer list, all of which get added for attributes.
99-
100-
for (const DeviceTypeArgument &arg : clause.getArchitectures())
101-
attrData.deviceTypeArchs.push_back(decodeDeviceType(arg.first));
102-
break;
103-
}
104-
default:
110+
for (const DeviceTypeArgument &arg : clause.getArchitectures()) {
111+
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
112+
builder.getContext(), decodeDeviceType(arg.first)));
113+
}
114+
operation.removeDeviceTypesAttr();
115+
operation.setDeviceTypesAttr(
116+
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
117+
} else if constexpr (isOneOfTypes<OpTy, SetOp>) {
118+
assert(!operation.getDeviceTypeAttr() && "already have device-type?");
119+
assert(clause.getArchitectures().size() <= 1);
120+
121+
if (!clause.getArchitectures().empty())
122+
operation.setDeviceType(
123+
decodeDeviceType(clause.getArchitectures()[0].first));
124+
} else {
105125
return clauseNotImplemented(clause);
106126
}
107127
}
108128

109-
// Apply any of the clauses that resulted in an 'attribute'.
110-
template <typename Op>
111-
void applyAttributes(CIRGenBuilderTy &builder, Op &op) {
112-
113-
if (attrData.defaultVal.has_value()) {
114-
// FIXME: OpenACC: as we implement this for other directive kinds, we have
115-
// to expand this list.
116-
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
117-
// operations listed in the rest of the arguments.
118-
if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>)
119-
op.setDefaultAttr(*attrData.defaultVal);
120-
else
121-
cgm.errorNYI(dirLoc, "OpenACC 'default' clause lowering for ", dirKind);
122-
}
123-
124-
if (!attrData.deviceTypeArchs.empty()) {
125-
// FIXME: OpenACC: as we implement this for other directive kinds, we have
126-
// to expand this list, or more likely, have a 'noop' branch as most other
127-
// uses of this apply to the operands instead.
128-
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
129-
if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) {
130-
llvm::SmallVector<mlir::Attribute> deviceTypes;
131-
for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs)
132-
deviceTypes.push_back(
133-
mlir::acc::DeviceTypeAttr::get(builder.getContext(), DT));
134-
135-
op.setDeviceTypesAttr(
136-
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
137-
} else if constexpr (isOneOfTypes<Op, SetOp>) {
138-
assert(attrData.deviceTypeArchs.size() <= 1 &&
139-
"Set can only have a single architecture");
140-
if (!attrData.deviceTypeArchs.empty())
141-
op.setDeviceType(attrData.deviceTypeArchs[0]);
129+
void VisitSelfClause(const OpenACCSelfClause &clause) {
130+
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp>) {
131+
if (clause.isEmptySelfClause()) {
132+
operation.setSelfAttr(true);
133+
} else if (clause.isConditionExprClause()) {
134+
assert(clause.hasConditionExpr());
135+
mlir::Value condition =
136+
cgf.evaluateExprAsBool(clause.getConditionExpr());
137+
138+
mlir::Location exprLoc =
139+
cgf.cgm.getLoc(clause.getConditionExpr()->getBeginLoc());
140+
mlir::IntegerType targetType = mlir::IntegerType::get(
141+
&cgf.getMLIRContext(), /*width=*/1,
142+
mlir::IntegerType::SignednessSemantics::Signless);
143+
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
144+
exprLoc, targetType, condition);
145+
operation.getSelfCondMutable().append(conversionOp.getResult(0));
142146
} else {
143-
cgm.errorNYI(dirLoc, "OpenACC 'device_type' clause lowering for ",
144-
dirKind);
147+
llvm_unreachable("var-list version of self shouldn't get here");
145148
}
149+
} else {
150+
return clauseNotImplemented(clause);
146151
}
147152
}
148153
};
149154

155+
template <typename OpTy>
156+
auto makeClauseEmitter(OpTy &op, CIRGenFunction &cgf, CIRGenBuilderTy &builder,
157+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc) {
158+
return OpenACCClauseCIREmitter<OpTy>(op, cgf, builder, dirKind, dirLoc);
159+
}
160+
150161
} // namespace
151162

152163
template <typename Op, typename TermOp>
@@ -158,24 +169,27 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
158169

159170
llvm::SmallVector<mlir::Type> retTy;
160171
llvm::SmallVector<mlir::Value> operands;
161-
162-
// Clause-emitter must be here because it might modify operands.
163-
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
164-
clauseEmitter.VisitClauseList(clauses);
165-
166172
auto op = builder.create<Op>(start, retTy, operands);
167173

168-
// Apply the attributes derived from the clauses.
169-
clauseEmitter.applyAttributes(builder, op);
174+
{
175+
mlir::OpBuilder::InsertionGuard guardCase(builder);
176+
// Sets insertion point before the 'op', since every new expression needs to
177+
// be before the operation.
178+
builder.setInsertionPoint(op);
179+
makeClauseEmitter(op, *this, builder, dirKind, dirLoc)
180+
.VisitClauseList(clauses);
181+
}
170182

171-
mlir::Block &block = op.getRegion().emplaceBlock();
172-
mlir::OpBuilder::InsertionGuard guardCase(builder);
173-
builder.setInsertionPointToEnd(&block);
183+
{
184+
mlir::Block &block = op.getRegion().emplaceBlock();
185+
mlir::OpBuilder::InsertionGuard guardCase(builder);
186+
builder.setInsertionPointToEnd(&block);
174187

175-
LexicalScope ls{*this, start, builder.getInsertionBlock()};
176-
res = emitStmt(associatedStmt, /*useCurrentScope=*/true);
188+
LexicalScope ls{*this, start, builder.getInsertionBlock()};
189+
res = emitStmt(associatedStmt, /*useCurrentScope=*/true);
177190

178-
builder.create<TermOp>(end);
191+
builder.create<TermOp>(end);
192+
}
179193
return res;
180194
}
181195

@@ -187,14 +201,16 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
187201

188202
llvm::SmallVector<mlir::Type> retTy;
189203
llvm::SmallVector<mlir::Value> operands;
190-
191-
// Clause-emitter must be here because it might modify operands.
192-
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc);
193-
clauseEmitter.VisitClauseList(clauses);
194-
195204
auto op = builder.create<Op>(start, retTy, operands);
196-
// Apply the attributes derived from the clauses.
197-
clauseEmitter.applyAttributes(builder, op);
205+
206+
{
207+
mlir::OpBuilder::InsertionGuard guardCase(builder);
208+
// Sets insertion point before the 'op', since every new expression needs to
209+
// be before the operation.
210+
builder.setInsertionPoint(op);
211+
makeClauseEmitter(op, *this, builder, dirKind, dirLoc)
212+
.VisitClauseList(clauses);
213+
}
198214
return res;
199215
}
200216

@@ -254,46 +270,46 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
254270

255271
mlir::LogicalResult
256272
CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) {
257-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Loop Construct");
273+
cgm.errorNYI(s.getSourceRange(), "OpenACC Loop Construct");
258274
return mlir::failure();
259275
}
260276
mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(
261277
const OpenACCCombinedConstruct &s) {
262-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Combined Construct");
278+
cgm.errorNYI(s.getSourceRange(), "OpenACC Combined Construct");
263279
return mlir::failure();
264280
}
265281
mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(
266282
const OpenACCEnterDataConstruct &s) {
267-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC EnterData Construct");
283+
cgm.errorNYI(s.getSourceRange(), "OpenACC EnterData Construct");
268284
return mlir::failure();
269285
}
270286
mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct(
271287
const OpenACCExitDataConstruct &s) {
272-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC ExitData Construct");
288+
cgm.errorNYI(s.getSourceRange(), "OpenACC ExitData Construct");
273289
return mlir::failure();
274290
}
275291
mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
276292
const OpenACCHostDataConstruct &s) {
277-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC HostData Construct");
293+
cgm.errorNYI(s.getSourceRange(), "OpenACC HostData Construct");
278294
return mlir::failure();
279295
}
280296
mlir::LogicalResult
281297
CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
282-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Wait Construct");
298+
cgm.errorNYI(s.getSourceRange(), "OpenACC Wait Construct");
283299
return mlir::failure();
284300
}
285301
mlir::LogicalResult
286302
CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) {
287-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Update Construct");
303+
cgm.errorNYI(s.getSourceRange(), "OpenACC Update Construct");
288304
return mlir::failure();
289305
}
290306
mlir::LogicalResult
291307
CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) {
292-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Atomic Construct");
308+
cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct");
293309
return mlir::failure();
294310
}
295311
mlir::LogicalResult
296312
CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) {
297-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Cache Construct");
313+
cgm.errorNYI(s.getSourceRange(), "OpenACC Cache Construct");
298314
return mlir::failure();
299315
}

clang/test/CIR/CodeGenOpenACC/kernels.c

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// RUN: %clang_cc1 -fopenacc -emit-cir -fclangir %s -o - | FileCheck %s
22

3-
void acc_kernels(void) {
4-
// CHECK: cir.func @acc_kernels() {
3+
void acc_kernels(int cond) {
4+
// CHECK: cir.func @acc_kernels(%[[ARG:.*]]: !s32i{{.*}}) {
5+
// CHECK-NEXT: %[[COND:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["cond", init]
6+
// CHECK-NEXT: cir.store %[[ARG]], %[[COND]] : !s32i, !cir.ptr<!s32i>
57
#pragma acc kernels
68
{}
79

@@ -38,5 +40,29 @@ void acc_kernels(void) {
3840
// CHECK-NEXT: acc.terminator
3941
// CHECK-NEXT:}
4042

43+
#pragma acc kernels self
44+
{}
45+
// CHECK-NEXT: acc.kernels {
46+
// CHECK-NEXT: acc.terminator
47+
// CHECK-NEXT: } attributes {selfAttr}
48+
49+
#pragma acc kernels self(cond)
50+
{}
51+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
52+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[COND_LOAD]] : !s32i), !cir.bool
53+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
54+
// CHECK-NEXT: acc.kernels self(%[[CONV_CAST]]) {
55+
// CHECK-NEXT: acc.terminator
56+
// CHECK-NEXT: } loc
57+
58+
#pragma acc kernels self(0)
59+
{}
60+
// CHECK-NEXT: %[[ZERO_LITERAL:.*]] = cir.const #cir.int<0> : !s32i
61+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[ZERO_LITERAL]] : !s32i), !cir.bool
62+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
63+
// CHECK-NEXT: acc.kernels self(%[[CONV_CAST]]) {
64+
// CHECK-NEXT: acc.terminator
65+
// CHECK-NEXT: } loc
66+
4167
// CHECK-NEXT: cir.return
4268
}

0 commit comments

Comments
 (0)