Skip to content

Commit fa273e1

Browse files
authored
[OpenACC][CIR] Implement 'data' construct lowering (#135038)
This patch does the lowering of the OpenACC 'data' construct, which requires getting the `default` clause (as `data` requires at least 1 of a list of clauses, and this is the easiest one). The lowering of the clauses appears to happen in 1 of 2 ways: a- as an operand. or b- as an attribute. This patch adds infrastructure to lower as an attribute, as that is how 'data' works. In addition to that, it changes the OpenACCClauseVisitor a bit, which previously just required that each of the derived classes have all of the clauses covered. This patch modifies it so that the visitor directly calls the derived class from its visitor function, which leaves the base-class ones the ability to defer to a generic function. This was previously like this because I had some use cases that I didn't end up using, and the 'generic' function here seems much more useful.
1 parent 6d4d017 commit fa273e1

File tree

7 files changed

+138
-33
lines changed

7 files changed

+138
-33
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,11 +1316,13 @@ template <class Impl> class OpenACCClauseVisitor {
13161316
switch (C->getClauseKind()) {
13171317
#define VISIT_CLAUSE(CLAUSE_NAME) \
13181318
case OpenACCClauseKind::CLAUSE_NAME: \
1319-
Visit##CLAUSE_NAME##Clause(*cast<OpenACC##CLAUSE_NAME##Clause>(C)); \
1319+
getDerived().Visit##CLAUSE_NAME##Clause( \
1320+
*cast<OpenACC##CLAUSE_NAME##Clause>(C)); \
13201321
return;
13211322
#define CLAUSE_ALIAS(ALIAS_NAME, CLAUSE_NAME, DEPRECATED) \
13221323
case OpenACCClauseKind::ALIAS_NAME: \
1323-
Visit##CLAUSE_NAME##Clause(*cast<OpenACC##CLAUSE_NAME##Clause>(C)); \
1324+
getDerived().Visit##CLAUSE_NAME##Clause( \
1325+
*cast<OpenACC##CLAUSE_NAME##Clause>(C)); \
13241326
return;
13251327
#include "clang/Basic/OpenACCClauses.def"
13261328

@@ -1333,7 +1335,7 @@ template <class Impl> class OpenACCClauseVisitor {
13331335
#define VISIT_CLAUSE(CLAUSE_NAME) \
13341336
void Visit##CLAUSE_NAME##Clause( \
13351337
const OpenACC##CLAUSE_NAME##Clause &Clause) { \
1336-
return getDerived().Visit##CLAUSE_NAME##Clause(Clause); \
1338+
return getDerived().VisitClause(Clause); \
13371339
}
13381340

13391341
#include "clang/Basic/OpenACCClauses.def"

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -571,14 +571,13 @@ class CIRGenFunction : public CIRGenTypeCache {
571571
// OpenACC Emission
572572
//===--------------------------------------------------------------------===//
573573
private:
574-
// Function to do the basic implementation of a 'compute' operation, including
575-
// the clauses/etc. This might be generalizable in the future to work for
576-
// other constructs, or at least be the base for construct emission.
574+
// Function to do the basic implementation of an operation with an Associated
575+
// Statement. Models AssociatedStmtConstruct.
577576
template <typename Op, typename TermOp>
578577
mlir::LogicalResult
579-
emitOpenACCComputeOp(mlir::Location start, mlir::Location end,
580-
llvm::ArrayRef<const OpenACCClause *> clauses,
581-
const Stmt *structuredBlock);
578+
emitOpenACCOpAssociatedStmt(mlir::Location start, mlir::Location end,
579+
llvm::ArrayRef<const OpenACCClause *> clauses,
580+
const Stmt *associatedStmt);
582581

583582
public:
584583
mlir::LogicalResult

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,68 @@ class OpenACCClauseCIREmitter final
2727
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
2828
CIRGenModule &cgm;
2929

30+
struct AttributeData {
31+
// Value of the 'default' attribute, added on 'data' and 'compute'/etc
32+
// constructs as a 'default-attr'.
33+
std::optional<ClauseDefaultValue> defaultVal = std::nullopt;
34+
} attrData;
35+
3036
void clauseNotImplemented(const OpenACCClause &c) {
3137
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
3238
}
3339

3440
public:
3541
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {}
3642

37-
#define VISIT_CLAUSE(CN) \
38-
void Visit##CN##Clause(const OpenACC##CN##Clause &clause) { \
39-
clauseNotImplemented(clause); \
43+
void VisitClause(const OpenACCClause &clause) {
44+
clauseNotImplemented(clause);
45+
}
46+
47+
void VisitDefaultClause(const OpenACCDefaultClause &clause) {
48+
switch (clause.getDefaultClauseKind()) {
49+
case OpenACCDefaultClauseKind::None:
50+
attrData.defaultVal = ClauseDefaultValue::None;
51+
break;
52+
case OpenACCDefaultClauseKind::Present:
53+
attrData.defaultVal = ClauseDefaultValue::Present;
54+
break;
55+
case OpenACCDefaultClauseKind::Invalid:
56+
break;
57+
}
58+
}
59+
60+
// Apply any of the clauses that resulted in an 'attribute'.
61+
template <typename Op> void applyAttributes(Op &op) {
62+
if (attrData.defaultVal.has_value())
63+
op.setDefaultAttr(*attrData.defaultVal);
4064
}
41-
#include "clang/Basic/OpenACCClauses.def"
4265
};
4366
} // namespace
4467

4568
template <typename Op, typename TermOp>
46-
mlir::LogicalResult CIRGenFunction::emitOpenACCComputeOp(
69+
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
4770
mlir::Location start, mlir::Location end,
48-
llvm::ArrayRef<const OpenACCClause *> clauses,
49-
const Stmt *structuredBlock) {
71+
llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) {
5072
mlir::LogicalResult res = mlir::success();
5173

74+
llvm::SmallVector<mlir::Type> retTy;
75+
llvm::SmallVector<mlir::Value> operands;
76+
77+
// Clause-emitter must be here because it might modify operands.
5278
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
5379
clauseEmitter.VisitClauseList(clauses);
5480

55-
llvm::SmallVector<mlir::Type> retTy;
56-
llvm::SmallVector<mlir::Value> operands;
5781
auto op = builder.create<Op>(start, retTy, operands);
5882

83+
// Apply the attributes derived from the clauses.
84+
clauseEmitter.applyAttributes(op);
85+
5986
mlir::Block &block = op.getRegion().emplaceBlock();
6087
mlir::OpBuilder::InsertionGuard guardCase(builder);
6188
builder.setInsertionPointToEnd(&block);
6289

6390
LexicalScope ls{*this, start, builder.getInsertionBlock()};
64-
res = emitStmt(structuredBlock, /*useCurrentScope=*/true);
91+
res = emitStmt(associatedStmt, /*useCurrentScope=*/true);
6592

6693
builder.create<TermOp>(end);
6794
return res;
@@ -74,19 +101,28 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
74101

75102
switch (s.getDirectiveKind()) {
76103
case OpenACCDirectiveKind::Parallel:
77-
return emitOpenACCComputeOp<ParallelOp, mlir::acc::YieldOp>(
104+
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
78105
start, end, s.clauses(), s.getStructuredBlock());
79106
case OpenACCDirectiveKind::Serial:
80-
return emitOpenACCComputeOp<SerialOp, mlir::acc::YieldOp>(
107+
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
81108
start, end, s.clauses(), s.getStructuredBlock());
82109
case OpenACCDirectiveKind::Kernels:
83-
return emitOpenACCComputeOp<KernelsOp, mlir::acc::TerminatorOp>(
110+
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
84111
start, end, s.clauses(), s.getStructuredBlock());
85112
default:
86113
llvm_unreachable("invalid compute construct kind");
87114
}
88115
}
89116

117+
mlir::LogicalResult
118+
CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
119+
mlir::Location start = getLoc(s.getSourceRange().getEnd());
120+
mlir::Location end = getLoc(s.getSourceRange().getEnd());
121+
122+
return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
123+
start, end, s.clauses(), s.getStructuredBlock());
124+
}
125+
90126
mlir::LogicalResult
91127
CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) {
92128
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Loop Construct");
@@ -97,11 +133,6 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(
97133
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Combined Construct");
98134
return mlir::failure();
99135
}
100-
mlir::LogicalResult
101-
CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
102-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Data Construct");
103-
return mlir::failure();
104-
}
105136
mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(
106137
const OpenACCEnterDataConstruct &s) {
107138
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC EnterData Construct");

clang/test/CIR/CodeGenOpenACC/data.c

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %clang_cc1 -fopenacc -emit-cir -fclangir %s -o - | FileCheck %s
2+
3+
void acc_data(void) {
4+
// CHECK: cir.func @acc_data() {
5+
6+
#pragma acc data default(none)
7+
{
8+
int i = 0;
9+
++i;
10+
}
11+
// CHECK-NEXT: acc.data {
12+
// CHECK-NEXT: cir.alloca
13+
// CHECK-NEXT: cir.const
14+
// CHECK-NEXT: cir.store
15+
// CHECK-NEXT: cir.load
16+
// CHECK-NEXT: cir.unary
17+
// CHECK-NEXT: cir.store
18+
// CHECK-NEXT: acc.terminator
19+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
20+
21+
#pragma acc data default(present)
22+
{
23+
int i = 0;
24+
++i;
25+
}
26+
// CHECK-NEXT: acc.data {
27+
// CHECK-NEXT: cir.alloca
28+
// CHECK-NEXT: cir.const
29+
// CHECK-NEXT: cir.store
30+
// CHECK-NEXT: cir.load
31+
// CHECK-NEXT: cir.unary
32+
// CHECK-NEXT: cir.store
33+
// CHECK-NEXT: acc.terminator
34+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue present>}
35+
36+
// CHECK-NEXT: cir.return
37+
}

clang/test/CIR/CodeGenOpenACC/kernels.c

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,21 @@ void acc_kernels(void) {
66
{}
77

88
// CHECK-NEXT: acc.kernels {
9-
// CHECK-NEXT:acc.terminator
9+
// CHECK-NEXT: acc.terminator
1010
// CHECK-NEXT:}
1111

12+
#pragma acc kernels default(none)
13+
{}
14+
// CHECK-NEXT: acc.kernels {
15+
// CHECK-NEXT: acc.terminator
16+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
17+
18+
#pragma acc kernels default(present)
19+
{}
20+
// CHECK-NEXT: acc.kernels {
21+
// CHECK-NEXT: acc.terminator
22+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue present>}
23+
1224
#pragma acc kernels
1325
while(1){}
1426
// CHECK-NEXT: acc.kernels {
@@ -23,7 +35,7 @@ void acc_kernels(void) {
2335
// CHECK-NEXT: }
2436
// cir.scope end:
2537
// CHECK-NEXT: }
26-
// CHECK-NEXT:acc.terminator
38+
// CHECK-NEXT: acc.terminator
2739
// CHECK-NEXT:}
2840

2941
// CHECK-NEXT: cir.return

clang/test/CIR/CodeGenOpenACC/parallel.c

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,21 @@ void acc_parallel(void) {
55
#pragma acc parallel
66
{}
77
// CHECK-NEXT: acc.parallel {
8-
// CHECK-NEXT:acc.yield
8+
// CHECK-NEXT: acc.yield
99
// CHECK-NEXT:}
1010

11+
#pragma acc parallel default(none)
12+
{}
13+
// CHECK-NEXT: acc.parallel {
14+
// CHECK-NEXT: acc.yield
15+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
16+
17+
#pragma acc parallel default(present)
18+
{}
19+
// CHECK-NEXT: acc.parallel {
20+
// CHECK-NEXT: acc.yield
21+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue present>}
22+
1123
#pragma acc parallel
1224
while(1){}
1325
// CHECK-NEXT: acc.parallel {
@@ -22,7 +34,7 @@ void acc_parallel(void) {
2234
// CHECK-NEXT: }
2335
// cir.scope end:
2436
// CHECK-NEXT: }
25-
// CHECK-NEXT:acc.yield
37+
// CHECK-NEXT: acc.yield
2638
// CHECK-NEXT:}
2739

2840
// CHECK-NEXT: cir.return

clang/test/CIR/CodeGenOpenACC/serial.c

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,21 @@ void acc_serial(void) {
66
{}
77

88
// CHECK-NEXT: acc.serial {
9-
// CHECK-NEXT:acc.yield
9+
// CHECK-NEXT: acc.yield
1010
// CHECK-NEXT:}
1111

12+
#pragma acc serial default(none)
13+
{}
14+
// CHECK-NEXT: acc.serial {
15+
// CHECK-NEXT: acc.yield
16+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
17+
18+
#pragma acc serial default(present)
19+
{}
20+
// CHECK-NEXT: acc.serial {
21+
// CHECK-NEXT: acc.yield
22+
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue present>}
23+
1224
#pragma acc serial
1325
while(1){}
1426
// CHECK-NEXT: acc.serial {
@@ -23,7 +35,7 @@ void acc_serial(void) {
2335
// CHECK-NEXT: }
2436
// cir.scope end:
2537
// CHECK-NEXT: }
26-
// CHECK-NEXT:acc.yield
38+
// CHECK-NEXT: acc.yield
2739
// CHECK-NEXT:}
2840

2941
// CHECK-NEXT: cir.return

0 commit comments

Comments
 (0)