@@ -115,34 +115,30 @@ func.func @update_halo_3d(
115
115
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
116
116
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
117
117
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
118
- // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
119
118
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
120
119
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
121
- // CHECK-NEXT: mpi.send([[vcast ]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8 >, i32, i32
122
- // CHECK-NEXT: mpi.recv([[vcast ]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8 >, i32, i32
120
+ // CHECK-NEXT: mpi.send([[valloc ]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8 >, i32, i32
121
+ // CHECK-NEXT: mpi.recv([[valloc ]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8 >, i32, i32
123
122
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
124
123
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
125
124
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
126
125
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
127
- // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
128
126
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
129
127
// CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
130
- // CHECK-NEXT: mpi.send([[vcast_2 ]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8 >, i32, i32
131
- // CHECK-NEXT: mpi.recv([[vcast_2 ]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8 >, i32, i32
128
+ // CHECK-NEXT: mpi.send([[valloc_1 ]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8 >, i32, i32
129
+ // CHECK-NEXT: mpi.recv([[valloc_1 ]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8 >, i32, i32
132
130
// CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
133
131
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
134
132
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
135
133
// CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
136
- // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
137
- // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
134
+ // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
138
135
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
139
136
// CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
140
137
// CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
141
138
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
142
- // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
143
139
// CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
144
140
// CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
145
- // CHECK-NEXT: mpi.send([[vcast_9 ]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8 >, i32, i32
141
+ // CHECK-NEXT: mpi.send([[valloc_8 ]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8 >, i32, i32
146
142
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
147
143
// CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
148
144
// CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
@@ -170,34 +166,30 @@ func.func @update_halo_3d_tensor(
170
166
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
171
167
// CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8>
172
168
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
173
- // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
174
169
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
175
170
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
176
- // CHECK-NEXT: mpi.send([[vcast ]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8 >, i32, i32
177
- // CHECK-NEXT: mpi.recv([[vcast ]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8 >, i32, i32
171
+ // CHECK-NEXT: mpi.send([[valloc ]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8 >, i32, i32
172
+ // CHECK-NEXT: mpi.recv([[valloc ]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8 >, i32, i32
178
173
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
179
174
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
180
175
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
181
176
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
182
- // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
183
177
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
184
178
// CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
185
- // CHECK-NEXT: mpi.send([[vcast_2 ]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8 >, i32, i32
186
- // CHECK-NEXT: mpi.recv([[vcast_2 ]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8 >, i32, i32
179
+ // CHECK-NEXT: mpi.send([[valloc_1 ]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8 >, i32, i32
180
+ // CHECK-NEXT: mpi.recv([[valloc_1 ]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8 >, i32, i32
187
181
// CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
188
182
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
189
183
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
190
184
// CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
191
- // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
192
- // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
185
+ // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
193
186
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
194
187
// CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
195
188
// CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
196
189
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
197
- // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
198
190
// CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
199
191
// CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
200
- // CHECK-NEXT: mpi.send([[vcast_9 ]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8 >, i32, i32
192
+ // CHECK-NEXT: mpi.send([[valloc_8 ]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8 >, i32, i32
201
193
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
202
194
// CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
203
195
// CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
@@ -209,7 +201,7 @@ func.func @update_halo_3d_tensor(
209
201
// CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
210
202
// CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
211
203
// CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
212
- // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] : memref<120x120x120xi8>
204
+ // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
213
205
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2 ], [1 ], [0 ]] halo_sizes = [1 , 2 , 3 , 4 , 5 , 6 ] : tensor <120 x120 x120 xi8 >
214
206
// CHECK: return [[v1]] : tensor<120x120x120xi8>
215
207
return %res : tensor <120 x120 x120 xi8 >
0 commit comments