Skip to content

Commit 1a49810

Browse files
authored
[mlir][ArmSME][test] Prepare tests for tile allocation changes (#91358)
This patch: 1. Removes some duplicate test cases 2. Removes unnecessary uses of `-convert-arm-sme-to-llvm` 3. Ensures tile values have uses via `test.some_use()` 1 and 2 will make these tests easier to update. 3 will be needed as ArmSME operations will be pure.
1 parent 602df27 commit 1a49810

File tree

7 files changed

+234
-165
lines changed

7 files changed

+234
-165
lines changed

mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
2525
%c0 = arith.constant 0 : index
2626
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
2727
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
28+
"test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
2829
return
2930
}
3031

@@ -36,6 +37,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
3637
%c0 = arith.constant 0 : index
3738
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
3839
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
40+
"test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
3941
return
4042
}
4143

@@ -47,6 +49,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
4749
%c0 = arith.constant 0 : index
4850
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
4951
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
52+
"test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
5053
return
5154
}
5255

@@ -58,6 +61,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
5861
%c0 = arith.constant 0 : index
5962
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
6063
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
64+
"test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
6165
return
6266
}
6367

@@ -69,6 +73,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
6973
%c0 = arith.constant 0 : index
7074
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
7175
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
76+
"test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
7277
return
7378
}
7479

@@ -80,6 +85,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
8085
%c0 = arith.constant 0 : index
8186
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
8287
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
88+
"test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
8389
return
8490
}
8591

@@ -91,6 +97,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
9197
%c0 = arith.constant 0 : index
9298
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
9399
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
100+
"test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
94101
return
95102
}
96103

@@ -102,6 +109,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
102109
%c0 = arith.constant 0 : index
103110
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
104111
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
112+
"test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
105113
return
106114
}
107115

@@ -113,6 +121,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
113121
%c0 = arith.constant 0 : index
114122
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
115123
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
124+
"test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
116125
return
117126
}
118127

@@ -124,6 +133,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
124133
%c0 = arith.constant 0 : index
125134
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
126135
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
136+
"test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
127137
return
128138
}
129139

@@ -135,6 +145,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
135145
%c0 = arith.constant 0 : index
136146
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
137147
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
148+
"test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
138149
return
139150
}
140151

@@ -146,6 +157,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
146157
%c0 = arith.constant 0 : index
147158
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
148159
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
160+
"test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
149161
return
150162
}
151163

@@ -157,6 +169,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
157169
%c0 = arith.constant 0 : index
158170
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
159171
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
172+
"test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
160173
return
161174
}
162175

@@ -168,6 +181,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
168181
%c0 = arith.constant 0 : index
169182
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
170183
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
184+
"test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
171185
return
172186
}
173187

@@ -179,6 +193,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
179193
%c0 = arith.constant 0 : index
180194
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
181195
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
196+
"test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
182197
return
183198
}
184199

@@ -190,6 +205,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
190205
%c0 = arith.constant 0 : index
191206
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
192207
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
208+
"test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
193209
return
194210
}
195211

@@ -201,6 +217,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
201217
%c0 = arith.constant 0 : index
202218
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
203219
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
220+
"test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
204221
return
205222
}
206223

@@ -212,6 +229,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
212229
%c0 = arith.constant 0 : index
213230
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
214231
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
232+
"test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
215233
return
216234
}
217235

@@ -441,7 +459,8 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : v
441459
func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
442460
%c0 = arith.constant 0 : index
443461
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
444-
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
462+
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
463+
"test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
445464
return
446465
}
447466

@@ -452,7 +471,8 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
452471
func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
453472
%c0 = arith.constant 0 : index
454473
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
455-
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
474+
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
475+
"test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
456476
return
457477
}
458478

mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs :
99
// expected-error@+2 {{failed to legalize operation 'arm_sme.outerproduct'}}
1010
// expected-error@+1 {{unsupported type}}
1111
%0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
12-
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
12+
"test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
1313
}
1414

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
2121
%c0 = arith.constant 0 : index
2222
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
23+
"test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
2324
return
2425
}
2526

@@ -30,6 +31,7 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
3031
func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
3132
%c0 = arith.constant 0 : index
3233
%tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
34+
"test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
3335
return
3436
}
3537

@@ -60,6 +62,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
6062
%pad = arith.constant 0 : i32
6163
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
6264
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
65+
"test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
6366
return
6467
}
6568

@@ -94,6 +97,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
9497
%c3 = arith.constant 3 : index
9598
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
9699
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
100+
"test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
97101
return
98102
}
99103

@@ -104,6 +108,7 @@ func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32
104108
%pad = arith.constant 0 : i32
105109
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
106110
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
111+
"test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
107112
return
108113
}
109114

@@ -113,6 +118,7 @@ func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?x
113118
%c0 = arith.constant 0 : index
114119
// expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
115120
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
121+
"test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
116122
return
117123
}
118124

0 commit comments

Comments
 (0)