@@ -489,18 +489,22 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
489
489
tgtCoreOffs, coreShape, strides);
490
490
491
491
// finally update the halo
492
- auto updateHaloResult = builder.create <UpdateHaloOp>(
493
- sourceShard.getLoc (),
494
- RankedTensorType::get (outShape, sourceShard.getType ().getElementType ()),
495
- sourceShard, initOprnd, mesh.getSymName (),
496
- MeshAxesArrayAttr::get (builder.getContext (),
497
- sourceSharding.getSplitAxes ()),
498
- sourceSharding.getDynamicHaloSizes (),
499
- sourceSharding.getStaticHaloSizes (),
500
- targetSharding.getDynamicHaloSizes (),
501
- targetSharding.getStaticHaloSizes ()).getResult ();
502
- return std::make_tuple (
503
- cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
492
+ auto updateHaloResult =
493
+ builder
494
+ .create <UpdateHaloOp>(
495
+ sourceShard.getLoc (),
496
+ RankedTensorType::get (outShape,
497
+ sourceShard.getType ().getElementType ()),
498
+ sourceShard, initOprnd, mesh.getSymName (),
499
+ MeshAxesArrayAttr::get (builder.getContext (),
500
+ sourceSharding.getSplitAxes ()),
501
+ sourceSharding.getDynamicHaloSizes (),
502
+ sourceSharding.getStaticHaloSizes (),
503
+ targetSharding.getDynamicHaloSizes (),
504
+ targetSharding.getStaticHaloSizes ())
505
+ .getResult ();
506
+ return std::make_tuple (cast<TypedValue<ShapedType>>(updateHaloResult),
507
+ targetSharding);
504
508
}
505
509
return std::nullopt;
506
510
}
@@ -725,8 +729,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
725
729
targetSpmdValue = spmdizationMap.lookup (shardOp.getSrc ());
726
730
} else {
727
731
// Insert resharding.
728
- TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
729
- spmdizationMap.lookup (srcShardOp));
732
+ TypedValue<ShapedType> srcSpmdValue =
733
+ cast<TypedValue<ShapedType>>( spmdizationMap.lookup (srcShardOp));
730
734
targetSpmdValue = reshard (builder, srcShardOp, shardOp, srcSpmdValue,
731
735
symbolTableCollection);
732
736
}
0 commit comments