Skip to content

Commit 26060de

Browse files
authored
[flang][cuda] Lower device/managed/unified allocation to cuda ops (llvm#90623)
Lower locals allocation of cuda device, managed and unified variables to fir.cuda_alloc. Add fir.cuda_free in the function context finalization. @vzakhari For some reason the PR llvm#90526 has been closed when I merged PR llvm#90525. Just reopening one.
1 parent d129ea8 commit 26060de

File tree

6 files changed

+107
-11
lines changed

6 files changed

+107
-11
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,13 @@ mlir::Value createNullBoxProc(fir::FirOpBuilder &builder, mlir::Location loc,
708708

709709
/// Set internal linkage attribute on a function.
710710
void setInternalLinkage(mlir::func::FuncOp);
711+
712+
llvm::SmallVector<mlir::Value>
713+
elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape);
714+
715+
llvm::SmallVector<mlir::Value>
716+
elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
717+
711718
} // namespace fir::factory
712719

713720
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H

flang/include/flang/Semantics/tools.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,23 @@ inline bool HasCUDAAttr(const Symbol &sym) {
222222
return false;
223223
}
224224

225+
inline bool NeedCUDAAlloc(const Symbol &sym) {
226+
bool inDeviceSubprogram{IsCUDADeviceContext(&sym.owner())};
227+
if (const auto *details{
228+
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
229+
if (details->cudaDataAttr() &&
230+
(*details->cudaDataAttr() == common::CUDADataAttr::Device ||
231+
*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
232+
*details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
233+
// Descriptor is allocated on host when in host context.
234+
if (Fortran::semantics::IsAllocatable(sym))
235+
return inDeviceSubprogram;
236+
return true;
237+
}
238+
}
239+
return false;
240+
}
241+
225242
const Scope *FindCUDADeviceContext(const Scope *);
226243
std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *);
227244

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,22 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
693693
if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee))
694694
return builder.create<fir::ZeroOp>(loc, fir::ReferenceType::get(ty));
695695

696+
if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) {
697+
fir::CUDADataAttributeAttr cudaAttr =
698+
Fortran::lower::translateSymbolCUDADataAttribute(builder.getContext(),
699+
ultimateSymbol);
700+
llvm::SmallVector<mlir::Value> indices;
701+
llvm::SmallVector<mlir::Value> elidedShape =
702+
fir::factory::elideExtentsAlreadyInType(ty, shape);
703+
llvm::SmallVector<mlir::Value> elidedLenParams =
704+
fir::factory::elideLengthsAlreadyInType(ty, lenParams);
705+
auto idxTy = builder.getIndexType();
706+
for (mlir::Value sh : elidedShape)
707+
indices.push_back(builder.createConvert(loc, idxTy, sh));
708+
return builder.create<fir::CUDAAllocOp>(loc, ty, nm, symNm, cudaAttr,
709+
lenParams, indices);
710+
}
711+
696712
// Let the builder do all the heavy lifting.
697713
if (!Fortran::semantics::IsProcedurePointer(ultimateSymbol))
698714
return builder.allocateLocal(loc, ty, nm, symNm, shape, lenParams, isTarg);
@@ -927,6 +943,19 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
927943
});
928944
}
929945
}
946+
if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol())) {
947+
auto *builder = &converter.getFirOpBuilder();
948+
mlir::Location loc = converter.getCurrentLocation();
949+
fir::ExtendedValue exv =
950+
converter.getSymbolExtendedValue(var.getSymbol(), &symMap);
951+
auto *sym = &var.getSymbol();
952+
converter.getFctCtx().attachCleanup([builder, loc, exv, sym]() {
953+
fir::CUDADataAttributeAttr cudaAttr =
954+
Fortran::lower::translateSymbolCUDADataAttribute(
955+
builder->getContext(), *sym);
956+
builder->create<fir::CUDAFreeOp>(loc, fir::getBase(exv), cudaAttr);
957+
});
958+
}
930959
}
931960

932961
//===----------------------------------------------------------------===//

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ mlir::Value fir::FirOpBuilder::createRealConstant(mlir::Location loc,
176176
llvm_unreachable("should use builtin floating-point type");
177177
}
178178

179-
static llvm::SmallVector<mlir::Value>
180-
elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape) {
179+
llvm::SmallVector<mlir::Value>
180+
fir::factory::elideExtentsAlreadyInType(mlir::Type type,
181+
mlir::ValueRange shape) {
181182
auto arrTy = mlir::dyn_cast<fir::SequenceType>(type);
182183
if (shape.empty() || !arrTy)
183184
return {};
@@ -191,8 +192,9 @@ elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape) {
191192
return dynamicShape;
192193
}
193194

194-
static llvm::SmallVector<mlir::Value>
195-
elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams) {
195+
llvm::SmallVector<mlir::Value>
196+
fir::factory::elideLengthsAlreadyInType(mlir::Type type,
197+
mlir::ValueRange lenParams) {
196198
if (lenParams.empty())
197199
return {};
198200
if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(type))
@@ -211,9 +213,9 @@ mlir::Value fir::FirOpBuilder::allocateLocal(
211213
// Convert the shape extents to `index`, as needed.
212214
llvm::SmallVector<mlir::Value> indices;
213215
llvm::SmallVector<mlir::Value> elidedShape =
214-
elideExtentsAlreadyInType(ty, shape);
216+
fir::factory::elideExtentsAlreadyInType(ty, shape);
215217
llvm::SmallVector<mlir::Value> elidedLenParams =
216-
elideLengthsAlreadyInType(ty, lenParams);
218+
fir::factory::elideLengthsAlreadyInType(ty, lenParams);
217219
auto idxTy = getIndexType();
218220
for (mlir::Value sh : elidedShape)
219221
indices.push_back(createConvert(loc, idxTy, sh));
@@ -283,9 +285,9 @@ fir::FirOpBuilder::createTemporary(mlir::Location loc, mlir::Type type,
283285
mlir::ValueRange lenParams,
284286
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
285287
llvm::SmallVector<mlir::Value> dynamicShape =
286-
elideExtentsAlreadyInType(type, shape);
288+
fir::factory::elideExtentsAlreadyInType(type, shape);
287289
llvm::SmallVector<mlir::Value> dynamicLength =
288-
elideLengthsAlreadyInType(type, lenParams);
290+
fir::factory::elideLengthsAlreadyInType(type, lenParams);
289291
InsertPoint insPt;
290292
const bool hoistAlloc = dynamicShape.empty() && dynamicLength.empty();
291293
if (hoistAlloc) {
@@ -306,9 +308,9 @@ mlir::Value fir::FirOpBuilder::createHeapTemporary(
306308
mlir::ValueRange shape, mlir::ValueRange lenParams,
307309
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
308310
llvm::SmallVector<mlir::Value> dynamicShape =
309-
elideExtentsAlreadyInType(type, shape);
311+
fir::factory::elideExtentsAlreadyInType(type, shape);
310312
llvm::SmallVector<mlir::Value> dynamicLength =
311-
elideLengthsAlreadyInType(type, lenParams);
313+
fir::factory::elideLengthsAlreadyInType(type, lenParams);
312314

313315
assert(!mlir::isa<fir::ReferenceType>(type) && "cannot be a reference");
314316
return create<fir::AllocMemOp>(loc, type, /*unique_name=*/llvm::StringRef{},
@@ -660,7 +662,8 @@ mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc, mlir::Type boxType,
660662
mlir::Type valueOrSequenceType = fir::unwrapPassByRefType(boxType);
661663
return create<fir::EmboxOp>(
662664
loc, boxType, addr, shape, slice,
663-
elideLengthsAlreadyInType(valueOrSequenceType, lengths), tdesc);
665+
fir::factory::elideLengthsAlreadyInType(valueOrSequenceType, lengths),
666+
tdesc);
664667
}
665668

666669
void fir::FirOpBuilder::dumpFunc() { getFunction().dump(); }

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4033,6 +4033,21 @@ mlir::LogicalResult fir::CUDADeallocateOp::verify() {
40334033
return mlir::success();
40344034
}
40354035

4036+
void fir::CUDAAllocOp::build(
4037+
mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType,
4038+
llvm::StringRef uniqName, llvm::StringRef bindcName,
4039+
fir::CUDADataAttributeAttr cudaAttr, mlir::ValueRange typeparams,
4040+
mlir::ValueRange shape, llvm::ArrayRef<mlir::NamedAttribute> attributes) {
4041+
mlir::StringAttr nameAttr =
4042+
uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName);
4043+
mlir::StringAttr bindcAttr =
4044+
bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
4045+
build(builder, result, wrapAllocaResultType(inType),
4046+
mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
4047+
cudaAttr);
4048+
result.addAttributes(attributes);
4049+
}
4050+
40364051
//===----------------------------------------------------------------------===//
40374052
// FIROpsDialect
40384053
//===----------------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-data-attribute.cuf

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,29 @@ end subroutine
6262
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "du", fir.cuda_attr = #fir.cuda<unified>})
6363
! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<unified>, uniq_name = "_QMcuda_varFdummy_arg_unifiedEdu"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
6464

65+
subroutine cuda_alloc_free(n)
66+
integer :: n
67+
real, device :: a(10)
68+
integer, unified :: u
69+
real, managed :: b(n)
70+
end
71+
72+
! CHECK-LABEL: func.func @_QMcuda_varPcuda_alloc_free
73+
! CHECK: %[[ALLOC_A:.*]] = fir.cuda_alloc !fir.array<10xf32> {bindc_name = "a", cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
74+
! CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
75+
! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOC_A]](%[[SHAPE]]) {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
76+
77+
! CHECK: %[[ALLOC_U:.*]] = fir.cuda_alloc i32 {bindc_name = "u", cuda_attr = #fir.cuda<unified>, uniq_name = "_QMcuda_varFcuda_alloc_freeEu"} -> !fir.ref<i32>
78+
! CHECK: %[[DECL_U:.*]]:2 = hlfir.declare %[[ALLOC_U]] {cuda_attr = #fir.cuda<unified>, uniq_name = "_QMcuda_varFcuda_alloc_freeEu"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
79+
80+
! CHECK: %[[ALLOC_B:.*]] = fir.cuda_alloc !fir.array<?xf32>, %{{.*}} : index {bindc_name = "b", cuda_attr = #fir.cuda<managed>, uniq_name = "_QMcuda_varFcuda_alloc_freeEb"} -> !fir.ref<!fir.array<?xf32>>
81+
! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
82+
! CHECK: %[[DECL_B:.*]]:2 = hlfir.declare %[[ALLOC_B]](%[[SHAPE]]) {cuda_attr = #fir.cuda<managed>, uniq_name = "_QMcuda_varFcuda_alloc_freeEb"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
83+
84+
! CHECK: fir.cuda_free %[[DECL_B]]#1 : !fir.ref<!fir.array<?xf32>> {cuda_attr = #fir.cuda<managed>}
85+
! CHECK: fir.cuda_free %[[DECL_U]]#1 : !fir.ref<i32> {cuda_attr = #fir.cuda<unified>}
86+
! CHECK: fir.cuda_free %[[DECL_A]]#1 : !fir.ref<!fir.array<10xf32>> {cuda_attr = #fir.cuda<device>}
87+
6588
end module
89+
90+

0 commit comments

Comments
 (0)