Skip to content

Commit 0bc710f

Browse files
authored
[flang][cuda] Accept constant as src for cuf.data_tranfer (#92951)
Assignment of a constant (host) to a device variable is a special case that can be further lowered to `cudaMemset` or similar functions. This patch update the lowering to avoid the creation of a temporary when we assign a constant to a device variable.
1 parent 57a5079 commit 0bc710f

File tree

4 files changed

+24
-11
lines changed

4 files changed

+24
-11
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
158158
updated.
159159
}];
160160

161-
let arguments = (ins Arg<AnyRefOrBoxType, "", [MemRead]>:$src,
161+
let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
162162
Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
163163
cuf_DataTransferKindAttr:$transfer_kind);
164164

flang/lib/Lower/Bridge.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include "flang/Semantics/symbol.h"
5858
#include "flang/Semantics/tools.h"
5959
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
60+
#include "mlir/IR/Matchers.h"
6061
#include "mlir/IR/PatternMatch.h"
6162
#include "mlir/Parser/Parser.h"
6263
#include "mlir/Transforms/RegionUtils.h"
@@ -3798,11 +3799,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
37983799
auto transferKindAttr = cuf::DataTransferKindAttr::get(
37993800
builder.getContext(), cuf::DataTransferKind::HostDevice);
38003801
if (!rhs.isVariable()) {
3801-
auto associate = hlfir::genAssociateExpr(
3802-
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
3803-
builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
3804-
transferKindAttr);
3805-
builder.create<hlfir::EndAssociateOp>(loc, associate);
3802+
// Special case if the rhs is a constant.
3803+
if (matchPattern(rhs.getDefiningOp(), mlir::m_Constant())) {
3804+
builder.create<cuf::DataTransferOp>(loc, rhs, lhsVal,
3805+
transferKindAttr);
3806+
} else {
3807+
auto associate = hlfir::genAssociateExpr(
3808+
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
3809+
builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
3810+
transferKindAttr);
3811+
builder.create<hlfir::EndAssociateOp>(loc, associate);
3812+
}
38063813
} else {
38073814
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
38083815
transferKindAttr);

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ mlir::LogicalResult cuf::DataTransferOp::verify() {
9999
if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
100100
(fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)))
101101
return mlir::success();
102-
return emitOpError("expect src and dst to be both references or descriptors");
102+
if (fir::isa_trivial(srcTy) &&
103+
matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
104+
return mlir::success();
105+
return emitOpError()
106+
<< "expect src and dst to be both references or descriptors or src to "
107+
"be a constant";
103108
}
104109

105110
//===----------------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ subroutine sub1()
2525

2626
adev = ahost + bhost
2727

28+
adev = 10
29+
2830
end
2931

3032
! CHECK-LABEL: func.func @_QPsub1()
@@ -41,10 +43,7 @@ end
4143
! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
4244
! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<i32>, i1
4345

44-
! CHECK: %[[C1:.*]] = arith.constant 1 : i32
45-
! CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[C1]] {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
46-
! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
47-
! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<i32>, i1
46+
! CHECK: cuf.data_transfer %c1{{.*}} to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
4847

4948
! CHECK: cuf.data_transfer %[[AHOST]]#0 to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
5049

@@ -62,6 +61,8 @@ end
6261
! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
6362
! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<!fir.array<10xi32>>, i1
6463

64+
! CHECK: cuf.data_transfer %c10{{.*}} to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
65+
6566
subroutine sub2()
6667
integer, device :: m
6768
integer, device :: adev(10), bdev(10)

0 commit comments

Comments
 (0)