Skip to content

Commit f0ce235

Browse files
authored
[mlir][ArmSME][NFC] Move conversion tests (llvm#75446)
* Move -vector-to-arm-sme tests to mlir/test/Conversion/VectorToArmSME * Move -arm-sme-to-llvm tests to mlir/test/Conversion/ArmSMEToLLVM * Separate unsupported tests.
1 parent 0e06694 commit f0ce235

File tree

5 files changed

+190
-175
lines changed

5 files changed

+190
-175
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics
2+
3+
//===----------------------------------------------------------------------===//
4+
// arm_sme.outerproduct
5+
//===----------------------------------------------------------------------===//
6+
7+
func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
8+
%acc = arm_sme.get_tile : vector<[16]x[16]xi8>
9+
// expected-error@+2 {{failed to legalize operation 'arm_sme.outerproduct'}}
10+
// expected-error@+1 {{unsupported type}}
11+
%0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
12+
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
13+
}
14+
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// vector.transfer_read
5+
//===----------------------------------------------------------------------===//
6+
7+
// CHECK-LABEL: @transfer_read_2d__bad_type
8+
// CHECK-NOT: arm_sme.tile_load
9+
// CHECK: vector.transfer_read
10+
func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
11+
%c0 = arith.constant 0 : index
12+
%pad = arith.constant 0.0 : f64
13+
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
14+
"prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
15+
return
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: @transfer_read_2d__non_memref_type
21+
// CHECK-NOT: arm_sme.tile_load
22+
// CHECK: vector.transfer_read
23+
func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
24+
%c0 = arith.constant 0 : index
25+
%pad = arith.constant 0.0 : f64
26+
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
27+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
28+
return
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
34+
// CHECK-NOT: arm_sme.tile_load
35+
// CHECK: vector.transfer_read
36+
func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
37+
%c0 = arith.constant 0 : index
38+
%pad = arith.constant 0.0 : f64
39+
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
40+
"prevent.dce"(%0) : (vector<[2]xf64>) -> ()
41+
return
42+
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: @transfer_read_2d__non_transpose
47+
// CHECK-NOT: arm_sme.tile_load
48+
// CHECK: vector.transfer_read
49+
func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
50+
%c0 = arith.constant 0 : index
51+
%pad = arith.constant 0.0 : f64
52+
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
53+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
54+
return
55+
}
56+
57+
// -----
58+
59+
// CHECK-LABEL: @transfer_read_2d__out_of_bounds
60+
// CHECK-NOT: arm_sme.tile_load
61+
// CHECK: vector.transfer_read
62+
func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
63+
%c0 = arith.constant 0 : index
64+
%pad = arith.constant 0.0 : f64
65+
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
66+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
67+
return
68+
}
69+
70+
//===----------------------------------------------------------------------===//
71+
// vector.transfer_write
72+
//===----------------------------------------------------------------------===//
73+
74+
// -----
75+
76+
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
77+
// lowering only occurs for vector types of correct rank, shape, element size
78+
// and number of scalable dims.
79+
80+
// CHECK-LABEL: @transfer_write_2d_zero__bad_type
81+
// CHECK: vector.transfer_write
82+
// CHECK-NOT: arm_sme.intr.zero
83+
func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
84+
%c0 = arith.constant 0 : index
85+
%cst = arith.constant dense<0> : vector<[16]x[16]xi4>
86+
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
87+
return
88+
}
89+
90+
// -----
91+
92+
// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
93+
// CHECK: vector.transfer_write
94+
// CHECK-NOT: arm_sme.tile_store
95+
func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
96+
%c0 = arith.constant 0 : index
97+
%cst = arith.constant dense<0> : vector<[8]x[8]xi8>
98+
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
99+
return
100+
}
101+
102+
// -----
103+
104+
// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
105+
// CHECK: vector.transfer_write
106+
// CHECK-NOT: arm_sme.tile_store
107+
func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
108+
%c0 = arith.constant 0 : index
109+
%cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
110+
vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
111+
return
112+
}
113+
114+
// -----
115+
116+
// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
117+
// CHECK: vector.transfer_write
118+
// CHECK-NOT: arm_sme.tile_store
119+
func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
120+
%c0 = arith.constant 0 : index
121+
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
122+
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
123+
return %0 : tensor<?x?xi8>
124+
}
125+
126+
// -----
127+
128+
// CHECK-LABEL: @transfer_write_2d__fixed
129+
// CHECK: vector.transfer_write
130+
// CHECK-NOT: arm_sme.tile_store
131+
func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
132+
%c0 = arith.constant 0 : index
133+
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
134+
return
135+
}
136+
137+
// -----
138+
139+
// CHECK-LABEL: @transfer_write_2d__out_of_bounds
140+
// CHECK: vector.transfer_write
141+
// CHECK-NOT: arm_sme.tile_store
142+
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
143+
%c0 = arith.constant 0 : index
144+
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
145+
return
146+
}
147+
148+
//===----------------------------------------------------------------------===//
149+
// vector.outerproduct
150+
//===----------------------------------------------------------------------===//
151+
152+
// -----
153+
154+
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
155+
// expected-error@+1 {{AXPY operations not supported}}
156+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
157+
return %0 : vector<[2]xf64>
158+
}
159+
160+
// -----
161+
162+
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
163+
%acc = arm_sme.get_tile : vector<[2]x[2]xf64>
164+
// expected-error@+1 {{unsupported kind}}
165+
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
166+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
167+
}
168+
169+
// -----
170+
171+
func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
172+
// CHECK: vector.outerproduct
173+
%acc = arm_sme.get_tile : vector<[4]x[4]xf32>
174+
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
175+
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
176+
}

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir renamed to mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 0 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -150,71 +150,6 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
150150

151151
// -----
152152

153-
// CHECK-LABEL: @transfer_read_2d__bad_type
154-
// CHECK-NOT: arm_sme.tile_load
155-
// CHECK: vector.transfer_read
156-
func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
157-
%c0 = arith.constant 0 : index
158-
%pad = arith.constant 0.0 : f64
159-
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
160-
"prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
161-
return
162-
}
163-
164-
// -----
165-
166-
// CHECK-LABEL: @transfer_read_2d__non_memref_type
167-
// CHECK-NOT: arm_sme.tile_load
168-
// CHECK: vector.transfer_read
169-
func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
170-
%c0 = arith.constant 0 : index
171-
%pad = arith.constant 0.0 : f64
172-
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
173-
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
174-
return
175-
}
176-
177-
// -----
178-
179-
// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
180-
// CHECK-NOT: arm_sme.tile_load
181-
// CHECK: vector.transfer_read
182-
func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
183-
%c0 = arith.constant 0 : index
184-
%pad = arith.constant 0.0 : f64
185-
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
186-
"prevent.dce"(%0) : (vector<[2]xf64>) -> ()
187-
return
188-
}
189-
190-
// -----
191-
192-
// CHECK-LABEL: @transfer_read_2d__non_transpose
193-
// CHECK-NOT: arm_sme.tile_load
194-
// CHECK: vector.transfer_read
195-
func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
196-
%c0 = arith.constant 0 : index
197-
%pad = arith.constant 0.0 : f64
198-
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
199-
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
200-
return
201-
}
202-
203-
// -----
204-
205-
// CHECK-LABEL: @transfer_read_2d__out_of_bounds
206-
// CHECK-NOT: arm_sme.tile_load
207-
// CHECK: vector.transfer_read
208-
func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
209-
%c0 = arith.constant 0 : index
210-
%pad = arith.constant 0.0 : f64
211-
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
212-
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
213-
return
214-
}
215-
216-
// -----
217-
218153
//===----------------------------------------------------------------------===//
219154
// vector.transfer_write
220155
//===----------------------------------------------------------------------===//
@@ -366,80 +301,6 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
366301
return
367302
}
368303

369-
// -----
370-
371-
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
372-
// lowering only occurs for vector types of correct rank, shape, element size
373-
// and number of scalable dims.
374-
375-
// CHECK-LABEL: @transfer_write_2d_zero__bad_type
376-
// CHECK: vector.transfer_write
377-
// CHECK-NOT: arm_sme.intr.zero
378-
func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
379-
%c0 = arith.constant 0 : index
380-
%cst = arith.constant dense<0> : vector<[16]x[16]xi4>
381-
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
382-
return
383-
}
384-
385-
// -----
386-
387-
// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
388-
// CHECK: vector.transfer_write
389-
// CHECK-NOT: arm_sme.tile_store
390-
func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
391-
%c0 = arith.constant 0 : index
392-
%cst = arith.constant dense<0> : vector<[8]x[8]xi8>
393-
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
394-
return
395-
}
396-
397-
// -----
398-
399-
// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
400-
// CHECK: vector.transfer_write
401-
// CHECK-NOT: arm_sme.tile_store
402-
func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
403-
%c0 = arith.constant 0 : index
404-
%cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
405-
vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
406-
return
407-
}
408-
409-
// -----
410-
411-
// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
412-
// CHECK: vector.transfer_write
413-
// CHECK-NOT: arm_sme.tile_store
414-
func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
415-
%c0 = arith.constant 0 : index
416-
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
417-
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
418-
return %0 : tensor<?x?xi8>
419-
}
420-
421-
// -----
422-
423-
// CHECK-LABEL: @transfer_write_2d__fixed
424-
// CHECK: vector.transfer_write
425-
// CHECK-NOT: arm_sme.tile_store
426-
func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
427-
%c0 = arith.constant 0 : index
428-
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
429-
return
430-
}
431-
432-
// -----
433-
434-
// CHECK-LABEL: @transfer_write_2d__out_of_bounds
435-
// CHECK: vector.transfer_write
436-
// CHECK-NOT: arm_sme.tile_store
437-
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
438-
%c0 = arith.constant 0 : index
439-
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
440-
return
441-
}
442-
443304
//===----------------------------------------------------------------------===//
444305
// vector.broadcast
445306
//===----------------------------------------------------------------------===//

mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -469,42 +469,6 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
469469
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
470470
}
471471

472-
// -----
473-
474-
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
475-
// expected-error@+1 {{AXPY operations not supported}}
476-
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
477-
return %0 : vector<[2]xf64>
478-
}
479-
480-
// -----
481-
482-
func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
483-
%acc = arm_sme.get_tile : vector<[16]x[16]xi8>
484-
// expected-error@+2 {{failed to legalize operation 'arm_sme.outerproduct'}}
485-
// expected-error@+1 {{unsupported type}}
486-
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
487-
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
488-
}
489-
490-
// -----
491-
492-
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
493-
%acc = arm_sme.get_tile : vector<[2]x[2]xf64>
494-
// expected-error@+1 {{unsupported kind}}
495-
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
496-
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
497-
}
498-
499-
// -----
500-
501-
func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
502-
// CHECK: vector.outerproduct
503-
%acc = arm_sme.get_tile : vector<[4]x[4]xf32>
504-
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
505-
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
506-
}
507-
508472
//===----------------------------------------------------------------------===//
509473
// vector.insert
510474
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)