Skip to content

Commit e4e9fea

Browse files
authored
[flang][cuda] Pass descriptor by reference for CUFMemsetDescriptor (llvm#114338)
1 parent f4af60d commit e4e9fea

File tree

4 files changed

+8
-11
lines changed

4 files changed

+8
-11
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void RTDECL(CUFMemFree)(void *devicePtr, unsigned type,
2828
/// Set value to the data hold by a descriptor. The \p value pointer must be
2929
/// addressable to the same amount of bytes specified by the element size of
3030
/// the descriptor \p desc.
31-
void RTDECL(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
31+
void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
3232
const char *sourceFile = nullptr, int sourceLine = 0);
3333

3434
/// Data transfer from a pointer to a pointer.

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,8 @@ struct CUFDataTransferOpConversion
552552
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
553553
mlir::Value sourceLine =
554554
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
555-
mlir::Value dst = builder.loadIfRef(loc, op.getDst());
556555
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
557-
builder, loc, fTy, dst, val, sourceFile, sourceLine)};
556+
builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
558557
builder.create<fir::CallOp>(loc, func, args);
559558
rewriter.eraseOp(op);
560559
} else {

flang/runtime/CUDA/memory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ void RTDEF(CUFMemFree)(
4949
}
5050
}
5151

52-
void RTDEF(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
53-
const char *sourceFile, int sourceLine) {
52+
void RTDEF(CUFMemsetDescriptor)(
53+
Descriptor *desc, void *value, const char *sourceFile, int sourceLine) {
5454
Terminator terminator{sourceFile, sourceLine};
5555
terminator.Crash("not yet implemented: CUDA data transfer from a scalar "
5656
"value to a descriptor");

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ func.func @_QPsub2() {
3333
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
3434
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
3535
// CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
36-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
37-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
36+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
3837
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
39-
// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
38+
// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
4039

4140
func.func @_QPsub3() {
4241
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -51,10 +50,9 @@ func.func @_QPsub3() {
5150
// CHECK-LABEL: func.func @_QPsub3()
5251
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
5352
// CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
54-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
55-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
53+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
5654
// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
57-
// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
55+
// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
5856

5957
func.func @_QPsub4() {
6058
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>

0 commit comments

Comments
 (0)