@@ -476,7 +476,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
476
476
auto initOprnd = builder.create <tensor::InsertSliceOp>(
477
477
sourceShard.getLoc (), core, initVal, noVals, noVals, noVals,
478
478
tgtCoreOffs, coreShape, strides);
479
- auto targetShard = builder.create <UpdateHaloOp>(
479
+ auto updateHaloResult = builder.create <UpdateHaloOp>(
480
480
sourceShard.getLoc (),
481
481
RankedTensorType::get (outShape, sourceShard.getType ().getElementType ()),
482
482
sourceShard, initOprnd, mesh.getSymName (),
@@ -485,9 +485,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
485
485
sourceSharding.getDynamicHaloSizes (),
486
486
sourceSharding.getStaticHaloSizes (),
487
487
targetSharding.getDynamicHaloSizes (),
488
- targetSharding.getStaticHaloSizes ());
488
+ targetSharding.getStaticHaloSizes ()). getResult () ;
489
489
return std::make_tuple (
490
- cast<TypedValue<ShapedType>>(targetShard. getResult () ), targetSharding);
490
+ cast<TypedValue<ShapedType>>(updateHaloResult ), targetSharding);
491
491
}
492
492
return std::nullopt;
493
493
}
@@ -710,7 +710,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
710
710
} else {
711
711
// Insert resharding.
712
712
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
713
- spmdizationMap.lookup (srcShardOp. getSrc () ));
713
+ spmdizationMap.lookup (srcShardOp));
714
714
targetSpmdValue = reshard (builder, srcShardOp, shardOp, srcSpmdValue,
715
715
symbolTableCollection);
716
716
}
0 commit comments