|
1 | 1 | // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
|
2 | 2 |
|
3 | 3 | #matvec_accesses = [
|
4 |
| - affine_map<(i, j) -> (i, j)>, |
5 |
| - affine_map<(i, j) -> (j)>, |
6 |
| - affine_map<(i, j) -> (i)> |
| 4 | + affine_map<(m, k) -> (m, k)>, |
| 5 | + affine_map<(m, k) -> (k)>, |
| 6 | + affine_map<(m, k) -> (m)> |
7 | 7 | ]
|
8 | 8 | #matvec_trait = {
|
9 | 9 | indexing_maps = #matvec_accesses,
|
|
16 | 16 | }
|
17 | 17 |
|
18 | 18 | #mattransvec_accesses = [
|
19 |
| - affine_map<(i, j) -> (j, i)>, |
20 |
| - affine_map<(i, j) -> (j)>, |
21 |
| - affine_map<(i, j) -> (i)> |
| 19 | + affine_map<(m, k) -> (k, m)>, |
| 20 | + affine_map<(m, k) -> (k)>, |
| 21 | + affine_map<(m, k) -> (m)> |
22 | 22 | ]
|
23 | 23 | #mattransvec_trait = {
|
24 | 24 | indexing_maps = #mattransvec_accesses,
|
25 | 25 | iterator_types = ["parallel", "reduction"]
|
26 | 26 | }
|
27 | 27 |
|
28 | 28 | #vecmat_accesses = [
|
29 |
| - affine_map<(i, j) -> (j)>, |
30 |
| - affine_map<(i, j) -> (i, j)>, |
31 |
| - affine_map<(i, j) -> (i)> |
| 29 | + affine_map<(m, k) -> (k)>, |
| 30 | + affine_map<(m, k) -> (m, k)>, |
| 31 | + affine_map<(m, k) -> (m)> |
32 | 32 | ]
|
33 | 33 | #vecmat_trait = {
|
34 | 34 | indexing_maps = #vecmat_accesses,
|
35 | 35 | iterator_types = ["parallel", "reduction"]
|
36 | 36 | }
|
37 | 37 |
|
38 | 38 | #vecmattrans_accesses = [
|
39 |
| - affine_map<(i, j) -> (j)>, |
40 |
| - affine_map<(i, j) -> (j, i)>, |
41 |
| - affine_map<(i, j) -> (i)> |
| 39 | + affine_map<(m, k) -> (k)>, |
| 40 | + affine_map<(m, k) -> (k, m)>, |
| 41 | + affine_map<(m, k) -> (m)> |
42 | 42 | ]
|
43 | 43 | #vecmattrans_trait = {
|
44 | 44 | indexing_maps = #vecmattrans_accesses,
|
45 | 45 | iterator_types = ["parallel", "reduction"]
|
46 | 46 | }
|
47 | 47 |
|
48 | 48 | #redpar_vecmattrans_accesses = [
|
49 |
| - affine_map<(i, j) -> (i)>, |
50 |
| - affine_map<(i, j) -> (i, j)>, |
51 |
| - affine_map<(i, j) -> (j)> |
| 49 | + affine_map<(m, k) -> (m)>, |
| 50 | + affine_map<(m, k) -> (m, k)>, |
| 51 | + affine_map<(m, k) -> (k)> |
52 | 52 | ]
|
53 | 53 | #redpar_vecmattrans_trait = {
|
54 | 54 | indexing_maps = #redpar_vecmattrans_accesses,
|
55 | 55 | iterator_types = ["reduction", "parallel"]
|
56 | 56 | }
|
57 | 57 |
|
58 |
| -// CHECK-LABEL: func @matvec2x2 |
59 |
| -// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> |
60 |
| -// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> |
61 |
| -// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> |
62 |
| -// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> |
63 |
| -// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> |
64 |
| -// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> |
65 |
| -// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> |
| 58 | +// CHECK-LABEL: func @matvec_mk_k_m |
| 59 | +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> |
| 60 | +// CHECK-SAME: %[[B:.*1]]: vector<2xf32> |
| 61 | +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> |
| 62 | +// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> |
66 | 63 | // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
|
67 |
| -// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32> |
68 |
| -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
| 64 | +// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> |
| 65 | +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
69 | 66 | // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
|
70 |
| -// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32> |
| 67 | +// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> |
71 | 68 | // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
|
72 |
| -// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>> |
73 |
| -// CHECK: return |
74 |
| -func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, |
75 |
| - %arg2: memref<vector<2xf32>>) { |
76 |
| - %A = memref.load %arg0[] : memref<vector<2x2xf32>> |
77 |
| - %x = memref.load %arg1[] : memref<vector<2xf32>> |
78 |
| - %b = memref.load %arg2[] : memref<vector<2xf32>> |
| 69 | +func.func @matvec_mk_k_m(%A: vector<2x2xf32>, |
| 70 | + %x: vector<2xf32>, |
| 71 | + %b: vector<2xf32>) -> vector<2xf32> { |
79 | 72 | %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
|
80 |
| - memref.store %0, %arg2[] : memref<vector<2xf32>> |
81 |
| - return |
| 73 | + return %0 : vector<2xf32> |
82 | 74 | }
|
83 | 75 |
|
84 |
| -// CHECK-LABEL: func @matvecmax2x2 |
85 |
| -// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> |
86 |
| -// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> |
87 |
| -// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> |
88 |
| -// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> |
89 |
| -// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> |
90 |
| -// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> |
91 |
| -// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> |
| 76 | +// CHECK-LABEL: func @matvec_mk_k_m_max |
| 77 | +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> |
| 78 | +// CHECK-SAME: %[[B:.*1]]: vector<2xf32> |
| 79 | +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> |
| 80 | +// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> |
92 | 81 | // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
|
93 |
| -// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32> |
94 |
| -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 |
| 82 | +// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> |
| 83 | +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 |
95 | 84 | // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
|
96 |
| -// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32> |
| 85 | +// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> |
97 | 86 | // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
|
98 |
| -// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>> |
99 |
| -// CHECK: return |
100 |
| -func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, |
101 |
| - %arg2: memref<vector<2xf32>>) { |
102 |
| - %A = memref.load %arg0[] : memref<vector<2x2xf32>> |
103 |
| - %x = memref.load %arg1[] : memref<vector<2xf32>> |
104 |
| - %b = memref.load %arg2[] : memref<vector<2xf32>> |
| 87 | +func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>, |
| 88 | + %x: vector<2xf32>, |
| 89 | + %b: vector<2xf32>) -> vector<2xf32> { |
105 | 90 | %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
|
106 |
| - memref.store %0, %arg2[] : memref<vector<2xf32>> |
107 |
| - return |
| 91 | + return %0 : vector<2xf32> |
108 | 92 | }
|
109 | 93 |
|
110 |
| -// CHECK-LABEL: func @mattransvec2x2 |
111 |
| -// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> |
112 |
| -// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> |
113 |
| -// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> |
114 |
| -// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> |
115 |
| -// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> |
116 |
| -// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> |
117 |
| -// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32> |
118 |
| -// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32> |
119 |
| -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
120 |
| -// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32> |
121 |
| -// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32> |
| 94 | +// CHECK-LABEL: func @matvec_km_k_m |
| 95 | +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> |
| 96 | +// CHECK-SAME: %[[B:.*1]]: vector<2xf32> |
| 97 | +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> |
| 98 | +// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> |
| 99 | +// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> |
| 100 | +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
| 101 | +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> |
| 102 | +// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> |
122 | 103 | // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
|
123 |
| -// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>> |
124 |
| -// CHECK: return |
125 |
| -func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, |
126 |
| - %arg2: memref<vector<2xf32>>) { |
127 |
| - %A = memref.load %arg0[] : memref<vector<2x2xf32>> |
128 |
| - %x = memref.load %arg1[] : memref<vector<2xf32>> |
129 |
| - %b = memref.load %arg2[] : memref<vector<2xf32>> |
| 104 | +func.func @matvec_km_k_m(%A: vector<2x2xf32>, |
| 105 | + %x: vector<2xf32>, |
| 106 | + %b: vector<2xf32>) -> vector<2xf32> { |
130 | 107 | %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
|
131 |
| - memref.store %0, %arg2[] : memref<vector<2xf32>> |
132 |
| - return |
| 108 | + return %0 : vector<2xf32> |
133 | 109 | }
|
134 | 110 |
|
135 |
| -// CHECK-LABEL: func @vecmat2x2 |
136 |
| -// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> |
137 |
| -// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> |
138 |
| -// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> |
139 |
| -// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> |
140 |
| -// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> |
141 |
| -// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> |
142 |
| -// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> |
| 111 | +// CHECK-LABEL: func @matvec_k_mk_m |
| 112 | +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> |
| 113 | +// CHECK-SAME: %[[B:.*1]]: vector<2xf32> |
| 114 | +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> |
| 115 | +// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> |
143 | 116 | // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
|
144 |
| -// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32> |
145 |
| -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
| 117 | +// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> |
| 118 | +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
146 | 119 | // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
|
147 |
| -// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32> |
| 120 | +// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> |
148 | 121 | // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
|
149 |
| -// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>> |
150 |
| -// CHECK: return |
151 |
| -func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, |
152 |
| - %arg2: memref<vector<2xf32>>) { |
153 |
| - %A = memref.load %arg0[] : memref<vector<2x2xf32>> |
154 |
| - %x = memref.load %arg1[] : memref<vector<2xf32>> |
155 |
| - %b = memref.load %arg2[] : memref<vector<2xf32>> |
| 122 | +func.func @matvec_k_mk_m(%A: vector<2x2xf32>, |
| 123 | + %x: vector<2xf32>, |
| 124 | + %b: vector<2xf32>) -> vector<2xf32> { |
156 | 125 | %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
|
157 |
| - memref.store %0, %arg2[] : memref<vector<2xf32>> |
158 |
| - return |
| 126 | + return %0 : vector<2xf32> |
159 | 127 | }
|
160 | 128 |
|
161 |
| -// CHECK-LABEL: func @vecmattrans2x2 |
162 |
| -// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> |
163 |
| -// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> |
164 |
| -// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> |
165 |
| -// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> |
166 |
| -// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> |
167 |
| -// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> |
168 |
| -// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32> |
169 |
| -// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32> |
170 |
| -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
171 |
| -// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32> |
172 |
| -// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32> |
| 129 | +// CHECK-LABEL: func @matvec_k_km_m |
| 130 | +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> |
| 131 | +// CHECK-SAME: %[[B:.*1]]: vector<2xf32> |
| 132 | +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> |
| 133 | +// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> |
| 134 | +// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> |
| 135 | +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
| 136 | +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> |
| 137 | +// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> |
173 | 138 | // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
|
174 |
| -// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>> |
175 |
| -// CHECK: return |
176 |
| -func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, |
177 |
| - %arg2: memref<vector<2xf32>>) { |
178 |
| - %A = memref.load %arg0[] : memref<vector<2x2xf32>> |
179 |
| - %x = memref.load %arg1[] : memref<vector<2xf32>> |
180 |
| - %b = memref.load %arg2[] : memref<vector<2xf32>> |
| 139 | +func.func @matvec_k_km_m(%A: vector<2x2xf32>, |
| 140 | + %x: vector<2xf32>, |
| 141 | + %b: vector<2xf32>) -> vector<2xf32> { |
181 | 142 | %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
|
182 |
| - memref.store %0, %arg2[] : memref<vector<2xf32>> |
183 |
| - return |
| 143 | + return %0 : vector<2xf32> |
184 | 144 | }
|
185 | 145 |
|
186 |
| -// CHECK-LABEL: func @redpar_vecmattrans2x2 |
187 |
| -// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> |
188 |
| -// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> |
189 |
| -// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> |
190 |
| -// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> |
191 |
| -// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> |
192 |
| -// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> |
193 |
| -// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32> |
194 |
| -// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32> |
195 |
| -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
196 |
| -// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32> |
197 |
| -// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32> |
| 146 | +// CHECK-LABEL: func @matvec_m_mk_k |
| 147 | +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> |
| 148 | +// CHECK-SAME: %[[B:.*1]]: vector<2xf32> |
| 149 | +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> |
| 150 | +// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> |
| 151 | +// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> |
| 152 | +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 |
| 153 | +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> |
| 154 | +// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> |
198 | 155 | // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
|
199 |
| -// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>> |
200 |
| -// CHECK: return |
201 |
| -func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, |
202 |
| - %arg2: memref<vector<2xf32>>) { |
203 |
| - %A = memref.load %arg0[] : memref<vector<2x2xf32>> |
204 |
| - %x = memref.load %arg1[] : memref<vector<2xf32>> |
205 |
| - %b = memref.load %arg2[] : memref<vector<2xf32>> |
| 156 | +func.func @matvec_m_mk_k(%A: vector<2x2xf32>, |
| 157 | + %x: vector<2xf32>, |
| 158 | + %b: vector<2xf32>) -> vector<2xf32> { |
206 | 159 | %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
|
207 |
| - memref.store %0, %arg2[] : memref<vector<2xf32>> |
208 |
| - return |
| 160 | + return %0 : vector<2xf32> |
209 | 161 | }
|
210 | 162 |
|
211 | 163 | module attributes {transform.with_named_sequence} {
|
|
0 commit comments