Skip to content

Commit 1cce571

Browse files
committed
fixing resharding
1 parent 5e4cada commit 1cce571

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
476476
auto initOprnd = builder.create<tensor::InsertSliceOp>(
477477
sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
478478
tgtCoreOffs, coreShape, strides);
479-
auto targetShard = builder.create<UpdateHaloOp>(
479+
auto updateHaloResult = builder.create<UpdateHaloOp>(
480480
sourceShard.getLoc(),
481481
RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
482482
sourceShard, initOprnd, mesh.getSymName(),
@@ -485,9 +485,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
485485
sourceSharding.getDynamicHaloSizes(),
486486
sourceSharding.getStaticHaloSizes(),
487487
targetSharding.getDynamicHaloSizes(),
488-
targetSharding.getStaticHaloSizes());
488+
targetSharding.getStaticHaloSizes()).getResult();
489489
return std::make_tuple(
490-
cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
490+
cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
491491
}
492492
return std::nullopt;
493493
}
@@ -710,7 +710,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
710710
} else {
711711
// Insert resharding.
712712
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
713-
spmdizationMap.lookup(srcShardOp.getSrc()));
713+
spmdizationMap.lookup(srcShardOp));
714714
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
715715
symbolTableCollection);
716716
}

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,19 @@ func.func @ew_chain_with_halo(
219219
// CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
220220
return %sharding_annotated_6 : tensor<8x16xf32>
221221
}
222+
223+
// CHECK-LABEL: func @test_shard_update_halo
224+
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
225+
func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
226+
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
227+
// CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
228+
// CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
229+
// CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
230+
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
231+
%sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
232+
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
233+
%sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [298, 598, 898, 1000] : !mesh.sharding
234+
%sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
235+
// CHECK: return %[[UH]] : tensor<304x1200xi64>
236+
return %sharding_annotated_3 : tensor<1200x1200xi64>
237+
}

0 commit comments

Comments
 (0)