Skip to content

Commit 13c590c

Browse files
committed
comments/docs
1 parent a0aa3eb commit 13c590c

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,18 +1066,20 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10661066
let summary = "Update halo data.";
10671067
let description = [{
10681068
This operation updates halo regions of shards, e.g. if their sharding
1069-
specified halos and the actual tensor data might have changed
1069+
specified halos and the actual tensor/memref data might have changed
10701070
on the remote devices. Changes might be caused by mutating operations
10711071
and/or if the new halo regions are larger than the existing ones.
10721072

1073+
Source and destination might have different halo sizes.
1074+
10731075
Assumes all devices hold tensors with same-sized halo data as specified
10741076
by `source_halo_sizes/static_source_halo_sizes` and
1075-
`destination_halo_sizes/static_destination_halo_sizes`
1077+
`destination_halo_sizes/static_destination_halo_sizes` in source shard
1078+
and destination/result shard.
10761079

10771080
`split_axes` specifies for each tensor axis along which mesh axes its halo
10781081
data is updated.
10791082

1080-
Source and destination might have different halo sizes.
10811083
}];
10821084
let arguments = (ins
10831085
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,12 +429,18 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
429429
return std::nullopt;
430430
}
431431

432+
// Detect a change in the halo size (only) and create necessary operations if
433+
// needed. A changed halo sizes requires copying the "core" of the source tensor
434+
// into the "core" of the destination tensor followed by an update halo
435+
// operation.
432436
static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
433437
tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
434438
MeshSharding sourceSharding,
435439
MeshSharding targetSharding,
436440
ShapedType sourceUnshardedShape,
437441
TypedValue<ShapedType> sourceShard) {
442+
// currently handles only cases where halo sizes differ but everything else
443+
// stays the same (from source to destination sharding)
438444
if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
439445
sourceSharding.getPartialAxes().empty() &&
440446
targetSharding.getPartialAxes().empty() &&
@@ -454,6 +460,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
454460
SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
455461
strides(rank, 1), outShape(sourceShard.getType().getShape()),
456462
coreShape(sourceShard.getType().getShape());
463+
464+
// determine "core" of source and destination
465+
// the core is the local part of the shard excluding halo regions
457466
for (auto i = 0u; i < rank; ++i) {
458467
if (i < splitAxes.size() && !splitAxes[i].empty()) {
459468
if (!srcHaloSizes.empty()) {
@@ -465,6 +474,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
465474
coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
466475
}
467476
}
477+
478+
// extract core from source and copy into destination core
468479
auto noVals = ValueRange{};
469480
auto initVal = builder.create<tensor::EmptyOp>(
470481
sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
@@ -476,6 +487,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
476487
auto initOprnd = builder.create<tensor::InsertSliceOp>(
477488
sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
478489
tgtCoreOffs, coreShape, strides);
490+
491+
// finally update the halo
479492
auto updateHaloResult = builder.create<UpdateHaloOp>(
480493
sourceShard.getLoc(),
481494
RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
@@ -546,10 +559,13 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
546559
MeshSharding targetSharding,
547560
TypedValue<ShapedType> sourceUnshardedValue,
548561
TypedValue<ShapedType> sourceShard) {
562+
// If source and destination sharding are the same, no need to do anything.
549563
if (sourceSharding == targetSharding) {
550564
return sourceShard;
551565
}
552566

567+
// tries to handle the case where the resharding is needed because the halo
568+
// sizes are different. Supports arbitrary mesh dimensionality.
553569
if (auto tryRes = tryUpdateHaloInResharding(
554570
builder, mesh, sourceSharding, targetSharding,
555571
sourceUnshardedValue.getType(), sourceShard)) {

0 commit comments

Comments
 (0)