@@ -429,12 +429,18 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
429
429
return std::nullopt;
430
430
}
431
431
432
+ // Detect a change in the halo size (only) and create necessary operations if
433
+ // needed. A changed halo sizes requires copying the "core" of the source tensor
434
+ // into the "core" of the destination tensor followed by an update halo
435
+ // operation.
432
436
static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
433
437
tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh,
434
438
MeshSharding sourceSharding,
435
439
MeshSharding targetSharding,
436
440
ShapedType sourceUnshardedShape,
437
441
TypedValue<ShapedType> sourceShard) {
442
+ // currently handles only cases where halo sizes differ but everything else
443
+ // stays the same (from source to destination sharding)
438
444
if (sourceSharding.equalSplitAndPartialAxes (targetSharding) &&
439
445
sourceSharding.getPartialAxes ().empty () &&
440
446
targetSharding.getPartialAxes ().empty () &&
@@ -454,6 +460,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
454
460
SmallVector<int64_t > srcCoreOffs (rank, 0 ), tgtCoreOffs (rank, 0 ),
455
461
strides (rank, 1 ), outShape (sourceShard.getType ().getShape ()),
456
462
coreShape (sourceShard.getType ().getShape ());
463
+
464
+ // determine "core" of source and destination
465
+ // the core is the local part of the shard excluding halo regions
457
466
for (auto i = 0u ; i < rank; ++i) {
458
467
if (i < splitAxes.size () && !splitAxes[i].empty ()) {
459
468
if (!srcHaloSizes.empty ()) {
@@ -465,6 +474,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
465
474
coreShape[i] + tgtHaloSizes[i * 2 ] + tgtHaloSizes[i * 2 + 1 ];
466
475
}
467
476
}
477
+
478
+ // extract core from source and copy into destination core
468
479
auto noVals = ValueRange{};
469
480
auto initVal = builder.create <tensor::EmptyOp>(
470
481
sourceShard.getLoc (), outShape, sourceShard.getType ().getElementType ());
@@ -476,6 +487,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
476
487
auto initOprnd = builder.create <tensor::InsertSliceOp>(
477
488
sourceShard.getLoc (), core, initVal, noVals, noVals, noVals,
478
489
tgtCoreOffs, coreShape, strides);
490
+
491
+ // finally update the halo
479
492
auto updateHaloResult = builder.create <UpdateHaloOp>(
480
493
sourceShard.getLoc (),
481
494
RankedTensorType::get (outShape, sourceShard.getType ().getElementType ()),
@@ -546,10 +559,13 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
546
559
MeshSharding targetSharding,
547
560
TypedValue<ShapedType> sourceUnshardedValue,
548
561
TypedValue<ShapedType> sourceShard) {
562
+ // If source and destination sharding are the same, no need to do anything.
549
563
if (sourceSharding == targetSharding) {
550
564
return sourceShard;
551
565
}
552
566
567
+ // tries to handle the case where the resharding is needed because the halo
568
+ // sizes are different. Supports arbitrary mesh dimensionality.
553
569
if (auto tryRes = tryUpdateHaloInResharding (
554
570
builder, mesh, sourceSharding, targetSharding,
555
571
sourceUnshardedValue.getType (), sourceShard)) {
0 commit comments