Skip to content

Commit d21580c

Browse files
committed
[MLIR][OpenMP]Add Flang lowering support for device_ptr and device_addr clauses
Add lowering support for the use_device_ptr and use_Device_addr clauses for the Target Data directive. Depends on D152822 Differential Revision: https://reviews.llvm.org/D152824
1 parent 0657ae3 commit d21580c

File tree

2 files changed

+131
-22
lines changed

2 files changed

+131
-22
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,48 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
723723
}
724724
}
725725

726+
static void createBodyOfTargetOp(
727+
Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
728+
const llvm::SmallVector<mlir::Type> &useDeviceTypes,
729+
const llvm::SmallVector<mlir::Location> &useDeviceLocs,
730+
const SmallVector<const Fortran::semantics::Symbol *> &useDeviceSymbols,
731+
const mlir::Location &currentLocation) {
732+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
733+
mlir::Region &region = dataOp.getRegion();
734+
735+
firOpBuilder.createBlock(&region, {}, useDeviceTypes, useDeviceLocs);
736+
firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
737+
firOpBuilder.setInsertionPointToStart(&region.front());
738+
739+
unsigned argIndex = 0;
740+
for (auto *sym : useDeviceSymbols) {
741+
const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
742+
mlir::Value val = fir::getBase(arg);
743+
fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
744+
if (auto refType = val.getType().dyn_cast<fir::ReferenceType>()) {
745+
if (fir::isa_builtin_cptr_type(refType.getElementType())) {
746+
converter.bindSymbol(*sym, val);
747+
} else {
748+
extVal.match(
749+
[&](const fir::MutableBoxValue &mbv) {
750+
converter.bindSymbol(
751+
*sym,
752+
fir::MutableBoxValue(
753+
val, fir::factory::getNonDeferredLenParams(extVal), {}));
754+
},
755+
[&](const auto &) {
756+
TODO(converter.getCurrentLocation(),
757+
"use_device clause operand unsupported type");
758+
});
759+
}
760+
} else {
761+
TODO(converter.getCurrentLocation(),
762+
"use_device clause operand unsupported type");
763+
}
764+
argIndex++;
765+
}
766+
}
767+
726768
static void createTargetOp(Fortran::lower::AbstractConverter &converter,
727769
const Fortran::parser::OmpClauseList &opClauseList,
728770
const llvm::omp::Directive &directive,
@@ -732,13 +774,24 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
732774

733775
mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand;
734776
mlir::UnitAttr nowaitAttr;
735-
llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand,
736-
mapOperands;
777+
llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
778+
deviceAddrOperands;
737779
llvm::SmallVector<mlir::IntegerAttr> mapTypes;
780+
llvm::SmallVector<mlir::Type> useDeviceTypes;
781+
llvm::SmallVector<mlir::Location> useDeviceLocs;
782+
SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
783+
784+
/// Check for unsupported map operand types.
785+
auto checkType = [](auto currentLocation, mlir::Type type) {
786+
if (auto refType = type.dyn_cast<fir::ReferenceType>())
787+
type = refType.getElementType();
788+
if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
789+
if (!boxType.getElementType().isa<fir::PointerType>())
790+
TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
791+
};
738792

739-
auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
740-
&mapTypes](const auto &mapClause,
741-
mlir::Location &currentLocation) {
793+
auto addMapClause = [&](const auto &mapClause,
794+
mlir::Location &currentLocation) {
742795
auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
743796
std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
744797
->t);
@@ -793,18 +846,25 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
793846
converter, mapOperand);
794847

795848
for (mlir::Value mapOp : mapOperand) {
796-
/// Check for unsupported map operand types.
797-
mlir::Type checkType = mapOp.getType();
798-
if (auto refType = checkType.dyn_cast<fir::ReferenceType>())
799-
checkType = refType.getElementType();
800-
if (checkType.isa<fir::BoxType>())
801-
TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
802-
849+
checkType(mapOp.getLoc(), mapOp.getType());
803850
mapOperands.push_back(mapOp);
804851
mapTypes.push_back(mapTypeAttr);
805852
}
806853
};
807854

855+
auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) {
856+
genObjectList(useDeviceClause, converter, operands);
857+
for (auto &operand : operands) {
858+
checkType(operand.getLoc(), operand.getType());
859+
useDeviceTypes.push_back(operand.getType());
860+
useDeviceLocs.push_back(operand.getLoc());
861+
}
862+
for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
863+
Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
864+
useDeviceSymbols.push_back(sym);
865+
}
866+
};
867+
808868
for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
809869
mlir::Location currentLocation = converter.genLocation(clause.source);
810870
if (const auto &ifClause =
@@ -825,19 +885,21 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
825885
deviceOperand =
826886
fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
827887
}
828-
} else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
829-
&clause.u)) {
830-
TODO(currentLocation, "OMPD_target Use Device Ptr");
831-
} else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
832-
&clause.u)) {
833-
TODO(currentLocation, "OMPD_target Use Device Addr");
834888
} else if (const auto &threadLmtClause =
835889
std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
836890
&clause.u)) {
837891
threadLmtOperand = fir::getBase(converter.genExprValue(
838892
*Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
839893
} else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
840894
nowaitAttr = firOpBuilder.getUnitAttr();
895+
} else if (const auto &devPtrClause =
896+
std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
897+
&clause.u)) {
898+
addUseDeviceClause(devPtrClause->v, devicePtrOperands);
899+
} else if (const auto &devAddrClause =
900+
std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
901+
&clause.u)) {
902+
addUseDeviceClause(devAddrClause->v, deviceAddrOperands);
841903
} else if (const auto &mapClause =
842904
std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
843905
addMapClause(mapClause, currentLocation);
@@ -859,9 +921,10 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
859921
createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList);
860922
} else if (directive == llvm::omp::Directive::OMPD_target_data) {
861923
auto dataOp = firOpBuilder.create<omp::DataOp>(
862-
currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
863-
useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
864-
createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList);
924+
currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
925+
deviceAddrOperands, mapOperands, mapTypesArrayAttr);
926+
createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
927+
useDeviceSymbols, currentLocation);
865928
} else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
866929
firOpBuilder.create<omp::EnterDataOp>(currentLocation, ifClauseOperand,
867930
deviceOperand, nowaitAttr,
@@ -1157,7 +1220,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
11571220
continue;
11581221
} else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
11591222
// Map clause is exclusive to Target Data directives. It is handled
1160-
// as part of the DataOp creation.
1223+
// as part of the TargetOp creation.
1224+
continue;
1225+
} else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
1226+
&clause.u)) {
1227+
// UseDevicePtr clause is exclusive to Target Data directives. It is
1228+
// handled as part of the TargetOp creation.
1229+
continue;
1230+
} else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
1231+
&clause.u)) {
1232+
// UseDeviceAddr clause is exclusive to Target Data directives. It is
1233+
// handled as part of the TargetOp creation.
11611234
continue;
11621235
} else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
11631236
&clause.u)) {

flang/test/Lower/OpenMP/target.f90

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,39 @@ subroutine omp_target_thread_limit
162162
!$omp end target
163163
!CHECK: }
164164
end subroutine omp_target_thread_limit
165+
166+
!===============================================================================
167+
! Target `use_device_ptr` clause
168+
!===============================================================================
169+
170+
!CHECK-LABEL: func.func @_QPomp_target_device_ptr() {
171+
subroutine omp_target_device_ptr
172+
use iso_c_binding, only : c_ptr, c_loc
173+
type(c_ptr) :: a
174+
integer, target :: b
175+
!CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)) use_device_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
176+
!$omp target data map(tofrom: a) use_device_ptr(a)
177+
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
178+
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
179+
a = c_loc(b)
180+
!CHECK: omp.terminator
181+
!$omp end target data
182+
!CHECK: }
183+
end subroutine omp_target_device_ptr
184+
185+
!===============================================================================
186+
! Target `use_device_addr` clause
187+
!===============================================================================
188+
189+
!CHECK-LABEL: func.func @_QPomp_target_device_addr() {
190+
subroutine omp_target_device_addr
191+
integer, pointer :: a
192+
!CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)
193+
!$omp target data map(tofrom: a) use_device_addr(a)
194+
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
195+
!CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
196+
a = 10
197+
!CHECK: omp.terminator
198+
!$omp end target data
199+
!CHECK: }
200+
end subroutine omp_target_device_addr

0 commit comments

Comments
 (0)