Skip to content

Commit b6ff1a8

Browse files
committed
checking for non-decreasing dims sizes in sharding
1 parent 07887db commit b6ff1a8

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
198198
llvm::adl_begin(outShape));
199199

200200
if (!shardedDimsOffsets.empty()) {
201+
auto isDynShape = ShapedType::isDynamicShape(meshShape);
201202
uint64_t pos = 1;
202203
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
203204
if (!innerSplitAxes.empty()) {
204205
auto sz = shardedDimsOffsets[pos];
205-
bool same = !ShapedType::isDynamicShape(meshShape);
206+
bool same = !isDynShape;
206207
if (same) {
207208
// Find sharded dims in shardedDimsOffsets with same static size on
208209
// all devices. Use kDynamic for dimensions with dynamic or
@@ -218,7 +219,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
218219
break;
219220
}
220221
}
221-
pos += numShards;
222+
pos += numShards + 1;
222223
}
223224
outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
224225
}
@@ -544,6 +545,34 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
544545
return emitError() << "sharded dims offsets are not allowed for "
545546
"devices meshes with dynamic shape.";
546547
}
548+
549+
auto shardedDimsOffsets = getStaticShardedDimsOffsets();
550+
if (!shardedDimsOffsets.empty()) {
551+
auto meshShape = mesh.value().getShape();
552+
assert(!ShapedType::isDynamicShape(meshShape));
553+
uint64_t pos = 0;
554+
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
555+
if (!innerSplitAxes.empty()) {
556+
int64_t numShards = 0, off = 0;
557+
for (auto i : innerSplitAxes.asArrayRef()) {
558+
numShards += meshShape[i];
559+
}
560+
for (int64_t i = 0; i <= numShards; ++i) {
561+
if (shardedDimsOffsets.size() <= pos + i) {
562+
return emitError() << "sharded dims offsets has wrong size.";
563+
}
564+
if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
565+
if (shardedDimsOffsets[pos + i] < off) {
566+
return emitError()
567+
<< "sharded dims offsets must be non-decreasing.";
568+
}
569+
off = shardedDimsOffsets[pos + i];
570+
}
571+
}
572+
pos += numShards + 1;
573+
}
574+
}
575+
}
547576
return success();
548577
}
549578

mlir/test/Dialect/Mesh/canonicalization.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ func.func @test_halo_sizes() -> !mesh.sharding {
204204
// CHECK-LABEL: func @test_shard_offs
205205
func.func @test_shard_offs() -> !mesh.sharding {
206206
%c2_i64 = arith.constant 2 : i64
207-
// CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 0, 2, 22] : !mesh.sharding
208-
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 0, %c2_i64, 22] : !mesh.sharding
207+
// CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
208+
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
209209
return %sharding : !mesh.sharding
210210
}

mlir/test/Dialect/Mesh/invalid.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,26 @@ func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
107107

108108
// -----
109109

110+
mesh.mesh @mesh0(shape = 2x4)
111+
func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
112+
// expected-error@+1 {{sharded dims offsets has wrong size}}
113+
%s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
114+
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
115+
return
116+
}
117+
118+
// -----
119+
120+
mesh.mesh @mesh0(shape = 4)
121+
func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
122+
// expected-error@+1 {{sharded dims offsets must be non-decreasing}}
123+
%s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
124+
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
125+
return
126+
}
127+
128+
// -----
129+
110130
mesh.mesh @mesh0(shape = 2x4)
111131

112132
func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {

0 commit comments

Comments
 (0)