Skip to content

Commit 6e7c40b

Browse files
authored
[OpenACC][CIR] Initial patch to do OpenACC->IR lowering (#134936)
This patch adds some lowering code for Compute Constructs, plus the infrastructure to someday do clauses. Doing this requires adding the dialect to the CIRGenerator. This patch does not however implement/correctly initialize lowering from OpenACC-Dialect to anything lower however.
1 parent 750da48 commit 6e7c40b

File tree

8 files changed

+169
-4
lines changed

8 files changed

+169
-4
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class OpenACCClause {
3838
OpenACCClauseKind getClauseKind() const { return Kind; }
3939
SourceLocation getBeginLoc() const { return Location.getBegin(); }
4040
SourceLocation getEndLoc() const { return Location.getEnd(); }
41+
SourceRange getSourceRange() const { return Location; }
4142

4243
static bool classof(const OpenACCClause *) { return true; }
4344

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,16 @@ class CIRGenFunction : public CIRGenTypeCache {
570570
//===--------------------------------------------------------------------===//
571571
// OpenACC Emission
572572
//===--------------------------------------------------------------------===//
573+
private:
574+
// Function to do the basic implementation of a 'compute' operation, including
575+
// the clauses/etc. This might be generalizable in the future to work for
576+
// other constructs, or at least be the base for construct emission.
577+
template <typename Op, typename TermOp>
578+
mlir::LogicalResult
579+
emitOpenACCComputeOp(mlir::Location start, mlir::Location end,
580+
llvm::ArrayRef<const OpenACCClause *> clauses,
581+
const Stmt *structuredBlock);
582+
573583
public:
574584
mlir::LogicalResult
575585
emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s);

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,79 @@
1212

1313
#include "CIRGenBuilder.h"
1414
#include "CIRGenFunction.h"
15+
#include "clang/AST/OpenACCClause.h"
1516
#include "clang/AST/StmtOpenACC.h"
1617

18+
#include "mlir/Dialect/OpenACC/OpenACC.h"
19+
1720
using namespace clang;
1821
using namespace clang::CIRGen;
1922
using namespace cir;
23+
using namespace mlir::acc;
24+
25+
namespace {
26+
class OpenACCClauseCIREmitter final
27+
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> {
28+
CIRGenModule &cgm;
29+
30+
void clauseNotImplemented(const OpenACCClause &c) {
31+
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
32+
}
33+
34+
public:
35+
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {}
36+
37+
#define VISIT_CLAUSE(CN) \
38+
void Visit##CN##Clause(const OpenACC##CN##Clause &clause) { \
39+
clauseNotImplemented(clause); \
40+
}
41+
#include "clang/Basic/OpenACCClauses.def"
42+
};
43+
} // namespace
44+
45+
template <typename Op, typename TermOp>
46+
mlir::LogicalResult CIRGenFunction::emitOpenACCComputeOp(
47+
mlir::Location start, mlir::Location end,
48+
llvm::ArrayRef<const OpenACCClause *> clauses,
49+
const Stmt *structuredBlock) {
50+
mlir::LogicalResult res = mlir::success();
51+
52+
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule());
53+
clauseEmitter.VisitClauseList(clauses);
54+
55+
llvm::SmallVector<mlir::Type> retTy;
56+
llvm::SmallVector<mlir::Value> operands;
57+
auto op = builder.create<Op>(start, retTy, operands);
58+
59+
mlir::Block &block = op.getRegion().emplaceBlock();
60+
mlir::OpBuilder::InsertionGuard guardCase(builder);
61+
builder.setInsertionPointToEnd(&block);
62+
63+
LexicalScope ls{*this, start, builder.getInsertionBlock()};
64+
res = emitStmt(structuredBlock, /*useCurrentScope=*/true);
65+
66+
builder.create<TermOp>(end);
67+
return res;
68+
}
2069

2170
mlir::LogicalResult
2271
CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
23-
getCIRGenModule().errorNYI(s.getSourceRange(), "OpenACC Compute Construct");
24-
return mlir::failure();
72+
mlir::Location start = getLoc(s.getSourceRange().getEnd());
73+
mlir::Location end = getLoc(s.getSourceRange().getEnd());
74+
75+
switch (s.getDirectiveKind()) {
76+
case OpenACCDirectiveKind::Parallel:
77+
return emitOpenACCComputeOp<ParallelOp, mlir::acc::YieldOp>(
78+
start, end, s.clauses(), s.getStructuredBlock());
79+
case OpenACCDirectiveKind::Serial:
80+
return emitOpenACCComputeOp<SerialOp, mlir::acc::YieldOp>(
81+
start, end, s.clauses(), s.getStructuredBlock());
82+
case OpenACCDirectiveKind::Kernels:
83+
return emitOpenACCComputeOp<KernelsOp, mlir::acc::TerminatorOp>(
84+
start, end, s.clauses(), s.getStructuredBlock());
85+
default:
86+
llvm_unreachable("invalid compute construct kind");
87+
}
2588
}
2689

2790
mlir::LogicalResult

clang/lib/CIR/CodeGen/CIRGenerator.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "CIRGenModule.h"
1414

15+
#include "mlir/Dialect/OpenACC/OpenACC.h"
1516
#include "mlir/IR/MLIRContext.h"
1617

1718
#include "clang/AST/DeclGroup.h"
@@ -36,6 +37,7 @@ void CIRGenerator::Initialize(ASTContext &astContext) {
3637

3738
mlirContext = std::make_unique<mlir::MLIRContext>();
3839
mlirContext->loadDialect<cir::CIRDialect>();
40+
mlirContext->getOrLoadDialect<mlir::acc::OpenACCDialect>();
3941
cgm = std::make_unique<clang::CIRGen::CIRGenModule>(
4042
*mlirContext.get(), astContext, codeGenOpts, diags);
4143
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -fopenacc -emit-cir -fclangir %s -o - | FileCheck %s
2+
3+
void acc_kernels(void) {
4+
// CHECK: cir.func @acc_kernels() {
5+
#pragma acc kernels
6+
{}
7+
8+
// CHECK-NEXT: acc.kernels {
9+
// CHECK-NEXT:acc.terminator
10+
// CHECK-NEXT:}
11+
12+
#pragma acc kernels
13+
while(1){}
14+
// CHECK-NEXT: acc.kernels {
15+
// CHECK-NEXT: cir.scope {
16+
// CHECK-NEXT: cir.while {
17+
// CHECK-NEXT: %[[INT:.*]] = cir.const #cir.int<1>
18+
// CHECK-NEXT: %[[CAST:.*]] = cir.cast(int_to_bool, %[[INT]] :
19+
// CHECK-NEXT: cir.condition(%[[CAST]])
20+
// CHECK-NEXT: } do {
21+
// CHECK-NEXT: cir.yield
22+
// cir.while do end:
23+
// CHECK-NEXT: }
24+
// cir.scope end:
25+
// CHECK-NEXT: }
26+
// CHECK-NEXT:acc.terminator
27+
// CHECK-NEXT:}
28+
29+
// CHECK-NEXT: cir.return
30+
}

clang/test/CIR/CodeGenOpenACC/openacc-not-implemented.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
void HelloWorld(int *A, int *B, int *C, int N) {
55

6-
// expected-error@+2{{ClangIR code gen Not Yet Implemented: OpenACC Compute Construct}}
6+
// expected-error@+2{{ClangIR code gen Not Yet Implemented: OpenACC Combined Construct}}
77
// expected-error@+1{{ClangIR code gen Not Yet Implemented: statement}}
8-
#pragma acc parallel
8+
#pragma acc parallel loop
99
for (unsigned I = 0; I < N; ++I)
1010
A[I] = B[I] + C[I];
1111

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %clang_cc1 -fopenacc -emit-cir -fclangir %s -o - | FileCheck %s
2+
3+
void acc_parallel(void) {
4+
// CHECK: cir.func @acc_parallel() {
5+
#pragma acc parallel
6+
{}
7+
// CHECK-NEXT: acc.parallel {
8+
// CHECK-NEXT:acc.yield
9+
// CHECK-NEXT:}
10+
11+
#pragma acc parallel
12+
while(1){}
13+
// CHECK-NEXT: acc.parallel {
14+
// CHECK-NEXT: cir.scope {
15+
// CHECK-NEXT: cir.while {
16+
// CHECK-NEXT: %[[INT:.*]] = cir.const #cir.int<1>
17+
// CHECK-NEXT: %[[CAST:.*]] = cir.cast(int_to_bool, %[[INT]] :
18+
// CHECK-NEXT: cir.condition(%[[CAST]])
19+
// CHECK-NEXT: } do {
20+
// CHECK-NEXT: cir.yield
21+
// cir.while do end:
22+
// CHECK-NEXT: }
23+
// cir.scope end:
24+
// CHECK-NEXT: }
25+
// CHECK-NEXT:acc.yield
26+
// CHECK-NEXT:}
27+
28+
// CHECK-NEXT: cir.return
29+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -fopenacc -emit-cir -fclangir %s -o - | FileCheck %s
2+
3+
void acc_serial(void) {
4+
// CHECK: cir.func @acc_serial() {
5+
#pragma acc serial
6+
{}
7+
8+
// CHECK-NEXT: acc.serial {
9+
// CHECK-NEXT:acc.yield
10+
// CHECK-NEXT:}
11+
12+
#pragma acc serial
13+
while(1){}
14+
// CHECK-NEXT: acc.serial {
15+
// CHECK-NEXT: cir.scope {
16+
// CHECK-NEXT: cir.while {
17+
// CHECK-NEXT: %[[INT:.*]] = cir.const #cir.int<1>
18+
// CHECK-NEXT: %[[CAST:.*]] = cir.cast(int_to_bool, %[[INT]] :
19+
// CHECK-NEXT: cir.condition(%[[CAST]])
20+
// CHECK-NEXT: } do {
21+
// CHECK-NEXT: cir.yield
22+
// cir.while do end:
23+
// CHECK-NEXT: }
24+
// cir.scope end:
25+
// CHECK-NEXT: }
26+
// CHECK-NEXT:acc.yield
27+
// CHECK-NEXT:}
28+
29+
// CHECK-NEXT: cir.return
30+
}

0 commit comments

Comments
 (0)