@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
593
593
Operation *definingOp = operand.getDefiningOp ();
594
594
assert (definingOp);
595
595
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
596
- assert (shardOp.getAnnotateForUsers ());
597
596
return shardOp.getShard ();
598
597
});
599
598
return res;
@@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
615
614
assert (result.hasOneUse ());
616
615
Operation *userOp = *result.getUsers ().begin ();
617
616
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
618
- assert (!shardOp.getAnnotateForUsers ());
619
617
return shardOp.getShard ();
620
618
});
621
619
return res;
622
620
}
623
621
624
622
static LogicalResult
625
- spmdizeOperation (Operation &op , IRMapping &spmdizationMap,
623
+ spmdizeOperation (ShardOp shardOp , IRMapping &spmdizationMap,
626
624
SymbolTableCollection &symbolTableCollection,
627
625
OpBuilder &builder) {
628
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
629
- if (shardOp) {
630
- if (!shardOp.getAnnotateForUsers ()) {
631
- return success ();
632
- }
633
-
626
+ Value targetSpmdValue;
627
+
628
+ // Check if 2 shard ops are chained. If not there is no need for resharding
629
+ // as the source and target shared the same sharding.
630
+ ShardOp srcShardOp =
631
+ dyn_cast_or_null<ShardOp>(shardOp.getOperand ().getDefiningOp ());
632
+ if (!srcShardOp) {
633
+ targetSpmdValue = spmdizationMap.lookup (shardOp.getOperand ());
634
+ } else {
634
635
// Insert resharding.
635
- ShardOp srcShardOp =
636
- llvm::cast<ShardOp>(shardOp.getOperand ().getDefiningOp ());
637
- assert (!srcShardOp.getAnnotateForUsers ());
636
+ assert (!srcShardOp.getAnnotateForUsers () && shardOp.getAnnotateForUsers ());
638
637
TypedValue<ShapedType> srcSpmdValue =
639
638
spmdizationMap.lookup (srcShardOp.getOperand ())
640
639
.cast <TypedValue<ShapedType>>();
641
- Value targetSpmdValue = reshard (builder, srcShardOp, shardOp, srcSpmdValue,
642
- symbolTableCollection);
643
- assert (!spmdizationMap.contains (shardOp.getResult ()));
644
- spmdizationMap.map (shardOp.getResult (), targetSpmdValue);
645
- return success ();
640
+ targetSpmdValue = reshard (builder, srcShardOp, shardOp, srcSpmdValue,
641
+ symbolTableCollection);
642
+ }
643
+
644
+ assert (!spmdizationMap.contains (shardOp.getResult ()));
645
+ spmdizationMap.map (shardOp.getResult (), targetSpmdValue);
646
+ return success ();
647
+ }
648
+
649
+ static LogicalResult
650
+ spmdizeOperation (Operation &op, IRMapping &spmdizationMap,
651
+ SymbolTableCollection &symbolTableCollection,
652
+ OpBuilder &builder) {
653
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
654
+ if (shardOp) {
655
+ return spmdizeOperation (shardOp, spmdizationMap, symbolTableCollection,
656
+ builder);
646
657
}
647
658
648
659
SmallVector<Value> spmdizedOperands;
0 commit comments