Skip to content

Commit b45ea44

Browse files
committed
[flang][openacc] Lower update directive
This patch upstream the lowering of Update directive that was initially done in flang-compiler#528 Reviewed By: schweitz Differential Revision: https://reviews.llvm.org/D90472
1 parent e6cd3ef commit b45ea44

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,112 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
725725
exitDataOp.finalizeAttr(firOpBuilder.getUnitAttr());
726726
}
727727

728+
static void
729+
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
730+
const Fortran::parser::AccClauseList &accClauseList) {
731+
mlir::Value ifCond, async, waitDevnum;
732+
SmallVector<Value, 2> hostOperands, deviceOperands, waitOperands,
733+
deviceTypeOperands;
734+
735+
// Async and wait clause have optional values but can be present with
736+
// no value as well. When there is no value, the op has an attribute to
737+
// represent the clause.
738+
bool addAsyncAttr = false;
739+
bool addWaitAttr = false;
740+
bool addIfPresentAttr = false;
741+
742+
auto &firOpBuilder = converter.getFirOpBuilder();
743+
auto currentLocation = converter.getCurrentLocation();
744+
745+
// Lower clauses values mapped to operands.
746+
// Keep track of each group of operands separatly as clauses can appear
747+
// more than once.
748+
for (const auto &clause : accClauseList.v) {
749+
if (const auto *ifClause =
750+
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
751+
mlir::Value cond = fir::getBase(
752+
converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v)));
753+
ifCond = firOpBuilder.createConvert(currentLocation,
754+
firOpBuilder.getI1Type(), cond);
755+
} else if (const auto *asyncClause =
756+
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
757+
const auto &asyncClauseValue = asyncClause->v;
758+
if (asyncClauseValue) { // async has a value.
759+
async = fir::getBase(converter.genExprValue(
760+
*Fortran::semantics::GetExpr(*asyncClauseValue)));
761+
} else {
762+
addAsyncAttr = true;
763+
}
764+
} else if (const auto *waitClause =
765+
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
766+
const auto &waitClauseValue = waitClause->v;
767+
if (waitClauseValue) { // wait has a value.
768+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
769+
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
770+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
771+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
772+
mlir::Value v = fir::getBase(
773+
converter.genExprValue(*Fortran::semantics::GetExpr(value)));
774+
waitOperands.push_back(v);
775+
}
776+
777+
const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
778+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
779+
if (waitDevnumValue)
780+
waitDevnum = fir::getBase(converter.genExprValue(
781+
*Fortran::semantics::GetExpr(*waitDevnumValue)));
782+
} else {
783+
addWaitAttr = true;
784+
}
785+
} else if (const auto *deviceTypeClause =
786+
std::get_if<Fortran::parser::AccClause::DeviceType>(
787+
&clause.u)) {
788+
789+
const auto &deviceTypeValue = deviceTypeClause->v;
790+
if (deviceTypeValue) {
791+
for (const auto &scalarIntExpr : *deviceTypeValue) {
792+
mlir::Value expr = fir::getBase(converter.genExprValue(
793+
*Fortran::semantics::GetExpr(scalarIntExpr)));
794+
deviceTypeOperands.push_back(expr);
795+
}
796+
} else {
797+
// * was passed as value and will be represented as a -1 constant
798+
// integer.
799+
mlir::Value star = firOpBuilder.createIntegerConstant(
800+
currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1);
801+
deviceTypeOperands.push_back(star);
802+
}
803+
} else if (const auto *hostClause =
804+
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
805+
genObjectList(hostClause->v, converter, hostOperands);
806+
} else if (const auto *deviceClause =
807+
std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
808+
genObjectList(deviceClause->v, converter, deviceOperands);
809+
}
810+
}
811+
812+
// Prepare the operand segement size attribute and the operands value range.
813+
SmallVector<mlir::Value, 14> operands;
814+
SmallVector<int32_t, 7> operandSegments;
815+
addOperand(operands, operandSegments, async);
816+
addOperand(operands, operandSegments, waitDevnum);
817+
addOperands(operands, operandSegments, waitOperands);
818+
addOperands(operands, operandSegments, deviceTypeOperands);
819+
addOperand(operands, operandSegments, ifCond);
820+
addOperands(operands, operandSegments, hostOperands);
821+
addOperands(operands, operandSegments, deviceOperands);
822+
823+
auto updateOp = createSimpleOp<mlir::acc::UpdateOp>(
824+
firOpBuilder, currentLocation, operands, operandSegments);
825+
826+
if (addAsyncAttr)
827+
updateOp.asyncAttr(firOpBuilder.getUnitAttr());
828+
if (addWaitAttr)
829+
updateOp.waitAttr(firOpBuilder.getUnitAttr());
830+
if (addIfPresentAttr)
831+
updateOp.ifPresentAttr(firOpBuilder.getUnitAttr());
832+
}
833+
728834
static void
729835
genACC(Fortran::lower::AbstractConverter &converter,
730836
Fortran::lower::pft::Evaluation &eval,
@@ -745,7 +851,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
745851
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
746852
TODO("OpenACC set directive not lowered yet!");
747853
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
748-
TODO("OpenACC update directive not lowered yet!");
854+
genACCUpdateOp(converter, accClauseList);
749855
}
750856
}
751857

0 commit comments

Comments
 (0)