@@ -693,10 +693,7 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
693
693
694
694
if (getPartialAxes ().size () != rhs.getPartialAxes ().size () ||
695
695
(!getPartialAxes ().empty () && getPartialType () != rhs.getPartialType ()) ||
696
- !llvm::equal (
697
- llvm::make_range (getPartialAxes ().begin (), getPartialAxes ().end ()),
698
- llvm::make_range (rhs.getPartialAxes ().begin (),
699
- rhs.getPartialAxes ().end ()))) {
696
+ !llvm::equal (getPartialAxes (), rhs.getPartialAxes ())) {
700
697
return false ;
701
698
}
702
699
@@ -708,11 +705,9 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
708
705
return false ;
709
706
}
710
707
711
- return llvm::all_of (llvm::make_range (getSplitAxes ().begin () + minSize,
712
- getSplitAxes ().end ()),
708
+ return llvm::all_of (llvm::drop_begin (getSplitAxes (), minSize),
713
709
std::mem_fn (&MeshAxesAttr::empty)) &&
714
- llvm::all_of (llvm::make_range (rhs.getSplitAxes ().begin () + minSize,
715
- rhs.getSplitAxes ().end ()),
710
+ llvm::all_of (llvm::drop_begin (rhs.getSplitAxes (), minSize),
716
711
std::mem_fn (&MeshAxesAttr::empty));
717
712
}
718
713
@@ -723,37 +718,26 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
723
718
bool MeshSharding::equalShardSizes (const MeshSharding &rhs) const {
724
719
if (rhs.getStaticShardedDimsOffsets ().size () !=
725
720
getStaticShardedDimsOffsets ().size () ||
726
- !llvm::equal (llvm::make_range (getStaticShardedDimsOffsets ().begin (),
727
- getStaticShardedDimsOffsets ().end ()),
728
- llvm::make_range (rhs.getStaticShardedDimsOffsets ().begin (),
729
- rhs.getStaticShardedDimsOffsets ().end ()))) {
721
+ !llvm::equal (getStaticShardedDimsOffsets (),
722
+ rhs.getStaticShardedDimsOffsets ())) {
730
723
return false ;
731
724
}
732
725
if (rhs.getDynamicShardedDimsOffsets ().size () !=
733
726
getDynamicShardedDimsOffsets ().size () ||
734
- !llvm::equal (
735
- llvm::make_range (getDynamicShardedDimsOffsets ().begin (),
736
- getDynamicShardedDimsOffsets ().end ()),
737
- llvm::make_range (rhs.getDynamicShardedDimsOffsets ().begin (),
738
- rhs.getDynamicShardedDimsOffsets ().end ()))) {
727
+ !llvm::equal (getDynamicShardedDimsOffsets (),
728
+ rhs.getDynamicShardedDimsOffsets ())) {
739
729
return false ;
740
730
}
741
731
return true ;
742
732
}
743
733
744
734
bool MeshSharding::equalHaloSizes (const MeshSharding &rhs) const {
745
735
if (rhs.getStaticHaloSizes ().size () != getStaticHaloSizes ().size () ||
746
- !llvm::equal (llvm::make_range (getStaticHaloSizes ().begin (),
747
- getStaticHaloSizes ().end ()),
748
- llvm::make_range (rhs.getStaticHaloSizes ().begin (),
749
- rhs.getStaticHaloSizes ().end ()))) {
736
+ !llvm::equal (getStaticHaloSizes (), rhs.getStaticHaloSizes ())) {
750
737
return false ;
751
738
}
752
739
if (rhs.getDynamicHaloSizes ().size () != getDynamicHaloSizes ().size () ||
753
- !llvm::equal (llvm::make_range (getDynamicHaloSizes ().begin (),
754
- getDynamicHaloSizes ().end ()),
755
- llvm::make_range (rhs.getDynamicHaloSizes ().begin (),
756
- rhs.getDynamicHaloSizes ().end ()))) {
740
+ !llvm::equal (getDynamicHaloSizes (), rhs.getDynamicHaloSizes ())) {
757
741
return false ;
758
742
}
759
743
return true ;
0 commit comments