Skip to content

[OpenACC][NFCI] Implement 'helpers' for all of the clauses I've used so far #137396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 33 additions & 170 deletions clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final
// diagnostics are gone.
SourceLocation dirLoc;

const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr;
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;

void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
lastDeviceTypeValues.clear();

llvm::for_each(clause.getArchitectures(),
[this](const DeviceTypeArgument &arg) {
lastDeviceTypeValues.push_back(
decodeDeviceType(arg.getIdentifierInfo()));
});
}

void clauseNotImplemented(const OpenACCClause &c) {
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
Expand Down Expand Up @@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
}

// Overload of this function that only returns the device-types list.
mlir::ArrayAttr
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) {
mlir::ValueRange argument;
mlir::MutableOperandRange range{operation};

return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range);
}
// Overload of this function for when 'segments' aren't necessary.
mlir::ArrayAttr
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
mlir::ValueRange argument,
mlir::MutableOperandRange argCollection) {
llvm::SmallVector<int32_t> segments;
assert(argument.size() <= 1 &&
"Overload only for cases where segments don't need to be added");
return handleDeviceTypeAffectedClause(existingDeviceTypes, argument,
argCollection, segments);
}

// Handle a clause affected by the 'device_type' to the point that they need
// to have attributes added in the correct/corresponding order, such as
// 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
// a collection of operands that need to be appended to the `argCollection` as
// we're adding a 'device_type' entry. If there is more than 0 elements in
// the 'argument', the collection must be non-null, as it is needed to add to
// it.
// As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
// be maintained, this takes a list of segments that will be updated with the
// proper counts as 'argument' elements are added.
//
// In MLIR, the 'operands' are stored as a large array, with a separate array
// of 'segments' that show which 'operand' applies to which 'operand-kind'.
// That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
//
// So the operands array might have 4 elements, but the 'segments' array will
// be something like:
//
// {0, 0, 0, 2, 0, 1, 1, 0, 0...}
//
// Where each position belongs to a specific 'operand-kind'. So that
// specifies that whichever operand-kind corresponds with index '3' has 2
// elements, and should take the 1st 2 operands off the list (since all
// preceding values are 0). operand-kinds corresponding to 5 and 6 each have
// 1 element.
//
// Fortunately, the `MutableOperandRange` append function actually takes care
// of that for us at the 'top level'.
//
// However, in cases like `num_gangs' or 'wait', where each individual
// 'element' might be itself array-like, there is a separate 'segments' array
// for them. So in the case of:
//
// device_type(nvidia, radeon) num_gangs(1, 2, 3)
//
// We have to emit that as TWO arrays into the IR (where the device_type is an
// attribute), so they look like:
//
// num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
// {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
//
// When stored in the 'operands' list, the top-level 'segment' for
// 'num_gangs' just shows 6 elements. In order to get the array-like
// apperance, the 'numGangsSegments' list is kept as well. In the above case,
// we've inserted 6 operands, so the 'numGangsSegments' must contain 2
// elements, 1 per array, and each will have a value of 3. The verifier will
// ensure that the collections counts are correct.
mlir::ArrayAttr
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
mlir::ValueRange argument,
mlir::MutableOperandRange argCollection,
llvm::SmallVector<int32_t> &segments) {
llvm::SmallVector<mlir::Attribute> deviceTypes;

// Collect the 'existing' device-type attributes so we can re-create them
// and insert them.
if (existingDeviceTypes) {
for (const mlir::Attribute &Attr : existingDeviceTypes)
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
builder.getContext(),
cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
}

// Insert 1 version of the 'expr' to the NumWorkers list per-current
// device type.
if (lastDeviceTypeClause) {
for (const DeviceTypeArgument &arch :
lastDeviceTypeClause->getArchitectures()) {
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
if (!argument.empty()) {
argCollection.append(argument);
segments.push_back(argument.size());
}
}
} else {
// Else, we just add a single for 'none'.
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None));
if (!argument.empty()) {
argCollection.append(argument);
segments.push_back(argument.size());
}
}

return mlir::ArrayAttr::get(builder.getContext(), deviceTypes);
}

public:
OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf,
CIRGenBuilderTy &builder,
Expand Down Expand Up @@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final
}

void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
lastDeviceTypeClause = &clause;
setLastDeviceTypeClause(clause);

if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
llvm::for_each(
clause.getArchitectures(), [this](const DeviceTypeArgument &arg) {
Expand All @@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final
} else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp,
DataOp>) {
// Nothing to do here, these constructs don't have any IR for these, as
// they just modify the other clauses IR. So setting of `lastDeviceType`
// (done above) is all we need.
// they just modify the other clauses IR. So setting of
// `lastDeviceTypeValues` (done above) is all we need.
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. update, data, loop, routine, combined constructs remain.
Expand All @@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final

void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
mlir::MutableOperandRange range = operation.getNumWorkersMutable();
operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getNumWorkersDeviceTypeAttr(),
createIntExpr(clause.getIntExpr()), range));
operation.addNumWorkersOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
llvm_unreachable("num_workers not valid on serial");
} else {
Expand All @@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final

void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
mlir::MutableOperandRange range = operation.getVectorLengthMutable();
operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getVectorLengthDeviceTypeAttr(),
createIntExpr(clause.getIntExpr()), range));
operation.addVectorLengthOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
llvm_unreachable("vector_length not valid on serial");
} else {
Expand All @@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final

void VisitAsyncClause(const OpenACCAsyncClause &clause) {
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
if (!clause.hasIntExpr()) {
operation.setAsyncOnlyAttr(
handleDeviceTypeAffectedClause(operation.getAsyncOnlyAttr()));
} else {
mlir::MutableOperandRange range = operation.getAsyncOperandsMutable();
operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getAsyncOperandsDeviceTypeAttr(),
createIntExpr(clause.getIntExpr()), range));
}
if (!clause.hasIntExpr())
operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
else
operation.addAsyncOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
// Wait doesn't have a device_type, so its handling here is slightly
// different.
Expand Down Expand Up @@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final
void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
llvm::SmallVector<mlir::Value> values;

for (const Expr *E : clause.getIntExprs())
values.push_back(createIntExpr(E));

llvm::SmallVector<int32_t> segments;
if (operation.getNumGangsSegments())
llvm::copy(*operation.getNumGangsSegments(),
std::back_inserter(segments));

mlir::MutableOperandRange range = operation.getNumGangsMutable();
operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getNumGangsDeviceTypeAttr(), values, range, segments));
operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments});
operation.addNumGangsOperands(builder.getContext(), values,
lastDeviceTypeValues);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
Expand All @@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final
void VisitWaitClause(const OpenACCWaitClause &clause) {
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
if (!clause.hasExprs()) {
operation.setWaitOnlyAttr(
handleDeviceTypeAffectedClause(operation.getWaitOnlyAttr()));
operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues);
} else {
llvm::SmallVector<mlir::Value> values;

if (clause.hasDevNumExpr())
values.push_back(createIntExpr(clause.getDevNumExpr()));
for (const Expr *E : clause.getQueueIdExprs())
values.push_back(createIntExpr(E));

llvm::SmallVector<int32_t> segments;
if (operation.getWaitOperandsSegments())
llvm::copy(*operation.getWaitOperandsSegments(),
std::back_inserter(segments));

unsigned beforeSegmentSize = segments.size();

mlir::MutableOperandRange range = operation.getWaitOperandsMutable();
operation.setWaitOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
operation.getWaitOperandsDeviceTypeAttr(), values, range,
segments));
operation.setWaitOperandsSegments(segments);

// In addition to having to set the 'segments', wait also has a list of
// bool attributes whether it is annotated with 'devnum'. We can use
// our knowledge of how much the 'segments' array grew to determine how
// many we need to add.
llvm::SmallVector<bool> hasDevNums;
if (operation.getHasWaitDevnumAttr())
for (mlir::Attribute A : operation.getHasWaitDevnumAttr())
hasDevNums.push_back(cast<mlir::BoolAttr>(A).getValue());

hasDevNums.insert(hasDevNums.end(), segments.size() - beforeSegmentSize,
clause.hasDevNumExpr());

operation.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasDevNums));
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
values, lastDeviceTypeValues);
}
} else {
// TODO: When we've implemented this for everything, switch this to an
Expand Down Expand Up @@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
if (s.hasDevNumExpr())
waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));

for (Expr *QueueExpr : s.getQueueIdExprs())
for (Expr *QueueExpr : s.getQueueIdExprs())
waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
}

Expand Down
80 changes: 80 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,31 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
static mlir::acc::Construct getConstructId() {
return mlir::acc::Construct::acc_construct_parallel;
}
/// Add a value to 'num_workers' with the current list of device types.
void addNumWorkersOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add a value to 'vector_length' with the current list of device types.
void addVectorLengthOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add an entry to the 'async-only' attribute (clause spelled without
/// arguments)for each of the additional device types (or a none if it is
/// empty).
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add a value to the 'async' with the current list of device types.
void addAsyncOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add an array-like entry to the 'num_gangs' with the current list of
/// device types.
void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
llvm::ArrayRef<DeviceType>);
/// Add an entry to the 'wait-only' attribute (clause spelled without
/// arguments)for each of the additional device types (or a none if it is
/// empty).
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add an array-like entry to the 'wait' with the current list of device
/// types.
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
llvm::ArrayRef<DeviceType>);
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1535,6 +1560,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
static mlir::acc::Construct getConstructId() {
return mlir::acc::Construct::acc_construct_serial;
}
/// Add an entry to the 'async-only' attribute (clause spelled without
/// arguments) for each of the additional device types (or a none if it is
/// empty).
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add a value to the 'async' with the current list of device types.
void addAsyncOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add an entry to the 'wait-only' attribute (clause spelled without
/// arguments) for each of the additional device types (or a none if it is
/// empty).
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add an array-like entry to the 'wait' with the current list of device
/// types.
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
llvm::ArrayRef<DeviceType>);
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1679,6 +1719,31 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
static mlir::acc::Construct getConstructId() {
return mlir::acc::Construct::acc_construct_kernels;
}
/// Add a value to 'num_workers' with the current list of device types.
void addNumWorkersOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add a value to 'vector_length' with the current list of device types.
void addVectorLengthOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add an entry to the 'async-only' attribute (clause spelled without
/// arguments) for each of the additional device types (or a none if it is
/// empty).
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add a value to the 'async' with the current list of device types.
void addAsyncOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add an array-like entry to the 'num_gangs' with the current list of
/// device types.
void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
llvm::ArrayRef<DeviceType>);
/// Add an entry to the 'wait-only' attribute (clause spelled without
/// arguments) for each of the additional device types (or a none if it is
/// empty).
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add an array-like entry to the 'wait' with the current list of device
/// types.
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
llvm::ArrayRef<DeviceType>);
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1785,6 +1850,21 @@ def OpenACC_DataOp : OpenACC_Op<"data",
/// Return the wait devnum value clause for the given device_type if
/// present.
mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
/// Add an entry to the 'async-only' attribute (clause spelled without
/// arguments) for each of the additional device types (or a none if it is
/// empty).
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add a value to the 'async' with the current list of device types.
void addAsyncOperand(MLIRContext *, mlir::Value,
llvm::ArrayRef<DeviceType>);
/// Add an entry to the 'wait-only' attribute (clause spelled without
/// arguments) for each of the additional device types (or a none if it is
/// empty).
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
/// Add an array-like entry to the 'wait' with the current list of device
/// types.
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
llvm::ArrayRef<DeviceType>);
}];

let assemblyFormat = [{
Expand Down
Loading
Loading