Skip to content

Commit f171e05

Browse files
authored
[NFC][OpenACC] Refactor clause emission- (#140586)
Having the whole clause emission be in a header file ended up being pragmatic, but ended up being a sizable negative for a variety of reasons. This patch moves it to its own .cpp file and makes CIRGenFunction instead call into the visitor via a template instead. This is possible because the valid list of construct kinds is quite finite, and easy to enumerate.
1 parent 5db4aea commit f171e05

File tree

5 files changed

+154
-125
lines changed

5 files changed

+154
-125
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,21 @@ class CIRGenFunction : public CIRGenTypeCache {
739739
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
740740
const Stmt *loopStmt);
741741

742+
template <typename Op>
743+
void emitOpenACCClauses(Op &op, OpenACCDirectiveKind dirKind,
744+
SourceLocation dirLoc,
745+
ArrayRef<const OpenACCClause *> clauses);
746+
// The second template argument doesn't need to be a template, since it should
747+
// always be an mlir::acc::LoopOp, but as this is a template anyway, we make
748+
// it a template argument as this way we can avoid including the OpenACC MLIR
749+
// headers here. We will count on linker failures/explicit instantiation to
750+
// ensure we don't mess this up, but it is only called from 1 place, and
751+
// instantiated 3x.
752+
template <typename ComputeOp, typename LoopOp>
753+
void emitOpenACCClauses(ComputeOp &op, LoopOp &loopOp,
754+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
755+
ArrayRef<const OpenACCClause *> clauses);
756+
742757
public:
743758
mlir::LogicalResult
744759
emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s);

clang/lib/CIR/CodeGen/CIRGenOpenACCClause.h renamed to clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp

Lines changed: 132 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@
1212

1313
#include <type_traits>
1414

15+
#include "CIRGenFunction.h"
16+
1517
#include "mlir/Dialect/Arith/IR/Arith.h"
1618
#include "mlir/Dialect/OpenACC/OpenACC.h"
1719
#include "llvm/ADT/TypeSwitch.h"
18-
namespace clang {
20+
21+
using namespace clang;
22+
using namespace clang::CIRGen;
23+
24+
namespace {
1925
// Simple type-trait to see if the first template arg is one of the list, so we
2026
// can tell whether to `if-constexpr` a bunch of stuff.
2127
template <typename ToTest, typename T, typename... Tys>
@@ -32,77 +38,10 @@ template <typename CompOpTy> struct CombinedConstructClauseInfo {
3238
ComputeOpTy computeOp;
3339
mlir::acc::LoopOp loopOp;
3440
};
35-
3641
template <typename ToTest> constexpr bool isCombinedType = false;
3742
template <typename T>
3843
constexpr bool isCombinedType<CombinedConstructClauseInfo<T>> = true;
3944

40-
namespace {
41-
struct DataOperandInfo {
42-
mlir::Location beginLoc;
43-
mlir::Value varValue;
44-
llvm::StringRef name;
45-
};
46-
47-
inline mlir::Value emitOpenACCIntExpr(CIRGen::CIRGenFunction &cgf,
48-
CIRGen::CIRGenBuilderTy &builder,
49-
const Expr *intExpr) {
50-
mlir::Value expr = cgf.emitScalarExpr(intExpr);
51-
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());
52-
53-
mlir::IntegerType targetType = mlir::IntegerType::get(
54-
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
55-
intExpr->getType()->isSignedIntegerOrEnumerationType()
56-
? mlir::IntegerType::SignednessSemantics::Signed
57-
: mlir::IntegerType::SignednessSemantics::Unsigned);
58-
59-
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
60-
exprLoc, targetType, expr);
61-
return conversionOp.getResult(0);
62-
}
63-
64-
// A helper function that gets the information from an operand to a data
65-
// clause, so that it can be used to emit the data operations.
66-
inline DataOperandInfo getDataOperandInfo(CIRGen::CIRGenFunction &cgf,
67-
CIRGen::CIRGenBuilderTy &builder,
68-
OpenACCDirectiveKind dk,
69-
const Expr *e) {
70-
// TODO: OpenACC: Cache was different enough as to need a separate
71-
// `ActOnCacheVar`, so we are going to need to do some investigations here
72-
// when it comes to implement this for cache.
73-
if (dk == OpenACCDirectiveKind::Cache) {
74-
cgf.cgm.errorNYI(e->getSourceRange(),
75-
"OpenACC data operand for 'cache' directive");
76-
return {cgf.cgm.getLoc(e->getBeginLoc()), {}, {}};
77-
}
78-
79-
const Expr *curVarExpr = e->IgnoreParenImpCasts();
80-
81-
mlir::Location exprLoc = cgf.cgm.getLoc(curVarExpr->getBeginLoc());
82-
83-
// TODO: OpenACC: Assemble the list of bounds.
84-
if (isa<ArraySectionExpr, ArraySubscriptExpr>(curVarExpr)) {
85-
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
86-
"OpenACC data clause array subscript/section");
87-
return {exprLoc, {}, {}};
88-
}
89-
90-
// TODO: OpenACC: if this is a member expr, emit the VarPtrPtr correctly.
91-
if (isa<MemberExpr>(curVarExpr)) {
92-
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
93-
"OpenACC Data clause member expr");
94-
return {exprLoc, {}, {}};
95-
}
96-
97-
// Sema has made sure that only 4 types of things can get here, array
98-
// subscript, array section, member expr, or DRE to a var decl (or the former
99-
// 3 wrapping a var-decl), so we should be able to assume this is right.
100-
const auto *dre = cast<DeclRefExpr>(curVarExpr);
101-
const auto *vd = cast<VarDecl>(dre->getFoundDecl()->getCanonicalDecl());
102-
return {exprLoc, cgf.emitDeclRefLValue(dre).getPointer(), vd->getName()};
103-
}
104-
} // namespace
105-
10645
template <typename OpTy>
10746
class OpenACCClauseCIREmitter final
10847
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
@@ -127,6 +66,10 @@ class OpenACCClauseCIREmitter final
12766
// Keep track of the data operands so that we can update their async clauses.
12867
llvm::SmallVector<mlir::Operation *> dataOperands;
12968

69+
void clauseNotImplemented(const OpenACCClause &c) {
70+
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
71+
}
72+
13073
void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
13174
lastDeviceTypeValues.clear();
13275

@@ -137,12 +80,19 @@ class OpenACCClauseCIREmitter final
13780
});
13881
}
13982

140-
void clauseNotImplemented(const OpenACCClause &c) {
141-
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
142-
}
83+
mlir::Value emitIntExpr(const Expr *intExpr) {
84+
mlir::Value expr = cgf.emitScalarExpr(intExpr);
85+
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());
14386

144-
mlir::Value emitOpenACCIntExpr(const Expr *intExpr) {
145-
return clang::emitOpenACCIntExpr(cgf, builder, intExpr);
87+
mlir::IntegerType targetType = mlir::IntegerType::get(
88+
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
89+
intExpr->getType()->isSignedIntegerOrEnumerationType()
90+
? mlir::IntegerType::SignednessSemantics::Signed
91+
: mlir::IntegerType::SignednessSemantics::Unsigned);
92+
93+
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
94+
exprLoc, targetType, expr);
95+
return conversionOp.getResult(0);
14696
}
14797

14898
// 'condition' as an OpenACC grammar production is used for 'if' and (some
@@ -218,11 +168,56 @@ class OpenACCClauseCIREmitter final
218168
computeEmitter.Visit(&c);
219169
}
220170

171+
struct DataOperandInfo {
172+
mlir::Location beginLoc;
173+
mlir::Value varValue;
174+
llvm::StringRef name;
175+
};
176+
177+
// A helper function that gets the information from an operand to a data
178+
// clause, so that it can be used to emit the data operations.
179+
inline DataOperandInfo getDataOperandInfo(OpenACCDirectiveKind dk,
180+
const Expr *e) {
181+
// TODO: OpenACC: Cache was different enough as to need a separate
182+
// `ActOnCacheVar`, so we are going to need to do some investigations here
183+
// when it comes to implement this for cache.
184+
if (dk == OpenACCDirectiveKind::Cache) {
185+
cgf.cgm.errorNYI(e->getSourceRange(),
186+
"OpenACC data operand for 'cache' directive");
187+
return {cgf.cgm.getLoc(e->getBeginLoc()), {}, {}};
188+
}
189+
190+
const Expr *curVarExpr = e->IgnoreParenImpCasts();
191+
192+
mlir::Location exprLoc = cgf.cgm.getLoc(curVarExpr->getBeginLoc());
193+
194+
// TODO: OpenACC: Assemble the list of bounds.
195+
if (isa<ArraySectionExpr, ArraySubscriptExpr>(curVarExpr)) {
196+
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
197+
"OpenACC data clause array subscript/section");
198+
return {exprLoc, {}, {}};
199+
}
200+
201+
// TODO: OpenACC: if this is a member expr, emit the VarPtrPtr correctly.
202+
if (isa<MemberExpr>(curVarExpr)) {
203+
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
204+
"OpenACC Data clause member expr");
205+
return {exprLoc, {}, {}};
206+
}
207+
208+
// Sema has made sure that only 4 types of things can get here, array
209+
// subscript, array section, member expr, or DRE to a var decl (or the
210+
// former 3 wrapping a var-decl), so we should be able to assume this is
211+
// right.
212+
const auto *dre = cast<DeclRefExpr>(curVarExpr);
213+
const auto *vd = cast<VarDecl>(dre->getFoundDecl()->getCanonicalDecl());
214+
return {exprLoc, cgf.emitDeclRefLValue(dre).getPointer(), vd->getName()};
215+
}
216+
221217
template <typename BeforeOpTy, typename AfterOpTy>
222218
void addDataOperand(const Expr *varOperand, mlir::acc::DataClause dataClause,
223219
bool structured, bool implicit) {
224-
DataOperandInfo opInfo =
225-
getDataOperandInfo(cgf, builder, dirKind, varOperand);
220+
DataOperandInfo opInfo = getDataOperandInfo(dirKind, varOperand);
226221
mlir::ValueRange bounds;
227222

228223
// TODO: OpenACC: we should comprehend the 'modifier-list' here for the data
@@ -394,7 +389,7 @@ class OpenACCClauseCIREmitter final
394389
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
395390
mlir::acc::KernelsOp>) {
396391
operation.addNumWorkersOperand(builder.getContext(),
397-
emitOpenACCIntExpr(clause.getIntExpr()),
392+
emitIntExpr(clause.getIntExpr()),
398393
lastDeviceTypeValues);
399394
} else if constexpr (isCombinedType<OpTy>) {
400395
applyToComputeOp(clause);
@@ -407,7 +402,7 @@ class OpenACCClauseCIREmitter final
407402
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
408403
mlir::acc::KernelsOp>) {
409404
operation.addVectorLengthOperand(builder.getContext(),
410-
emitOpenACCIntExpr(clause.getIntExpr()),
405+
emitIntExpr(clause.getIntExpr()),
411406
lastDeviceTypeValues);
412407
} else if constexpr (isCombinedType<OpTy>) {
413408
applyToComputeOp(clause);
@@ -432,7 +427,7 @@ class OpenACCClauseCIREmitter final
432427
mlir::OpBuilder::InsertionGuard guardCase(builder);
433428
if (!dataOperands.empty())
434429
builder.setInsertionPoint(dataOperands.front());
435-
intExpr = emitOpenACCIntExpr(clause.getIntExpr());
430+
intExpr = emitIntExpr(clause.getIntExpr());
436431
}
437432
operation.addAsyncOperand(builder.getContext(), intExpr,
438433
lastDeviceTypeValues);
@@ -444,7 +439,7 @@ class OpenACCClauseCIREmitter final
444439
operation.setAsync(true);
445440
else
446441
operation.getAsyncOperandMutable().append(
447-
emitOpenACCIntExpr(clause.getIntExpr()));
442+
emitIntExpr(clause.getIntExpr()));
448443
} else if constexpr (isCombinedType<OpTy>) {
449444
applyToComputeOp(clause);
450445
} else {
@@ -499,8 +494,7 @@ class OpenACCClauseCIREmitter final
499494
void VisitDeviceNumClause(const OpenACCDeviceNumClause &clause) {
500495
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp, mlir::acc::ShutdownOp,
501496
mlir::acc::SetOp>) {
502-
operation.getDeviceNumMutable().append(
503-
emitOpenACCIntExpr(clause.getIntExpr()));
497+
operation.getDeviceNumMutable().append(emitIntExpr(clause.getIntExpr()));
504498
} else {
505499
llvm_unreachable(
506500
"init, shutdown, set, are only valid device_num constructs");
@@ -512,7 +506,7 @@ class OpenACCClauseCIREmitter final
512506
mlir::acc::KernelsOp>) {
513507
llvm::SmallVector<mlir::Value> values;
514508
for (const Expr *E : clause.getIntExprs())
515-
values.push_back(emitOpenACCIntExpr(E));
509+
values.push_back(emitIntExpr(E));
516510

517511
operation.addNumGangsOperands(builder.getContext(), values,
518512
lastDeviceTypeValues);
@@ -531,9 +525,9 @@ class OpenACCClauseCIREmitter final
531525
} else {
532526
llvm::SmallVector<mlir::Value> values;
533527
if (clause.hasDevNumExpr())
534-
values.push_back(emitOpenACCIntExpr(clause.getDevNumExpr()));
528+
values.push_back(emitIntExpr(clause.getDevNumExpr()));
535529
for (const Expr *E : clause.getQueueIdExprs())
536-
values.push_back(emitOpenACCIntExpr(E));
530+
values.push_back(emitIntExpr(E));
537531
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
538532
values, lastDeviceTypeValues);
539533
}
@@ -549,7 +543,7 @@ class OpenACCClauseCIREmitter final
549543
void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
550544
if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
551545
operation.getDefaultAsyncMutable().append(
552-
emitOpenACCIntExpr(clause.getIntExpr()));
546+
emitIntExpr(clause.getIntExpr()));
553547
} else {
554548
llvm_unreachable("set, is only valid device_num constructs");
555549
}
@@ -639,7 +633,7 @@ class OpenACCClauseCIREmitter final
639633
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
640634
if (clause.hasIntExpr())
641635
operation.addWorkerNumOperand(builder.getContext(),
642-
emitOpenACCIntExpr(clause.getIntExpr()),
636+
emitIntExpr(clause.getIntExpr()),
643637
lastDeviceTypeValues);
644638
else
645639
operation.addEmptyWorker(builder.getContext(), lastDeviceTypeValues);
@@ -657,7 +651,7 @@ class OpenACCClauseCIREmitter final
657651
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
658652
if (clause.hasIntExpr())
659653
operation.addVectorOperand(builder.getContext(),
660-
emitOpenACCIntExpr(clause.getIntExpr()),
654+
emitIntExpr(clause.getIntExpr()),
661655
lastDeviceTypeValues);
662656
else
663657
operation.addEmptyVector(builder.getContext(), lastDeviceTypeValues);
@@ -693,7 +687,7 @@ class OpenACCClauseCIREmitter final
693687
} else if (isa<OpenACCAsteriskSizeExpr>(expr)) {
694688
values.push_back(createConstantInt(exprLoc, 64, -1));
695689
} else {
696-
values.push_back(emitOpenACCIntExpr(expr));
690+
values.push_back(emitIntExpr(expr));
697691
}
698692
}
699693

@@ -728,5 +722,54 @@ auto makeClauseEmitter(OpTy &op, CIRGen::CIRGenFunction &cgf,
728722
OpenACCDirectiveKind dirKind, SourceLocation dirLoc) {
729723
return OpenACCClauseCIREmitter<OpTy>(op, cgf, builder, dirKind, dirLoc);
730724
}
725+
} // namespace
726+
727+
template <typename Op>
728+
void CIRGenFunction::emitOpenACCClauses(
729+
Op &op, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
730+
ArrayRef<const OpenACCClause *> clauses) {
731+
mlir::OpBuilder::InsertionGuard guardCase(builder);
732+
733+
// Sets insertion point before the 'op', since every new expression needs to
734+
// be before the operation.
735+
builder.setInsertionPoint(op);
736+
makeClauseEmitter(op, *this, builder, dirKind, dirLoc).emitClauses(clauses);
737+
}
738+
739+
#define EXPL_SPEC(N) \
740+
template void CIRGenFunction::emitOpenACCClauses<N>( \
741+
N &, OpenACCDirectiveKind, SourceLocation, \
742+
ArrayRef<const OpenACCClause *>);
743+
EXPL_SPEC(mlir::acc::ParallelOp)
744+
EXPL_SPEC(mlir::acc::SerialOp)
745+
EXPL_SPEC(mlir::acc::KernelsOp)
746+
EXPL_SPEC(mlir::acc::LoopOp)
747+
EXPL_SPEC(mlir::acc::DataOp)
748+
EXPL_SPEC(mlir::acc::InitOp)
749+
EXPL_SPEC(mlir::acc::ShutdownOp)
750+
EXPL_SPEC(mlir::acc::SetOp)
751+
EXPL_SPEC(mlir::acc::WaitOp)
752+
#undef EXPL_SPEC
753+
754+
template <typename ComputeOp, typename LoopOp>
755+
void CIRGenFunction::emitOpenACCClauses(
756+
ComputeOp &op, LoopOp &loopOp, OpenACCDirectiveKind dirKind,
757+
SourceLocation dirLoc, ArrayRef<const OpenACCClause *> clauses) {
758+
static_assert(std::is_same_v<mlir::acc::LoopOp, LoopOp>);
759+
760+
CombinedConstructClauseInfo<ComputeOp> inf{op, loopOp};
761+
// We cannot set the insertion point here and do so in the emitter, but make
762+
// sure we reset it with the 'guard' anyway.
763+
mlir::OpBuilder::InsertionGuard guardCase(builder);
764+
makeClauseEmitter(inf, *this, builder, dirKind, dirLoc).emitClauses(clauses);
765+
}
766+
767+
#define EXPL_SPEC(N) \
768+
template void CIRGenFunction::emitOpenACCClauses<N, mlir::acc::LoopOp>( \
769+
N &, mlir::acc::LoopOp &, OpenACCDirectiveKind, SourceLocation, \
770+
ArrayRef<const OpenACCClause *>);
731771

732-
} // namespace clang
772+
EXPL_SPEC(mlir::acc::ParallelOp)
773+
EXPL_SPEC(mlir::acc::SerialOp)
774+
EXPL_SPEC(mlir::acc::KernelsOp)
775+
#undef EXPL_SPEC

0 commit comments

Comments
 (0)