@@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
448
448
return mlir::Value{};
449
449
}
450
450
451
+ static mlir::Value emboxSrc (mlir::PatternRewriter &rewriter,
452
+ cuf::DataTransferOp op,
453
+ const mlir::SymbolTable &symtab) {
454
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
455
+ mlir::Location loc = op.getLoc ();
456
+ fir::FirOpBuilder builder (rewriter, mod);
457
+ mlir::Value addr;
458
+ mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
459
+ if (fir::isa_trivial (srcTy) &&
460
+ mlir::matchPattern (op.getSrc ().getDefiningOp (), mlir::m_Constant ())) {
461
+ // Put constant in memory if it is not.
462
+ mlir::Value alloc = builder.createTemporary (loc, srcTy);
463
+ builder.create <fir::StoreOp>(loc, op.getSrc (), alloc);
464
+ addr = alloc;
465
+ } else {
466
+ addr = getDeviceAddress (rewriter, op.getSrcMutable (), symtab);
467
+ }
468
+ llvm::SmallVector<mlir::Value> lenParams;
469
+ mlir::Type boxTy = fir::BoxType::get (srcTy);
470
+ mlir::Value box =
471
+ builder.createBox (loc, boxTy, addr, getShapeFromDecl (op.getSrc ()),
472
+ /* slice=*/ nullptr , lenParams,
473
+ /* tdesc=*/ nullptr );
474
+ mlir::Value src = builder.createTemporary (loc, box.getType ());
475
+ builder.create <fir::StoreOp>(loc, box, src);
476
+ return src;
477
+ }
478
+
479
+ static mlir::Value emboxDst (mlir::PatternRewriter &rewriter,
480
+ cuf::DataTransferOp op,
481
+ const mlir::SymbolTable &symtab) {
482
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
483
+ mlir::Location loc = op.getLoc ();
484
+ fir::FirOpBuilder builder (rewriter, mod);
485
+ mlir::Type dstTy = fir::unwrapRefType (op.getDst ().getType ());
486
+ mlir::Value dstAddr = getDeviceAddress (rewriter, op.getDstMutable (), symtab);
487
+ mlir::Type dstBoxTy = fir::BoxType::get (dstTy);
488
+ llvm::SmallVector<mlir::Value> lenParams;
489
+ mlir::Value dstBox =
490
+ builder.createBox (loc, dstBoxTy, dstAddr, getShapeFromDecl (op.getDst ()),
491
+ /* slice=*/ nullptr , lenParams,
492
+ /* tdesc=*/ nullptr );
493
+ mlir::Value dst = builder.createTemporary (loc, dstBox.getType ());
494
+ builder.create <fir::StoreOp>(loc, dstBox, dst);
495
+ return dst;
496
+ }
497
+
451
498
struct CUFDataTransferOpConversion
452
499
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
453
500
using OpRewritePattern::OpRewritePattern;
@@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion
486
533
!mlir::isa<fir::BaseBoxType>(dstTy)) {
487
534
488
535
if (fir::isa_trivial (srcTy) && !fir::isa_trivial (dstTy)) {
489
- // TODO: scalar to array data transfer.
490
- mlir::emitError (loc,
491
- " not yet implemented: scalar to array data transfer\n " );
492
- return mlir::failure ();
536
+ // Initialization of an array from a scalar value should be implemented
537
+ // via a kernel launch. Use the flan runtime via the Assign function
538
+ // until we have more infrastructure.
539
+ mlir::Value src = emboxSrc (rewriter, op, symtab);
540
+ mlir::Value dst = emboxDst (rewriter, op, symtab);
541
+ mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey (
542
+ CUFDataTransferDescDescNoRealloc)>(loc, builder);
543
+ auto fTy = func.getFunctionType ();
544
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
545
+ mlir::Value sourceLine =
546
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (4 ));
547
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
548
+ builder, loc, fTy , dst, src, modeValue, sourceFile, sourceLine)};
549
+ builder.create <fir::CallOp>(loc, func, args);
550
+ rewriter.eraseOp (op);
551
+ return mlir::success ();
493
552
}
494
553
495
554
mlir::Type i64Ty = builder.getI64Type ();
@@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion
548
607
mlir::Value dst = op.getDst ();
549
608
mlir::Value src = op.getSrc ();
550
609
551
- if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
552
- // If src is not a descriptor, create one.
553
- mlir::Value addr;
554
- if (fir::isa_trivial (srcTy) &&
555
- mlir::matchPattern (op.getSrc ().getDefiningOp (),
556
- mlir::m_Constant ())) {
557
- // Put constant in memory if it is not.
558
- mlir::Value alloc = builder.createTemporary (loc, srcTy);
559
- builder.create <fir::StoreOp>(loc, op.getSrc (), alloc);
560
- addr = alloc;
561
- } else {
562
- addr = getDeviceAddress (rewriter, op.getSrcMutable (), symtab);
563
- }
564
- mlir::Type boxTy = fir::BoxType::get (srcTy);
565
- llvm::SmallVector<mlir::Value> lenParams;
566
- mlir::Value box =
567
- builder.createBox (loc, boxTy, addr, getShapeFromDecl (src),
568
- /* slice=*/ nullptr , lenParams,
569
- /* tdesc=*/ nullptr );
570
- mlir::Value memBox = builder.createTemporary (loc, box.getType ());
571
- builder.create <fir::StoreOp>(loc, box, memBox);
572
- src = memBox;
573
- }
610
+ if (!mlir::isa<fir::BaseBoxType>(srcTy))
611
+ src = emboxSrc (rewriter, op, symtab);
574
612
575
613
auto fTy = func.getFunctionType ();
576
614
mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
@@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion
582
620
rewriter.eraseOp (op);
583
621
} else {
584
622
// Transfer from a descriptor.
585
-
586
- mlir::Value addr = getDeviceAddress (rewriter, op.getDstMutable (), symtab);
587
- mlir::Type boxTy = fir::BoxType::get (dstTy);
588
- llvm::SmallVector<mlir::Value> lenParams;
589
- mlir::Value box =
590
- builder.createBox (loc, boxTy, addr, getShapeFromDecl (op.getDst ()),
591
- /* slice=*/ nullptr , lenParams,
592
- /* tdesc=*/ nullptr );
593
- mlir::Value memBox = builder.createTemporary (loc, box.getType ());
594
- builder.create <fir::StoreOp>(loc, box, memBox);
623
+ mlir::Value dst = emboxDst (rewriter, op, symtab);
595
624
596
625
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey (
597
626
CUFDataTransferDescDescNoRealloc)>(loc, builder);
@@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion
601
630
mlir::Value sourceLine =
602
631
fir::factory::locationToLineNo (builder, loc, fTy .getInput (4 ));
603
632
llvm::SmallVector<mlir::Value> args{
604
- fir::runtime::createArguments (builder, loc, fTy , memBox , op.getSrc (),
633
+ fir::runtime::createArguments (builder, loc, fTy , dst , op.getSrc (),
605
634
modeValue, sourceFile, sourceLine)};
606
635
builder.create <fir::CallOp>(loc, func, args);
607
636
rewriter.eraseOp (op);
0 commit comments