Skip to content

Revert "[flang][cuda] Specialize entry point for scalar to desc data transfer" #116458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions flang/include/flang/Runtime/CUDA/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a scalar descriptor to a descriptor.
void RTDECL(CUFDataTransferCstDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a descriptor to a descriptor.
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
Expand Down
8 changes: 2 additions & 6 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,8 @@ struct CUFDataTransferOpConversion
// until we have more infrastructure.
mlir::Value src = emboxSrc(rewriter, op, symtab);
mlir::Value dst = emboxDst(rewriter, op, symtab);
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
CUFDataTransferDescDescNoRealloc)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
Expand Down Expand Up @@ -649,9 +648,6 @@ struct CUFDataTransferOpConversion
mlir::Value src = op.getSrc();
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
src = emboxSrc(rewriter, op, symtab);
if (fir::isa_trivial(srcTy))
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
}
auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
if (mlir::isa<fir::EmboxOp>(val.getDefiningOp())) {
Expand Down
19 changes: 0 additions & 19 deletions flang/runtime/CUDA/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//

#include "flang/Runtime/CUDA/memory.h"
#include "../assign-impl.h"
#include "../terminator.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
Expand Down Expand Up @@ -121,24 +120,6 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}

void RTDECL(CUFDataTransferCstDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
unsigned mode, const char *sourceFile, int sourceLine) {
MemmoveFct memmoveFct;
Terminator terminator{sourceFile, sourceLine};
if (mode == kHostToDevice) {
memmoveFct = &MemmoveHostToDevice;
} else if (mode == kDeviceToHost) {
memmoveFct = &MemmoveDeviceToHost;
} else if (mode == kDeviceToDevice) {
memmoveFct = &MemmoveDeviceToDevice;
} else {
terminator.Crash("host to host copy not supported");
}

Fortran::runtime::DoFromSourceAssign(
*dstDesc, *srcDesc, terminator, memmoveFct);
}

void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
Expand Down
17 changes: 2 additions & 15 deletions flang/runtime/assign-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,16 @@
#ifndef FORTRAN_RUNTIME_ASSIGN_IMPL_H_
#define FORTRAN_RUNTIME_ASSIGN_IMPL_H_

#include "flang/Runtime/freestanding-tools.h"

namespace Fortran::runtime {
class Descriptor;
class Terminator;

using MemmoveFct = void *(*)(void *, const void *, std::size_t);

// Assign one object to another via allocate statement from source specifier.
// Note that if allocate object and source expression have the same rank, the
// value of the allocate object becomes the value provided; otherwise the value
// of each element of allocate object becomes the value provided (9.7.1.2(7)).
#ifdef RT_DEVICE_COMPILATION
static RT_API_ATTRS void *MemmoveWrapper(
void *dest, const void *src, std::size_t count) {
return Fortran::runtime::memmove(dest, src, count);
}
RT_API_ATTRS void DoFromSourceAssign(Descriptor &, const Descriptor &,
Terminator &, MemmoveFct memmoveFct = &MemmoveWrapper);
#else
RT_API_ATTRS void DoFromSourceAssign(Descriptor &, const Descriptor &,
Terminator &, MemmoveFct memmoveFct = &Fortran::runtime::memmove);
#endif
RT_API_ATTRS void DoFromSourceAssign(
Descriptor &, const Descriptor &, Terminator &);

} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_ASSIGN_IMPL_H_
12 changes: 6 additions & 6 deletions flang/runtime/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,

RT_OFFLOAD_API_GROUP_BEGIN

RT_API_ATTRS void DoFromSourceAssign(Descriptor &alloc,
const Descriptor &source, Terminator &terminator, MemmoveFct memmoveFct) {
RT_API_ATTRS void DoFromSourceAssign(
Descriptor &alloc, const Descriptor &source, Terminator &terminator) {
if (alloc.rank() > 0 && source.rank() == 0) {
// The value of each element of allocate object becomes the value of source.
DescriptorAddendum *allocAddendum{alloc.Addendum()};
Expand All @@ -523,17 +523,17 @@ RT_API_ATTRS void DoFromSourceAssign(Descriptor &alloc,
alloc.IncrementSubscripts(allocAt)) {
Descriptor allocElement{*Descriptor::Create(*allocDerived,
reinterpret_cast<void *>(alloc.Element<char>(allocAt)), 0)};
Assign(allocElement, source, terminator, NoAssignFlags, memmoveFct);
Assign(allocElement, source, terminator, NoAssignFlags);
}
} else { // intrinsic type
for (std::size_t n{alloc.Elements()}; n-- > 0;
alloc.IncrementSubscripts(allocAt)) {
memmoveFct(alloc.Element<char>(allocAt), source.raw().base_addr,
alloc.ElementBytes());
Fortran::runtime::memmove(alloc.Element<char>(allocAt),
source.raw().base_addr, alloc.ElementBytes());
}
}
} else {
Assign(alloc, source, terminator, NoAssignFlags, memmoveFct);
Assign(alloc, source, terminator, NoAssignFlags);
}
}

Expand Down
12 changes: 6 additions & 6 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func.func @_QPsub2() {
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPsub3() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
Expand All @@ -58,7 +58,7 @@ func.func @_QPsub3() {
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPsub4() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
Expand Down Expand Up @@ -297,7 +297,7 @@ func.func @_QPscalar_to_array() {
}

// CHECK-LABEL: func.func @_QPscalar_to_array()
// CHECK: _FortranACUFDataTransferCstDesc
// CHECK: _FortranACUFDataTransferDescDescNoRealloc

func.func @_QPtest_type() {
%0 = cuf.alloc !fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_typeEa"} -> !fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>>
Expand Down Expand Up @@ -344,7 +344,7 @@ func.func @_QPshape_shift() {
}

// CHECK-LABEL: func.func @_QPshape_shift()
// CHECK: fir.call @_FortranACUFDataTransferCstDesc
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc

func.func @_QPshape_shift2() {
%c11 = arith.constant 11 : index
Expand Down Expand Up @@ -383,7 +383,7 @@ func.func @_QPdevice_addr_conv() {
// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc

func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -464,6 +464,6 @@ func.func @_QPlogical_cst() {
// CHECK: %[[EMBOX:.*]] = fir.embox %[[CONST]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<4>>>
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<4>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

} // end of module
Loading