Skip to content

Commit a19919f

Browse files
authored
[flang][cuda] Add cuf.device_address operation (#122975)
Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to #122802
1 parent ebef440 commit a19919f

File tree

5 files changed

+71
-27
lines changed

5 files changed

+71
-27
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,4 +335,16 @@ def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
335335
}];
336336
}
337337

338+
def cuf_DeviceAddressOp : cuf_Op<"device_address", []> {
339+
let summary = "Get the device address from a host symbol";
340+
341+
let arguments = (ins SymbolRefAttr:$hostSymbol);
342+
343+
let assemblyFormat = [{
344+
$hostSymbol attr-dict `->` type($addr)
345+
}];
346+
347+
let results = (outs fir_ReferenceType:$addr);
348+
}
349+
338350
#endif // FORTRAN_DIALECT_CUF_CUF_OPS

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366366
const fir::LLVMTypeConverter *typeConverter;
367367
};
368368

369-
static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
370-
mlir::ModuleOp mod, mlir::Location loc,
371-
mlir::Value inputArg) {
372-
fir::FirOpBuilder builder(rewriter, mod);
373-
mlir::func::FuncOp callee =
374-
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
375-
auto fTy = callee.getFunctionType();
376-
mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
377-
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
378-
mlir::Value sourceLine =
379-
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
380-
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
381-
builder, loc, fTy, conv, sourceFile, sourceLine)};
382-
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
383-
return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
384-
}
369+
struct CUFDeviceAddressOpConversion
370+
: public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
371+
using OpRewritePattern::OpRewritePattern;
372+
373+
CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
374+
const mlir::SymbolTable &symtab)
375+
: OpRewritePattern(context), symTab{symtab} {}
376+
377+
mlir::LogicalResult
378+
matchAndRewrite(cuf::DeviceAddressOp op,
379+
mlir::PatternRewriter &rewriter) const override {
380+
if (auto global = symTab.lookup<fir::GlobalOp>(
381+
op.getHostSymbol().getRootReference().getValue())) {
382+
auto mod = op->getParentOfType<mlir::ModuleOp>();
383+
mlir::Location loc = op.getLoc();
384+
auto hostAddr = rewriter.create<fir::AddrOfOp>(
385+
loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
386+
fir::FirOpBuilder builder(rewriter, mod);
387+
mlir::func::FuncOp callee =
388+
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
389+
builder);
390+
auto fTy = callee.getFunctionType();
391+
mlir::Value conv =
392+
createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
393+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
394+
mlir::Value sourceLine =
395+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
396+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
397+
builder, loc, fTy, conv, sourceFile, sourceLine)};
398+
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
399+
mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
400+
call->getResult(0));
401+
rewriter.replaceOp(op, addr.getDefiningOp());
402+
return success();
403+
}
404+
return failure();
405+
}
406+
407+
private:
408+
const mlir::SymbolTable &symTab;
409+
};
385410

386411
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
387412
using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
398423
addrOfOp.getSymbol().getRootReference().getValue())) {
399424
if (cuf::isRegisteredDeviceGlobal(global)) {
400425
rewriter.setInsertionPointAfter(addrOfOp);
401-
auto mod = op->getParentOfType<mlir::ModuleOp>();
402-
mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
403-
addrOfOp.getResult());
426+
mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
427+
op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
404428
rewriter.startOpModification(op);
405429
op.getMemrefMutable().assign(devAddr);
406430
rewriter.finalizeOpModification(op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
773797
}
774798
}
775799
llvm::SmallVector<mlir::Value> args;
776-
auto mod = op->getParentOfType<mlir::ModuleOp>();
777800
for (mlir::Value arg : op.getArgs()) {
778801
// If the argument is a global descriptor, make sure we pass the device
779802
// copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
785808
if (auto global = symTab.lookup<fir::GlobalOp>(
786809
addrOfOp.getSymbol().getRootReference().getValue())) {
787810
if (cuf::isRegisteredDeviceGlobal(global)) {
788-
arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
789-
declareOp.getResult());
811+
arg = rewriter
812+
.create<cuf::DeviceAddressOp>(op.getLoc(),
813+
addrOfOp.getType(),
814+
addrOfOp.getSymbol())
815+
.getResult();
790816
}
791817
}
792818
}
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
907933
patterns.getContext());
908934
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
909935
&dl, &converter);
910-
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
936+
patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
937+
patterns.getContext(), symtab);
911938
}
912939

913940
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
914941
mlir::RewritePatternSet &patterns) {
915-
patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
942+
patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
943+
patterns.getContext(), symtab);
916944
}

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
198198
// CHECK-LABEL: func.func @_QPsub8()
199199
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
200200
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
201+
// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
201202
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
202203
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
203204
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -222,6 +223,7 @@ func.func @_QPsub9() {
222223
// CHECK-LABEL: func.func @_QPsub9()
223224
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
224225
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
226+
// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
225227
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
226228
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
227229
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -380,6 +382,7 @@ func.func @_QPdevice_addr_conv() {
380382
}
381383

382384
// CHECK-LABEL: func.func @_QPdevice_addr_conv()
385+
// CHECK: fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
383386
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
384387
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
385388
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>

flang/test/Fir/CUDA/cuda-global-addr.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
2626
}
2727

2828
// CHECK-LABEL: func.func @_QQmain()
29+
// CHECK: fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
2930
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
3031
// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
3132
// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>

flang/test/Fir/CUDA/cuda-launch.fir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
9898
}
9999

100100
// CHECK-LABEL: func.func @_QQmain()
101+
// CHECK: _FortranACUFSyncGlobalDescriptor
101102
// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
102-
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
103-
// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
104-
// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
103+
// CHECK: %[[CONV_ADDR:.*]] = fir.convert %[[ADDROF]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
104+
// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_ADDR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
105105
// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
106106
// CHECK: gpu.launch_func @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)

0 commit comments

Comments
 (0)