-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[OpenACC][CIR] Implement 'device_type' clause lowering for 'init'/'sh… #135102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
// Emit OpenACC Stmt nodes as CIR code. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#include <type_traits> | ||
|
||
#include "CIRGenBuilder.h" | ||
#include "CIRGenFunction.h" | ||
|
@@ -23,22 +24,39 @@ using namespace cir; | |
using namespace mlir::acc; | ||
|
||
namespace { | ||
// Simple type-trait to see if the first template arg is one of the list, so we | ||
// can tell whether to `if-constexpr` a bunch of stuff. | ||
template <typename ToTest, typename T, typename... Tys> | ||
constexpr bool isOneOfTypes = | ||
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>; | ||
template <typename ToTest, typename T> | ||
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>; | ||
|
||
class OpenACCClauseCIREmitter final | ||
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter> { | ||
CIRGenModule &cgm; | ||
// This is necessary since a few of the clauses emit differently based on the | ||
// directive kind they are attached to. | ||
OpenACCDirectiveKind dirKind; | ||
SourceLocation dirLoc; | ||
|
||
struct AttributeData { | ||
// Value of the 'default' attribute, added on 'data' and 'compute'/etc | ||
// constructs as a 'default-attr'. | ||
std::optional<ClauseDefaultValue> defaultVal = std::nullopt; | ||
// For directives that have their device type architectures listed in | ||
// attributes (init/shutdown/etc), the list of architectures to be emitted. | ||
llvm::SmallVector<mlir::acc::DeviceType> deviceTypeArchs{}; | ||
} attrData; | ||
|
||
void clauseNotImplemented(const OpenACCClause &c) { | ||
cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind()); | ||
} | ||
|
||
public: | ||
OpenACCClauseCIREmitter(CIRGenModule &cgm) : cgm(cgm) {} | ||
OpenACCClauseCIREmitter(CIRGenModule &cgm, OpenACCDirectiveKind dirKind, | ||
SourceLocation dirLoc) | ||
: cgm(cgm), dirKind(dirKind), dirLoc(dirLoc) {} | ||
|
||
void VisitClause(const OpenACCClause &clause) { | ||
clauseNotImplemented(clause); | ||
|
@@ -57,31 +75,92 @@ class OpenACCClauseCIREmitter final | |
} | ||
} | ||
|
||
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) { | ||
// '*' case leaves no identifier-info, just a nullptr. | ||
if (!ii) | ||
return mlir::acc::DeviceType::Star; | ||
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName()) | ||
.CaseLower("default", mlir::acc::DeviceType::Default) | ||
.CaseLower("host", mlir::acc::DeviceType::Host) | ||
.CaseLower("multicore", mlir::acc::DeviceType::Multicore) | ||
.CasesLower("nvidia", "acc_device_nvidia", | ||
mlir::acc::DeviceType::Nvidia) | ||
.CaseLower("radeon", mlir::acc::DeviceType::Radeon); | ||
} | ||
|
||
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) { | ||
|
||
switch (dirKind) { | ||
case OpenACCDirectiveKind::Init: | ||
case OpenACCDirectiveKind::Shutdown: { | ||
// Device type has a list that is either a 'star' (emitted as 'star'), | ||
// or an identifer list, all of which get added for attributes. | ||
|
||
for (const DeviceTypeArgument &arg : clause.getArchitectures()) | ||
attrData.deviceTypeArchs.push_back(decodeDeviceType(arg.first)); | ||
break; | ||
} | ||
default: | ||
return clauseNotImplemented(clause); | ||
} | ||
} | ||
|
||
// Apply any of the clauses that resulted in an 'attribute'. | ||
template <typename Op> void applyAttributes(Op &op) { | ||
if (attrData.defaultVal.has_value()) | ||
op.setDefaultAttr(*attrData.defaultVal); | ||
template <typename Op> | ||
void applyAttributes(CIRGenBuilderTy &builder, Op &op) { | ||
|
||
if (attrData.defaultVal.has_value()) { | ||
// FIXME: OpenACC: as we implement this for other directive kinds, we have | ||
// to expand this list. | ||
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc | ||
// operations listed in the rest of the arguments. | ||
if constexpr (isOneOfTypes<Op, ParallelOp, SerialOp, KernelsOp, DataOp>) | ||
op.setDefaultAttr(*attrData.defaultVal); | ||
else | ||
cgm.errorNYI(dirLoc, "OpenACC 'default' clause lowering for ", dirKind); | ||
} | ||
|
||
if (!attrData.deviceTypeArchs.empty()) { | ||
// FIXME: OpenACC: as we implement this for other directive kinds, we have | ||
// to expand this list, or more likely, have a 'noop' branch as most other | ||
// uses of this apply to the operands instead. | ||
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc | ||
if constexpr (isOneOfTypes<Op, InitOp, ShutdownOp>) { | ||
llvm::SmallVector<mlir::Attribute> deviceTypes; | ||
for (mlir::acc::DeviceType DT : attrData.deviceTypeArchs) | ||
deviceTypes.push_back( | ||
mlir::acc::DeviceTypeAttr::get(builder.getContext(), DT)); | ||
|
||
op.setDeviceTypesAttr( | ||
mlir::ArrayAttr::get(builder.getContext(), deviceTypes)); | ||
} else { | ||
cgm.errorNYI(dirLoc, "OpenACC 'device_type' clause lowering for ", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this really "NYI"? It feels more like we shouldn't get here. In either case, since the condition is static, can this be a static compile error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is, there are entries for Device_type that aren't implemented yet, the list on 126 is incomplete. Eventually, this will become a no-op/unreachable. It can't be a |
||
dirKind); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
template <typename Op, typename TermOp> | ||
mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt( | ||
mlir::Location start, mlir::Location end, | ||
llvm::ArrayRef<const OpenACCClause *> clauses, const Stmt *associatedStmt) { | ||
mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind, | ||
SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses, | ||
const Stmt *associatedStmt) { | ||
mlir::LogicalResult res = mlir::success(); | ||
|
||
llvm::SmallVector<mlir::Type> retTy; | ||
llvm::SmallVector<mlir::Value> operands; | ||
|
||
// Clause-emitter must be here because it might modify operands. | ||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule()); | ||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc); | ||
clauseEmitter.VisitClauseList(clauses); | ||
|
||
auto op = builder.create<Op>(start, retTy, operands); | ||
|
||
// Apply the attributes derived from the clauses. | ||
clauseEmitter.applyAttributes(op); | ||
clauseEmitter.applyAttributes(builder, op); | ||
|
||
mlir::Block &block = op.getRegion().emplaceBlock(); | ||
mlir::OpBuilder::InsertionGuard guardCase(builder); | ||
|
@@ -95,19 +174,21 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt( | |
} | ||
|
||
template <typename Op> | ||
mlir::LogicalResult | ||
CIRGenFunction::emitOpenACCOp(mlir::Location start, | ||
llvm::ArrayRef<const OpenACCClause *> clauses) { | ||
mlir::LogicalResult CIRGenFunction::emitOpenACCOp( | ||
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc, | ||
llvm::ArrayRef<const OpenACCClause *> clauses) { | ||
mlir::LogicalResult res = mlir::success(); | ||
|
||
llvm::SmallVector<mlir::Type> retTy; | ||
llvm::SmallVector<mlir::Value> operands; | ||
|
||
// Clause-emitter must be here because it might modify operands. | ||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule()); | ||
OpenACCClauseCIREmitter clauseEmitter(getCIRGenModule(), dirKind, dirLoc); | ||
clauseEmitter.VisitClauseList(clauses); | ||
|
||
builder.create<Op>(start, retTy, operands); | ||
auto op = builder.create<Op>(start, retTy, operands); | ||
// Apply the attributes derived from the clauses. | ||
clauseEmitter.applyAttributes(builder, op); | ||
return res; | ||
} | ||
|
||
|
@@ -119,13 +200,16 @@ CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) { | |
switch (s.getDirectiveKind()) { | ||
case OpenACCDirectiveKind::Parallel: | ||
return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>( | ||
start, end, s.clauses(), s.getStructuredBlock()); | ||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), | ||
s.getStructuredBlock()); | ||
case OpenACCDirectiveKind::Serial: | ||
return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>( | ||
start, end, s.clauses(), s.getStructuredBlock()); | ||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), | ||
s.getStructuredBlock()); | ||
case OpenACCDirectiveKind::Kernels: | ||
return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>( | ||
start, end, s.clauses(), s.getStructuredBlock()); | ||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), | ||
s.getStructuredBlock()); | ||
default: | ||
llvm_unreachable("invalid compute construct kind"); | ||
} | ||
|
@@ -137,18 +221,22 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) { | |
mlir::Location end = getLoc(s.getSourceRange().getEnd()); | ||
|
||
return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>( | ||
start, end, s.clauses(), s.getStructuredBlock()); | ||
start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), | ||
s.getStructuredBlock()); | ||
} | ||
|
||
mlir::LogicalResult | ||
CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) { | ||
mlir::Location start = getLoc(s.getSourceRange().getEnd()); | ||
return emitOpenACCOp<InitOp>(start, s.clauses()); | ||
return emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), | ||
s.clauses()); | ||
} | ||
|
||
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct( | ||
const OpenACCShutdownConstruct &s) { | ||
mlir::Location start = getLoc(s.getSourceRange().getEnd()); | ||
return emitOpenACCOp<ShutdownOp>(start, s.clauses()); | ||
return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(), | ||
s.getDirectiveLoc(), s.clauses()); | ||
} | ||
|
||
mlir::LogicalResult | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,17 @@ void acc_init(void) { | |
// CHECK: cir.func @acc_init() { | ||
#pragma acc init | ||
// CHECK-NEXT: acc.init loc(#{{[a-zA-Z0-9]+}}){{$}} | ||
|
||
#pragma acc init device_type(*) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if you have this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will faithfully represent them, like this:
There isn't really anything that prohibits it by standard, so it seems reasonable to do. |
||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<star>]} | ||
#pragma acc init device_type(nvidia) | ||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]} | ||
#pragma acc init device_type(host, multicore) | ||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]} | ||
#pragma acc init device_type(NVIDIA) | ||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<nvidia>]} | ||
#pragma acc init device_type(HoSt, MuLtIcORe) | ||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]} | ||
#pragma acc init device_type(HoSt) device_type(MuLtIcORe) | ||
// CHECK-NEXT: acc.init attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I right that
ParallelOp
isacc::ParallelOp
?I was confused about what was happening here because I didn't read the comment where
isOneOfType
was defined. A brief comment here explaining what you're looking for ('Is Op one of the expected types?") would be helpful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep,
ParallelOp
ismlir::acc::ParallelOp
. We have ausing namespace mlir::acc
above. I'll add the comment.