Skip to content

Commit 30e6cd1

Browse files
committed
fixing halo send/rdv direction, not communication if no halo
1 parent 2e59db0 commit 30e6cd1

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -410,42 +410,47 @@ struct ConvertUpdateHaloOp
410410
// local data. Because subviews and halos can have mixed dynamic and static
411411
// shapes, OpFoldResults are used whenever possible.
412412

413+
auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
414+
adaptor.getHaloSizes(), rewriter);
415+
if (haloSizes.empty()) {
416+
// no halos -> nothing to do
417+
rewriter.replaceOp(op, adaptor.getDestination());
418+
return success();
419+
}
420+
413421
SymbolTableCollection symbolTableCollection;
414-
auto loc = op.getLoc();
422+
Location loc = op.getLoc();
415423

416424
// convert a OpFoldResult into a Value
417425
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
418426
if (auto value = dyn_cast<Value>(v))
419427
return value;
420-
return rewriter.create<::mlir::arith::ConstantOp>(
428+
return rewriter.create<arith::ConstantOp>(
421429
loc, rewriter.getIndexAttr(
422430
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
423431
};
424432

425-
auto dest = op.getDestination();
433+
auto dest = adaptor.getDestination();
426434
auto dstShape = cast<ShapedType>(dest.getType()).getShape();
427435
Value array = dest;
428436
if (isa<RankedTensorType>(array.getType())) {
429437
// If the destination is a memref, we need to cast it to a tensor
430438
auto tensorType = MemRefType::get(
431439
dstShape, cast<ShapedType>(array.getType()).getElementType());
432-
array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
433-
.getResult();
440+
array =
441+
rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
434442
}
435443
auto rank = cast<ShapedType>(array.getType()).getRank();
436-
auto opSplitAxes = op.getSplitAxes().getAxes();
437-
auto mesh = op.getMesh();
444+
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
445+
auto mesh = adaptor.getMesh();
438446
auto meshOp = getMesh(op, symbolTableCollection);
439-
auto haloSizes =
440-
getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
441447
// subviews need Index values
442448
for (auto &sz : haloSizes) {
443-
if (auto value = dyn_cast<Value>(sz)) {
449+
if (auto value = dyn_cast<Value>(sz))
444450
sz =
445451
rewriter
446452
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
447453
.getResult();
448-
}
449454
}
450455

451456
// most of the offset/size/stride data is the same for all dims
@@ -530,8 +535,8 @@ struct ConvertUpdateHaloOp
530535
: haloSizes[currHaloDim * 2];
531536
// Check if we need to send and/or receive
532537
// Processes on the mesh borders have only one neighbor
533-
auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
534-
auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
538+
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
539+
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
535540
auto hasFrom = rewriter.create<arith::CmpIOp>(
536541
loc, arith::CmpIPredicate::sge, from, zero);
537542
auto hasTo = rewriter.create<arith::CmpIOp>(
@@ -564,8 +569,25 @@ struct ConvertUpdateHaloOp
564569
offsets[dim] = orgOffset;
565570
};
566571

567-
genSendRecv(false);
568-
genSendRecv(true);
572+
auto get_i32val = [&](OpFoldResult &v) {
573+
return isa<Value>(v)
574+
? cast<Value>(v)
575+
: rewriter.create<arith::ConstantOp>(
576+
loc,
577+
rewriter.getI32IntegerAttr(
578+
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
579+
};
580+
581+
for (int i = 0; i < 2; ++i) {
582+
Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
583+
auto hasSize = rewriter.create<arith::CmpIOp>(
584+
loc, arith::CmpIPredicate::sgt, haloSz, zero);
585+
rewriter.create<scf::IfOp>(loc, hasSize,
586+
[&](OpBuilder &builder, Location loc) {
587+
genSendRecv(i > 0);
588+
builder.create<scf::YieldOp>(loc);
589+
});
590+
}
569591

570592
// the shape for lower dims include higher dims' halos
571593
dimSizes[dim] = shape[dim];
@@ -583,7 +605,7 @@ struct ConvertUpdateHaloOp
583605
loc, op.getResult().getType(), array,
584606
/*restrict=*/true, /*writable=*/true));
585607
}
586-
return mlir::success();
608+
return success();
587609
}
588610
};
589611

0 commit comments

Comments
 (0)