-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Convert data transfer between scalar and arrays #110180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Convert data transfer between scalar and arrays #110180
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesAdd conversion of data transfer between scalars or between arrays. Scalar to array are not handled yet. Full diff: https://github.com/llvm/llvm-project/pull/110180.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index f8ace2dd96a0d8..faaadf01f3f2b6 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -317,11 +317,6 @@ struct CufDataTransferOpConversion
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
- // Only convert cuf.data_transfer with at least one descripor.
- if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
- !mlir::isa<fir::BaseBoxType>(dstTy))
- return failure();
-
unsigned mode;
if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
mode = kHostToDevice;
@@ -334,7 +329,62 @@ struct CufDataTransferOpConversion
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
+ fir::KindMapping kindMap{fir::getKindMapping(mod)};
+ mlir::Value modeValue =
+ builder.createIntegerConstant(loc, builder.getI32Type(), mode);
+
+ // Convert data transfer without any descriptor.
+ if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
+ !mlir::isa<fir::BaseBoxType>(dstTy)) {
+
+ if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
+ // TODO: scalar to array data transfer.
+ return mlir::failure();
+ }
+
+ mlir::Type i64Ty = builder.getI64Type();
+ mlir::Value nbElement;
+ if (op.getShape()) {
+ auto shapeOp =
+ mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
+ nbElement = rewriter.create<fir::ConvertOp>(loc, i64Ty,
+ shapeOp.getExtents()[0]);
+ for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
+ auto operand = rewriter.create<fir::ConvertOp>(
+ loc, i64Ty, shapeOp.getExtents()[i]);
+ nbElement =
+ rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
+ }
+ } else {
+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
+ nbElement = builder.createIntegerConstant(
+ loc, i64Ty, seqTy.getConstantArraySize());
+ }
+ int width = computeWidth(loc, dstTy, kindMap);
+ mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
+ loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
+ mlir::Value bytes =
+ nbElement
+ ? rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue)
+ : widthValue;
+
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
+ builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
+
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue,
+ sourceFile, sourceLine)};
+ builder.create<fir::CallOp>(loc, func, args);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+ // Conversion of data transfer involving at least one descriptor.
if (mlir::isa<fir::BaseBoxType>(srcTy) &&
mlir::isa<fir::BaseBoxType>(dstTy)) {
// Transfer between two descriptor.
@@ -343,8 +393,6 @@ struct CufDataTransferOpConversion
loc, builder);
auto fTy = func.getFunctionType();
- mlir::Value modeValue =
- builder.createIntegerConstant(loc, builder.getI32Type(), mode);
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
@@ -377,8 +425,6 @@ struct CufDataTransferOpConversion
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
- mlir::Value modeValue =
- builder.createIntegerConstant(loc, builder.getI32Type(), mode);
// Type used to compute the width.
mlir::Type computeType = dstTy;
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
@@ -388,7 +434,6 @@ struct CufDataTransferOpConversion
computeType = srcTy;
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
}
- fir::KindMapping kindMap{fir::getKindMapping(mod)};
int width = computeWidth(loc, computeType, kindMap);
mlir::Value nbElement;
@@ -466,13 +511,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
}
return true;
});
- target.addDynamicallyLegalOp<cuf::DataTransferOp>(
- [](::cuf::DataTransferOp op) {
- mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
- mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
- return !mlir::isa<fir::BaseBoxType>(srcTy) &&
- !mlir::isa<fir::BaseBoxType>(dstTy);
- });
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index f639a6c22b76d0..ed894aed5534a0 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -70,7 +70,6 @@ func.func @_QPsub4() {
cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
return
}
-
// CHECK-LABEL: func.func @_QPsub4()
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
@@ -137,4 +136,57 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
+func.func @_QPsub6() {
+ %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>
+ %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %2 = fir.alloca i32 {bindc_name = "ihost", uniq_name = "_QFsub6Eihost"}
+ %3:2 = hlfir.declare %2 {uniq_name = "_QFsub6Eihost"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ cuf.data_transfer %1#0 to %3#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
+ %4 = fir.load %3#0 : !fir.ref<i32>
+ %5:3 = hlfir.associate %4 {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+ cuf.data_transfer %5#0 to %1#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
+ hlfir.end_associate %5#1, %5#2 : !fir.ref<i32>, i1
+ cuf.free %1#1 : !fir.ref<i32> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub6()
+// CHECK: %[[IDEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[IHOST:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub6Eihost"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[DST:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %c4{{.*}}, %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[LOAD:.*]] = fir.load %[[IHOST]]#0 : !fir.ref<i32>
+// CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[LOAD]] {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+// CHECK: %[[DST:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[ASSOC]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %c4{{.*}}, %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub7() {
+ %c10 = arith.constant 10 : index
+ %0 = cuf.alloc !fir.array<10xi32> {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub7Eidev"} -> !fir.ref<!fir.array<10xi32>>
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub7Eidev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+ %c10_0 = arith.constant 10 : index
+ %3 = fir.alloca !fir.array<10xi32> {bindc_name = "ihost", uniq_name = "_QFsub7Eihost"}
+ %4 = fir.shape %c10_0 : (index) -> !fir.shape<1>
+ %5:2 = hlfir.declare %3(%4) {uniq_name = "_QFsub7Eihost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+ cuf.data_transfer %2#0 to %5#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
+ cuf.data_transfer %5#0 to %2#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
+ cuf.free %2#1 : !fir.ref<!fir.array<10xi32>> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub7()
+// CHECK: %[[IDEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub7Eidev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[IHOST:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub7Eihost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : i64
+// CHECK: %[[DST:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : i64
+// CHECK: %[[DST:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
} // end of module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, with one remark.
|
||
if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) { | ||
// TODO: scalar to array data transfer. | ||
return mlir::failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, but wouldn't this leave the op uconverted, and it would be illegal for the conversion target? If yes, then I would rather add an explicit TODO here, than rely on not so explicit dialect conversion failure message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are correct. The message would be a bit criptic to understand. I don't expect it to stay like this for long but I'm gonna add a more intuitive error message for the time being.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/89/builds/7488 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/9057 Here is the relevant piece of the build log for the reference
|
Add conversion of data transfer between scalars or between arrays.
Scalar to array are not handled yet.