1
+ // REQUIRES: arm-emulator
2
+
3
+ // DEFINE: %{compile} = mlir-opt %s \
4
+ // DEFINE: --arm-sve-legalize-vector-storage --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve' \
5
+ // DEFINE: --expand-strided-metadata --lower-affine --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
6
+ // DEFINE: -o %t
7
+
8
+ // DEFINE: %{entry_point} = main
9
+
10
+ // DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve" \
11
+ // DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
12
+
13
+ // RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
14
+
15
+ // Test the transfer_read with vector type with a non-trailing scalable
16
+ // dimension as transformed by the pattern LegalizeTransferRead.
17
+
18
+ func.func @transfer_read_scalable_non_trailing (%vs : i32 , %M : memref <?x8 xi8 >) {
19
+ func.call @setArmVLBits (%vs ) : (i32 ) -> ()
20
+
21
+ // Read an LLVM-illegal vector
22
+ %c0 = arith.constant 0 : index
23
+ %c0_i8 = arith.constant 0 : i8
24
+ %A = vector.transfer_read %M [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x8 xi8 >, vector <[4 ]x8 xi8 >
25
+
26
+ // Print the vector, for verification.
27
+ %B = vector.shape_cast %A : vector <[4 ]x8 xi8 > to vector <[32 ]xi8 >
28
+ func.call @printVec (%B ) : (vector <[32 ]xi8 >) -> ()
29
+
30
+ return
31
+ }
32
+
33
+ func.func @main () {
34
+
35
+ %c0 = arith.constant 0 : index
36
+
37
+ // Prepare an 8x8 buffer with test data. The test performs two reads
38
+ // of a [4]x8 vector from the buffer. One read, with vector length 128 bits,
39
+ // reads the first half the buffer. The other read, with vector length
40
+ // 256 bits, reads the entire buffer.
41
+ %T = arith.constant dense <[[11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 ],
42
+ [21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 ],
43
+ [31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 ],
44
+ [41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 ],
45
+ [51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 ],
46
+ [61 , 62 , 63 , 64 , 65 , 66 , 67 , 68 ],
47
+ [71 , 72 , 73 , 74 , 75 , 76 , 77 , 78 ],
48
+ [81 , 82 , 83 , 84 , 85 , 86 , 87 , 88 ]]> : vector <8 x8 xi8 >
49
+
50
+ %M = memref.alloca () : memref <8 x8 xi8 >
51
+ vector.transfer_write %T , %M [%c0 , %c0 ] : vector <8 x8 xi8 >, memref <8 x8 xi8 >
52
+ %MM = memref.cast %M : memref <8 x8 xi8 > to memref <?x8 xi8 >
53
+
54
+ // CHECK-LABEL: Result(VL128):
55
+ // CHECK:( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28 )
56
+ // CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
57
+ vector.print str " Result(VL128):\n "
58
+ %c128 = arith.constant 128 : i32
59
+ func.call @transfer_read_scalable_non_trailing (%c128 , %MM ) : (i32 , memref <?x8 xi8 >) -> ()
60
+
61
+ // CHECK-LABEL: Result(VL256):
62
+ // CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
63
+ // CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 )
64
+ vector.print str " Result(VL256):\n "
65
+ %c256 = arith.constant 256 : i32
66
+ func.call @transfer_read_scalable_non_trailing (%c256 , %MM ) : (i32 , memref <?x8 xi8 >) -> ()
67
+
68
+ return
69
+ }
70
+
71
+ func.func private @printVec (%v : vector <[32 ]xi8 >) {
72
+ %v0 = vector.scalable.extract %v [0 ] : vector <[16 ]xi8 > from vector <[32 ]xi8 >
73
+ %v1 = vector.scalable.extract %v [16 ] : vector <[16 ]xi8 > from vector <[32 ]xi8 >
74
+ vector.print %v0 : vector <[16 ]xi8 >
75
+ vector.print %v1 : vector <[16 ]xi8 >
76
+ return
77
+ }
78
+
79
+ func.func private @setArmVLBits (%bits : i32 )
0 commit comments