Skip to content

Commit a3ccaed

Browse files
authored
[flang][cuda] Allocate local descriptor in managed memory (#102060)
This patch adds entry point in the runtime to be able to allocate descriptors in managed memory. These entry points currently only call `CUFAllocManaged` and `CUFFreeManaged` but could be more complicated in the future. `cuf.alloc` and `cuf.free` related to local descriptors are converted into runtime calls.
1 parent f133dd9 commit a3ccaed

File tree

6 files changed

+186
-9
lines changed

6 files changed

+186
-9
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- include/flang/Runtime/CUDA/descriptor.h -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_RUNTIME_CUDA_DESCRIPTOR_H_
10+
#define FORTRAN_RUNTIME_CUDA_DESCRIPTOR_H_
11+
12+
#include "flang/Runtime/descriptor.h"
13+
#include "flang/Runtime/entry-names.h"
14+
#include <cstddef>
15+
16+
namespace Fortran::runtime::cuf {
17+
18+
extern "C" {
19+
20+
// Allocate a descriptor in managed.
21+
Descriptor *RTDECL(CUFAllocDesciptor)(
22+
std::size_t, const char *sourceFile = nullptr, int sourceLine = 0);
23+
24+
// Deallocate a descriptor allocated in managed or unified memory.
25+
void RTDECL(CUFFreeDesciptor)(
26+
Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0);
27+
28+
} // extern "C"
29+
} // namespace Fortran::runtime::cuf
30+
#endif // FORTRAN_RUNTIME_CUDA_DESCRIPTOR_H_

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
#include "flang/Common/Fortran.h"
1010
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
11+
#include "flang/Optimizer/CodeGen/TypeConverter.h"
1112
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
1213
#include "flang/Optimizer/Dialect/FIRDialect.h"
1314
#include "flang/Optimizer/Dialect/FIROps.h"
1415
#include "flang/Optimizer/HLFIR/HLFIROps.h"
16+
#include "flang/Optimizer/Support/DataLayout.h"
17+
#include "flang/Runtime/CUDA/descriptor.h"
1518
#include "flang/Runtime/allocatable.h"
1619
#include "mlir/Pass/Pass.h"
1720
#include "mlir/Transforms/DialectConversion.h"
@@ -25,6 +28,7 @@ namespace fir {
2528
using namespace fir;
2629
using namespace mlir;
2730
using namespace Fortran::runtime;
31+
using namespace Fortran::runtime::cuf;
2832

2933
namespace {
3034

@@ -75,11 +79,11 @@ static mlir::LogicalResult convertOpToCall(OpTy op,
7579
}
7680

7781
struct CufAllocateOpConversion
78-
: public mlir::OpRewritePattern<cuf::AllocateOp> {
82+
: public mlir::OpRewritePattern<::cuf::AllocateOp> {
7983
using OpRewritePattern::OpRewritePattern;
8084

8185
mlir::LogicalResult
82-
matchAndRewrite(cuf::AllocateOp op,
86+
matchAndRewrite(::cuf::AllocateOp op,
8387
mlir::PatternRewriter &rewriter) const override {
8488
// TODO: Allocation with source will need a new entry point in the runtime.
8589
if (op.getSource())
@@ -108,16 +112,16 @@ struct CufAllocateOpConversion
108112
mlir::func::FuncOp func =
109113
fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocate)>(loc,
110114
builder);
111-
return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
115+
return convertOpToCall<::cuf::AllocateOp>(op, rewriter, func);
112116
}
113117
};
114118

115119
struct CufDeallocateOpConversion
116-
: public mlir::OpRewritePattern<cuf::DeallocateOp> {
120+
: public mlir::OpRewritePattern<::cuf::DeallocateOp> {
117121
using OpRewritePattern::OpRewritePattern;
118122

119123
mlir::LogicalResult
120-
matchAndRewrite(cuf::DeallocateOp op,
124+
matchAndRewrite(::cuf::DeallocateOp op,
121125
mlir::PatternRewriter &rewriter) const override {
122126
// TODO: Allocation of module variable will need more work as the descriptor
123127
// will be duplicated and needs to be synced after allocation.
@@ -133,7 +137,84 @@ struct CufDeallocateOpConversion
133137
mlir::func::FuncOp func =
134138
fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
135139
builder);
136-
return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
140+
return convertOpToCall<::cuf::DeallocateOp>(op, rewriter, func);
141+
}
142+
};
143+
144+
struct CufAllocOpConversion : public mlir::OpRewritePattern<::cuf::AllocOp> {
145+
using OpRewritePattern::OpRewritePattern;
146+
147+
CufAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
148+
fir::LLVMTypeConverter *typeConverter)
149+
: OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
150+
151+
mlir::LogicalResult
152+
matchAndRewrite(::cuf::AllocOp op,
153+
mlir::PatternRewriter &rewriter) const override {
154+
auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
155+
156+
// Only convert cuf.alloc that allocates a descriptor.
157+
if (!boxTy)
158+
return failure();
159+
160+
auto mod = op->getParentOfType<mlir::ModuleOp>();
161+
fir::FirOpBuilder builder(rewriter, mod);
162+
mlir::Location loc = op.getLoc();
163+
mlir::func::FuncOp func =
164+
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
165+
166+
auto fTy = func.getFunctionType();
167+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
168+
mlir::Value sourceLine =
169+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
170+
171+
mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
172+
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
173+
mlir::Value sizeInBytes =
174+
builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
175+
176+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
177+
builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
178+
auto callOp = builder.create<fir::CallOp>(loc, func, args);
179+
auto convOp = builder.createConvert(loc, op.getResult().getType(),
180+
callOp.getResult(0));
181+
rewriter.replaceOp(op, convOp);
182+
return mlir::success();
183+
}
184+
185+
private:
186+
mlir::DataLayout *dl;
187+
fir::LLVMTypeConverter *typeConverter;
188+
};
189+
190+
struct CufFreeOpConversion : public mlir::OpRewritePattern<::cuf::FreeOp> {
191+
using OpRewritePattern::OpRewritePattern;
192+
193+
mlir::LogicalResult
194+
matchAndRewrite(::cuf::FreeOp op,
195+
mlir::PatternRewriter &rewriter) const override {
196+
// Only convert cuf.free on descriptor.
197+
if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
198+
return failure();
199+
auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
200+
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
201+
return failure();
202+
203+
auto mod = op->getParentOfType<mlir::ModuleOp>();
204+
fir::FirOpBuilder builder(rewriter, mod);
205+
mlir::Location loc = op.getLoc();
206+
mlir::func::FuncOp func =
207+
fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
208+
209+
auto fTy = func.getFunctionType();
210+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
211+
mlir::Value sourceLine =
212+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
213+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
214+
builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
215+
builder.create<fir::CallOp>(loc, func, args);
216+
rewriter.eraseOp(op);
217+
return mlir::success();
137218
}
138219
};
139220

@@ -143,8 +224,22 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
143224
auto *ctx = &getContext();
144225
mlir::RewritePatternSet patterns(ctx);
145226
mlir::ConversionTarget target(*ctx);
146-
target.addIllegalOp<cuf::AllocateOp, cuf::DeallocateOp>();
147-
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion>(ctx);
227+
228+
mlir::Operation *op = getOperation();
229+
mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
230+
if (!module)
231+
return signalPassFailure();
232+
233+
std::optional<mlir::DataLayout> dl =
234+
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
235+
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
236+
/*forceUnifiedTBAATree=*/false, *dl);
237+
238+
target.addIllegalOp<::cuf::AllocOp, ::cuf::AllocateOp, ::cuf::DeallocateOp,
239+
::cuf::FreeOp>();
240+
patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
241+
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
242+
CufFreeOpConversion>(ctx);
148243
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
149244
std::move(patterns)))) {
150245
mlir::emitError(mlir::UnknownLoc::get(ctx),

flang/runtime/CUDA/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ find_library(CUDA_RUNTIME_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTOR
1111

1212
add_flang_library(CufRuntime
1313
allocator.cpp
14+
descriptor.cpp
1415
)
1516
target_link_libraries(CufRuntime
1617
PRIVATE

flang/runtime/CUDA/descriptor.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===-- runtime/CUDA/descriptor.cpp ---------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Runtime/CUDA/descriptor.h"
10+
#include "flang/Runtime/CUDA/allocator.h"
11+
12+
namespace Fortran::runtime::cuf {
13+
extern "C" {
14+
RT_EXT_API_GROUP_BEGIN
15+
16+
Descriptor *RTDEF(CUFAllocDesciptor)(
17+
std::size_t sizeInBytes, const char *sourceFile, int sourceLine) {
18+
return reinterpret_cast<Descriptor *>(CUFAllocManaged(sizeInBytes));
19+
}
20+
21+
void RTDEF(CUFFreeDesciptor)(
22+
Descriptor *desc, const char *sourceFile, int sourceLine) {
23+
CUFFreeManaged(reinterpret_cast<void *>(desc));
24+
}
25+
26+
RT_EXT_API_GROUP_END
27+
}
28+
} // namespace Fortran::runtime::cuf

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: fir-opt --cuf-convert %s | FileCheck %s
22

3+
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
4+
35
func.func @_QPsub1() {
46
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
57
%4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
@@ -8,14 +10,21 @@ func.func @_QPsub1() {
810
%c0_i32 = arith.constant 0 : i32
911
%9 = cuf.allocate %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
1012
%10 = cuf.deallocate %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
13+
cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
1114
return
1215
}
1316

17+
1418
// CHECK-LABEL: func.func @_QPsub1()
15-
// CHECK: %[[DESC:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
19+
// CHECK: %[[DESC_RT_CALL:.*]] = fir.call @_FortranACUFAllocDesciptor(%{{.*}}, %{{.*}}, %{{.*}}) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
20+
// CHECK: %[[DESC:.*]] = fir.convert %[[DESC_RT_CALL]] : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
1621
// CHECK: %[[DECL_DESC:.*]]:2 = hlfir.declare %[[DESC]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
1722
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
1823
// CHECK: %{{.*}} = fir.call @_FortranAAllocatableAllocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
1924

2025
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
2126
// CHECK: %{{.*}} = fir.call @_FortranAAllocatableDeallocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
27+
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
28+
// CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
29+
30+
}

flang/unittests/Runtime/CUDA/AllocatorCUF.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
#include "../../../runtime/terminator.h"
1111
#include "flang/Common/Fortran.h"
1212
#include "flang/Runtime/CUDA/allocator.h"
13+
#include "flang/Runtime/CUDA/descriptor.h"
1314
#include "flang/Runtime/allocatable.h"
1415
#include "flang/Runtime/allocator-registry.h"
1516

1617
#include "cuda.h"
1718

1819
using namespace Fortran::runtime;
20+
using namespace Fortran::runtime::cuf;
1921

2022
static OwningPtr<Descriptor> createAllocatable(
2123
Fortran::common::TypeCategory tc, int kind, int rank = 1) {
@@ -87,3 +89,15 @@ TEST(AllocatableCUFTest, SimplePinnedAllocate) {
8789
(*a, /*hasStat=*/false, /*errMsg=*/nullptr, __FILE__, __LINE__);
8890
EXPECT_FALSE(a->IsAllocated());
8991
}
92+
93+
TEST(AllocatableCUFTest, DescriptorAllocationTest) {
94+
using Fortran::common::TypeCategory;
95+
Fortran::runtime::cuf::CUFRegisterAllocator();
96+
ScopedContext ctx;
97+
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
98+
auto a{createAllocatable(TypeCategory::Real, 4)};
99+
Descriptor *desc = nullptr;
100+
desc = RTNAME(CUFAllocDesciptor)(a->SizeInBytes());
101+
EXPECT_TRUE(desc != nullptr);
102+
RTNAME(CUFFreeDesciptor)(desc);
103+
}

0 commit comments

Comments
 (0)