@@ -410,42 +410,47 @@ struct ConvertUpdateHaloOp
410
410
// local data. Because subviews and halos can have mixed dynamic and static
411
411
// shapes, OpFoldResults are used whenever possible.
412
412
413
+ auto haloSizes = getMixedValues (adaptor.getStaticHaloSizes (),
414
+ adaptor.getHaloSizes (), rewriter);
415
+ if (haloSizes.empty ()) {
416
+ // no halos -> nothing to do
417
+ rewriter.replaceOp (op, adaptor.getDestination ());
418
+ return success ();
419
+ }
420
+
413
421
SymbolTableCollection symbolTableCollection;
414
- auto loc = op.getLoc ();
422
+ Location loc = op.getLoc ();
415
423
416
424
// convert a OpFoldResult into a Value
417
425
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
418
426
if (auto value = dyn_cast<Value>(v))
419
427
return value;
420
- return rewriter.create <::mlir:: arith::ConstantOp>(
428
+ return rewriter.create <arith::ConstantOp>(
421
429
loc, rewriter.getIndexAttr (
422
430
cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
423
431
};
424
432
425
- auto dest = op .getDestination ();
433
+ auto dest = adaptor .getDestination ();
426
434
auto dstShape = cast<ShapedType>(dest.getType ()).getShape ();
427
435
Value array = dest;
428
436
if (isa<RankedTensorType>(array.getType ())) {
429
437
// If the destination is a memref, we need to cast it to a tensor
430
438
auto tensorType = MemRefType::get (
431
439
dstShape, cast<ShapedType>(array.getType ()).getElementType ());
432
- array = rewriter. create <bufferization::ToMemrefOp>(loc, tensorType, array)
433
- . getResult ( );
440
+ array =
441
+ rewriter. create <bufferization::ToMemrefOp>(loc, tensorType, array );
434
442
}
435
443
auto rank = cast<ShapedType>(array.getType ()).getRank ();
436
- auto opSplitAxes = op .getSplitAxes ().getAxes ();
437
- auto mesh = op .getMesh ();
444
+ auto opSplitAxes = adaptor .getSplitAxes ().getAxes ();
445
+ auto mesh = adaptor .getMesh ();
438
446
auto meshOp = getMesh (op, symbolTableCollection);
439
- auto haloSizes =
440
- getMixedValues (op.getStaticHaloSizes (), op.getHaloSizes (), rewriter);
441
447
// subviews need Index values
442
448
for (auto &sz : haloSizes) {
443
- if (auto value = dyn_cast<Value>(sz)) {
449
+ if (auto value = dyn_cast<Value>(sz))
444
450
sz =
445
451
rewriter
446
452
.create <arith::IndexCastOp>(loc, rewriter.getIndexType (), value)
447
453
.getResult ();
448
- }
449
454
}
450
455
451
456
// most of the offset/size/stride data is the same for all dims
@@ -530,8 +535,8 @@ struct ConvertUpdateHaloOp
530
535
: haloSizes[currHaloDim * 2 ];
531
536
// Check if we need to send and/or receive
532
537
// Processes on the mesh borders have only one neighbor
533
- auto to = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
534
- auto from = upperHalo ? neighbourIDs[0 ] : neighbourIDs[1 ];
538
+ auto to = upperHalo ? neighbourIDs[0 ] : neighbourIDs[1 ];
539
+ auto from = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
535
540
auto hasFrom = rewriter.create <arith::CmpIOp>(
536
541
loc, arith::CmpIPredicate::sge, from, zero);
537
542
auto hasTo = rewriter.create <arith::CmpIOp>(
@@ -564,8 +569,25 @@ struct ConvertUpdateHaloOp
564
569
offsets[dim] = orgOffset;
565
570
};
566
571
567
- genSendRecv (false );
568
- genSendRecv (true );
572
+ auto get_i32val = [&](OpFoldResult &v) {
573
+ return isa<Value>(v)
574
+ ? cast<Value>(v)
575
+ : rewriter.create <arith::ConstantOp>(
576
+ loc,
577
+ rewriter.getI32IntegerAttr (
578
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
579
+ };
580
+
581
+ for (int i = 0 ; i < 2 ; ++i) {
582
+ Value haloSz = get_i32val (haloSizes[currHaloDim * 2 + i]);
583
+ auto hasSize = rewriter.create <arith::CmpIOp>(
584
+ loc, arith::CmpIPredicate::sgt, haloSz, zero);
585
+ rewriter.create <scf::IfOp>(loc, hasSize,
586
+ [&](OpBuilder &builder, Location loc) {
587
+ genSendRecv (i > 0 );
588
+ builder.create <scf::YieldOp>(loc);
589
+ });
590
+ }
569
591
570
592
// the shape for lower dims include higher dims' halos
571
593
dimSizes[dim] = shape[dim];
@@ -583,7 +605,7 @@ struct ConvertUpdateHaloOp
583
605
loc, op.getResult ().getType (), array,
584
606
/* restrict=*/ true , /* writable=*/ true ));
585
607
}
586
- return mlir:: success ();
608
+ return success ();
587
609
}
588
610
};
589
611
0 commit comments