@@ -370,11 +370,6 @@ struct CufDataTransferOpConversion
370
370
mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
371
371
mlir::Type dstTy = fir::unwrapRefType (op.getDst ().getType ());
372
372
373
- // Only convert cuf.data_transfer with at least one descripor.
374
- if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
375
- !mlir::isa<fir::BaseBoxType>(dstTy))
376
- return failure ();
377
-
378
373
unsigned mode;
379
374
if (op.getTransferKind () == cuf::DataTransferKind::HostDevice) {
380
375
mode = kHostToDevice ;
@@ -387,7 +382,64 @@ struct CufDataTransferOpConversion
387
382
auto mod = op->getParentOfType <mlir::ModuleOp>();
388
383
fir::FirOpBuilder builder (rewriter, mod);
389
384
mlir::Location loc = op.getLoc ();
385
+ fir::KindMapping kindMap{fir::getKindMapping (mod)};
386
+ mlir::Value modeValue =
387
+ builder.createIntegerConstant (loc, builder.getI32Type (), mode);
388
+
389
+ // Convert data transfer without any descriptor.
390
+ if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
391
+ !mlir::isa<fir::BaseBoxType>(dstTy)) {
392
+
393
+ if (fir::isa_trivial (srcTy) && !fir::isa_trivial (dstTy)) {
394
+ // TODO: scalar to array data transfer.
395
+ mlir::emitError (loc,
396
+ " not yet implemented: scalar to array data transfer\n " );
397
+ return mlir::failure ();
398
+ }
399
+
400
+ mlir::Type i64Ty = builder.getI64Type ();
401
+ mlir::Value nbElement;
402
+ if (op.getShape ()) {
403
+ auto shapeOp =
404
+ mlir::dyn_cast<fir::ShapeOp>(op.getShape ().getDefiningOp ());
405
+ nbElement = rewriter.create <fir::ConvertOp>(loc, i64Ty,
406
+ shapeOp.getExtents ()[0 ]);
407
+ for (unsigned i = 1 ; i < shapeOp.getExtents ().size (); ++i) {
408
+ auto operand = rewriter.create <fir::ConvertOp>(
409
+ loc, i64Ty, shapeOp.getExtents ()[i]);
410
+ nbElement =
411
+ rewriter.create <mlir::arith::MulIOp>(loc, nbElement, operand);
412
+ }
413
+ } else {
414
+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
415
+ nbElement = builder.createIntegerConstant (
416
+ loc, i64Ty, seqTy.getConstantArraySize ());
417
+ }
418
+ int width = computeWidth (loc, dstTy, kindMap);
419
+ mlir::Value widthValue = rewriter.create <mlir::arith::ConstantOp>(
420
+ loc, i64Ty, rewriter.getIntegerAttr (i64Ty, width));
421
+ mlir::Value bytes =
422
+ nbElement
423
+ ? rewriter.create <mlir::arith::MulIOp>(loc, nbElement, widthValue)
424
+ : widthValue;
425
+
426
+ mlir::func::FuncOp func =
427
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFDataTransferPtrPtr)>(loc,
428
+ builder);
429
+ auto fTy = func.getFunctionType ();
430
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
431
+ mlir::Value sourceLine =
432
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (5 ));
433
+
434
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
435
+ builder, loc, fTy , op.getDst (), op.getSrc (), bytes, modeValue,
436
+ sourceFile, sourceLine)};
437
+ builder.create <fir::CallOp>(loc, func, args);
438
+ rewriter.eraseOp (op);
439
+ return mlir::success ();
440
+ }
390
441
442
+ // Conversion of data transfer involving at least one descriptor.
391
443
if (mlir::isa<fir::BaseBoxType>(srcTy) &&
392
444
mlir::isa<fir::BaseBoxType>(dstTy)) {
393
445
// Transfer between two descriptor.
@@ -396,8 +448,6 @@ struct CufDataTransferOpConversion
396
448
loc, builder);
397
449
398
450
auto fTy = func.getFunctionType ();
399
- mlir::Value modeValue =
400
- builder.createIntegerConstant (loc, builder.getI32Type (), mode);
401
451
mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
402
452
mlir::Value sourceLine =
403
453
fir::factory::locationToLineNo (builder, loc, fTy .getInput (4 ));
@@ -430,8 +480,6 @@ struct CufDataTransferOpConversion
430
480
builder.create <fir::CallOp>(loc, func, args);
431
481
rewriter.eraseOp (op);
432
482
} else {
433
- mlir::Value modeValue =
434
- builder.createIntegerConstant (loc, builder.getI32Type (), mode);
435
483
// Type used to compute the width.
436
484
mlir::Type computeType = dstTy;
437
485
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
@@ -441,7 +489,6 @@ struct CufDataTransferOpConversion
441
489
computeType = srcTy;
442
490
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
443
491
}
444
- fir::KindMapping kindMap{fir::getKindMapping (mod)};
445
492
int width = computeWidth (loc, computeType, kindMap);
446
493
447
494
mlir::Value nbElement;
@@ -509,13 +556,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
509
556
fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
510
557
fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
511
558
/* forceUnifiedTBAATree=*/ false , *dl);
512
- target.addDynamicallyLegalOp <cuf::DataTransferOp>(
513
- [](::cuf::DataTransferOp op) {
514
- mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
515
- mlir::Type dstTy = fir::unwrapRefType (op.getDst ().getType ());
516
- return !mlir::isa<fir::BaseBoxType>(srcTy) &&
517
- !mlir::isa<fir::BaseBoxType>(dstTy);
518
- });
519
559
target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect>();
520
560
cuf::populateCUFToFIRConversionPatterns (typeConverter, *dl, patterns);
521
561
if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
0 commit comments