Skip to content

Commit a0aa3eb

Browse files
committed
adding test update_halo 2d
1 parent 1cce571 commit a0aa3eb

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,23 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
230230
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
231231
%sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
232232
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
233-
%sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [298, 598, 898, 1000] : !mesh.sharding
234233
%sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
235234
// CHECK: return %[[UH]] : tensor<304x1200xi64>
236235
return %sharding_annotated_3 : tensor<1200x1200xi64>
237236
}
237+
238+
mesh.mesh @mesh4x4(shape = 4x4)
239+
// CHECK-LABEL: func @test_shard_update_halo2d
240+
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
241+
func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
242+
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
243+
// CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
244+
// CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
245+
// CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
246+
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
247+
%sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
248+
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
249+
%sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
250+
// CHECK: return %[[UH]] : tensor<303x307xi64>
251+
return %sharding_annotated_3 : tensor<1200x1200xi64>
252+
}

0 commit comments

Comments
 (0)