@@ -198,11 +198,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
198
198
llvm::adl_begin (outShape));
199
199
200
200
if (!shardedDimsOffsets.empty ()) {
201
+ auto isDynShape = ShapedType::isDynamicShape (meshShape);
201
202
uint64_t pos = 1 ;
202
203
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate (splitAxes)) {
203
204
if (!innerSplitAxes.empty ()) {
204
205
auto sz = shardedDimsOffsets[pos];
205
- bool same = !ShapedType::isDynamicShape (meshShape) ;
206
+ bool same = !isDynShape ;
206
207
if (same) {
207
208
// Find sharded dims in shardedDimsOffsets with same static size on
208
209
// all devices. Use kDynamic for dimensions with dynamic or
@@ -218,7 +219,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
218
219
break ;
219
220
}
220
221
}
221
- pos += numShards;
222
+ pos += numShards + 1 ;
222
223
}
223
224
outShape[tensorAxis] = same ? sz : ShapedType::kDynamic ;
224
225
}
@@ -544,6 +545,34 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
544
545
return emitError () << " sharded dims offsets are not allowed for "
545
546
" devices meshes with dynamic shape." ;
546
547
}
548
+
549
+ auto shardedDimsOffsets = getStaticShardedDimsOffsets ();
550
+ if (!shardedDimsOffsets.empty ()) {
551
+ auto meshShape = mesh.value ().getShape ();
552
+ assert (!ShapedType::isDynamicShape (meshShape));
553
+ uint64_t pos = 0 ;
554
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate (getSplitAxes ())) {
555
+ if (!innerSplitAxes.empty ()) {
556
+ int64_t numShards = 0 , off = 0 ;
557
+ for (auto i : innerSplitAxes.asArrayRef ()) {
558
+ numShards += meshShape[i];
559
+ }
560
+ for (int64_t i = 0 ; i <= numShards; ++i) {
561
+ if (shardedDimsOffsets.size () <= pos + i) {
562
+ return emitError () << " sharded dims offsets has wrong size." ;
563
+ }
564
+ if (!ShapedType::isDynamic (shardedDimsOffsets[pos + i])) {
565
+ if (shardedDimsOffsets[pos + i] < off) {
566
+ return emitError ()
567
+ << " sharded dims offsets must be non-decreasing." ;
568
+ }
569
+ off = shardedDimsOffsets[pos + i];
570
+ }
571
+ }
572
+ pos += numShards + 1 ;
573
+ }
574
+ }
575
+ }
547
576
return success ();
548
577
}
549
578
0 commit comments