@@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings(
539
539
if (std::size (values) != std::size (shardings)) {
540
540
return false ;
541
541
}
542
- return llvm::all_of (llvm::zip (std::forward<ValueRange>(values),
543
- std::forward<MeshShardingAttrRage>(shardings)),
542
+ return llvm::all_of (llvm::zip_equal (
543
+ std::forward<ValueRange>(values),
544
+ std::forward<MeshShardingAttrRage>(shardings)),
544
545
[](auto valueAndSharding) {
545
546
return isValueCompatibleWithFullReplicationSharding (
546
547
std::get<0 >(valueAndSharding),
@@ -588,11 +589,9 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
588
589
SmallVector<MeshShardingAttr> operatorAndResultShardings;
589
590
operatorAndResultShardings.reserve (operandShardings.size () +
590
591
resultShardings.size ());
591
- operatorAndResultShardings.insert (operatorAndResultShardings.end (),
592
- operandShardings.begin (),
593
- operandShardings.end ());
592
+ llvm::append_range (operatorAndResultShardings, operandShardings);
594
593
for (auto [sharding, affineMap] :
595
- llvm::zip (operatorAndResultShardings, indexingMaps)) {
594
+ llvm::zip_equal (operatorAndResultShardings, indexingMaps)) {
596
595
if (!sharding) {
597
596
continue ;
598
597
}
@@ -602,6 +601,12 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
602
601
meshAxesAssignmentForTensorAxis.asArrayRef (), indexingExpr,
603
602
meshAxisAssignmentForLoopIterators);
604
603
}
604
+ // Missing trailing split axes means replication on those tensor dimensions.
605
+ for (unsigned i = sharding.getSplitAxes ().size ();
606
+ i < affineMap.getNumResults (); ++i) {
607
+ updateMeshAxisAssignmentForLoopIterators (
608
+ {}, affineMap.getResults ()[i], meshAxisAssignmentForLoopIterators);
609
+ }
605
610
}
606
611
607
612
ShardingArray res;
@@ -619,7 +624,7 @@ bool mesh::isAtLeastOneReductionIteratorSharded(
619
624
ArrayRef<utils::IteratorType> loopIteratorTypes,
620
625
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
621
626
for (auto [loopIteratorType, meshAxisAssignment] :
622
- llvm::zip (loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
627
+ llvm::zip_equal (loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
623
628
if (loopIteratorType == utils::IteratorType::reduction &&
624
629
!meshAxisAssignment.empty ()) {
625
630
return true ;
@@ -633,10 +638,9 @@ SmallVector<MeshAxis> mesh::getReductionMeshAxes(
633
638
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
634
639
SmallVector<MeshAxis> meshAxes;
635
640
for (auto [loopIteratorType, meshAxisAssignment] :
636
- llvm::zip (loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
641
+ llvm::zip_equal (loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
637
642
if (loopIteratorType == utils::IteratorType::reduction) {
638
- meshAxes.insert (meshAxes.end (), meshAxisAssignment.begin (),
639
- meshAxisAssignment.end ());
643
+ llvm::append_range (meshAxes, meshAxisAssignment);
640
644
}
641
645
}
642
646
return meshAxes;
@@ -651,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation(
651
655
Operation *newOp = builder.clone (op, spmdizationMap);
652
656
// Set the result types to the sharded counterparts.
653
657
for (auto [oldResult, newResult, sharding] :
654
- llvm::zip (op.getResults (), newOp->getResults (), resultShardings)) {
658
+ llvm::zip_equal (op.getResults (), newOp->getResults (), resultShardings)) {
655
659
newResult.setType (shardType (newResult.getType (),
656
660
getMesh (&op, sharding.getMesh (), symbolTable),
657
661
sharding));
0 commit comments