@@ -230,8 +230,23 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
230
230
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor <1200 x1200 xi64 >
231
231
%sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0 ]] halo_sizes = [2 , 2 ] : !mesh.sharding
232
232
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor <1200 x1200 xi64 >
233
- %sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0 ]] sharded_dims_offsets = [298 , 598 , 898 , 1000 ] : !mesh.sharding
234
233
%sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor <1200 x1200 xi64 >
235
234
// CHECK: return %[[UH]] : tensor<304x1200xi64>
236
235
return %sharding_annotated_3 : tensor <1200 x1200 xi64 >
237
236
}
237
+
238
+ mesh.mesh @mesh4x4 (shape = 4 x4 )
239
+ // CHECK-LABEL: func @test_shard_update_halo2d
240
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
241
+ func.func @test_shard_update_halo2d (%arg0: tensor <1200 x1200 xi64 >) -> tensor <1200 x1200 xi64 > {
242
+ %sharding = mesh.sharding @mesh4x4 split_axes = [[0 ], [1 ]] : !mesh.sharding
243
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
244
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
245
+ // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
246
+ %sharding_annotated = mesh.shard %arg0 to %sharding : tensor <1200 x1200 xi64 >
247
+ %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0 ], [1 ]] halo_sizes = [1 , 2 , 3 , 4 ] : !mesh.sharding
248
+ %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor <1200 x1200 xi64 >
249
+ %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor <1200 x1200 xi64 >
250
+ // CHECK: return %[[UH]] : tensor<303x307xi64>
251
+ return %sharding_annotated_3 : tensor <1200 x1200 xi64 >
252
+ }
0 commit comments