Skip to content

Commit 730e0d0

Browse files
authored
[mlir][vector][nfc] Refactor vector.contract matvec tests (#72832)
Update tests in "vector-contract-matvec-transforms.mlir" so that they are consistent with similar tests in: * "vector-contract-to-outerproduct-transforms.mlir". This is to enable further refactoring in a follow-up patch, namely to: * remove duplication (this will be much easier once consistent naming is used), * extend tests in "vector-contract-matvec-transforms.mlir" with cases for scalable vectors, * merge "vector-contract-matvec-transforms.mlir" and "vector-contract-to-outerproduct-transforms.mlir" (there's no need for 2 different files testing identical transformations). Overview of changes in this patch: 1. Simplify the test by removing MemRef wrappers - this test verifies Vector -> Vector transformations and MemRefs are not needed. 2. Use (m, k) indices instead of (i, j). 3. Rename function names. This is part of a larger effort to improve test coverage for scalable vectors in the Vector dialect. Implements #72834.
1 parent 7a6fd49 commit 730e0d0

File tree

1 file changed

+90
-138
lines changed

1 file changed

+90
-138
lines changed

mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir

Lines changed: 90 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

33
#matvec_accesses = [
4-
affine_map<(i, j) -> (i, j)>,
5-
affine_map<(i, j) -> (j)>,
6-
affine_map<(i, j) -> (i)>
4+
affine_map<(m, k) -> (m, k)>,
5+
affine_map<(m, k) -> (k)>,
6+
affine_map<(m, k) -> (m)>
77
]
88
#matvec_trait = {
99
indexing_maps = #matvec_accesses,
@@ -16,196 +16,148 @@
1616
}
1717

1818
#mattransvec_accesses = [
19-
affine_map<(i, j) -> (j, i)>,
20-
affine_map<(i, j) -> (j)>,
21-
affine_map<(i, j) -> (i)>
19+
affine_map<(m, k) -> (k, m)>,
20+
affine_map<(m, k) -> (k)>,
21+
affine_map<(m, k) -> (m)>
2222
]
2323
#mattransvec_trait = {
2424
indexing_maps = #mattransvec_accesses,
2525
iterator_types = ["parallel", "reduction"]
2626
}
2727

2828
#vecmat_accesses = [
29-
affine_map<(i, j) -> (j)>,
30-
affine_map<(i, j) -> (i, j)>,
31-
affine_map<(i, j) -> (i)>
29+
affine_map<(m, k) -> (k)>,
30+
affine_map<(m, k) -> (m, k)>,
31+
affine_map<(m, k) -> (m)>
3232
]
3333
#vecmat_trait = {
3434
indexing_maps = #vecmat_accesses,
3535
iterator_types = ["parallel", "reduction"]
3636
}
3737

3838
#vecmattrans_accesses = [
39-
affine_map<(i, j) -> (j)>,
40-
affine_map<(i, j) -> (j, i)>,
41-
affine_map<(i, j) -> (i)>
39+
affine_map<(m, k) -> (k)>,
40+
affine_map<(m, k) -> (k, m)>,
41+
affine_map<(m, k) -> (m)>
4242
]
4343
#vecmattrans_trait = {
4444
indexing_maps = #vecmattrans_accesses,
4545
iterator_types = ["parallel", "reduction"]
4646
}
4747

4848
#redpar_vecmattrans_accesses = [
49-
affine_map<(i, j) -> (i)>,
50-
affine_map<(i, j) -> (i, j)>,
51-
affine_map<(i, j) -> (j)>
49+
affine_map<(m, k) -> (m)>,
50+
affine_map<(m, k) -> (m, k)>,
51+
affine_map<(m, k) -> (k)>
5252
]
5353
#redpar_vecmattrans_trait = {
5454
indexing_maps = #redpar_vecmattrans_accesses,
5555
iterator_types = ["reduction", "parallel"]
5656
}
5757

58-
// CHECK-LABEL: func @matvec2x2
59-
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
60-
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
61-
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
62-
// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
63-
// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
64-
// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
65-
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
58+
// CHECK-LABEL: func @matvec_mk_k_m
59+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
60+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
61+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
62+
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
6663
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
67-
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
68-
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
64+
// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
65+
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
6966
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
70-
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
67+
// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
7168
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
72-
// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
73-
// CHECK: return
74-
func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
75-
%arg2: memref<vector<2xf32>>) {
76-
%A = memref.load %arg0[] : memref<vector<2x2xf32>>
77-
%x = memref.load %arg1[] : memref<vector<2xf32>>
78-
%b = memref.load %arg2[] : memref<vector<2xf32>>
69+
func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
70+
%x: vector<2xf32>,
71+
%b: vector<2xf32>) -> vector<2xf32> {
7972
%0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
80-
memref.store %0, %arg2[] : memref<vector<2xf32>>
81-
return
73+
return %0 : vector<2xf32>
8274
}
8375

84-
// CHECK-LABEL: func @matvecmax2x2
85-
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
86-
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
87-
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
88-
// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
89-
// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
90-
// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
91-
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
76+
// CHECK-LABEL: func @matvec_mk_k_m_max
77+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
78+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
79+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
80+
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
9281
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
93-
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
94-
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
82+
// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
83+
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
9584
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
96-
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
85+
// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
9786
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
98-
// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
99-
// CHECK: return
100-
func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
101-
%arg2: memref<vector<2xf32>>) {
102-
%A = memref.load %arg0[] : memref<vector<2x2xf32>>
103-
%x = memref.load %arg1[] : memref<vector<2xf32>>
104-
%b = memref.load %arg2[] : memref<vector<2xf32>>
87+
func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
88+
%x: vector<2xf32>,
89+
%b: vector<2xf32>) -> vector<2xf32> {
10590
%0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
106-
memref.store %0, %arg2[] : memref<vector<2xf32>>
107-
return
91+
return %0 : vector<2xf32>
10892
}
10993

110-
// CHECK-LABEL: func @mattransvec2x2
111-
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
112-
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
113-
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
114-
// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
115-
// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
116-
// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
117-
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
118-
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
119-
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
120-
// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
121-
// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
94+
// CHECK-LABEL: func @matvec_km_k_m
95+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
96+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
97+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
98+
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
99+
// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
100+
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
101+
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
102+
// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
122103
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
123-
// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
124-
// CHECK: return
125-
func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
126-
%arg2: memref<vector<2xf32>>) {
127-
%A = memref.load %arg0[] : memref<vector<2x2xf32>>
128-
%x = memref.load %arg1[] : memref<vector<2xf32>>
129-
%b = memref.load %arg2[] : memref<vector<2xf32>>
104+
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
105+
%x: vector<2xf32>,
106+
%b: vector<2xf32>) -> vector<2xf32> {
130107
%0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
131-
memref.store %0, %arg2[] : memref<vector<2xf32>>
132-
return
108+
return %0 : vector<2xf32>
133109
}
134110

135-
// CHECK-LABEL: func @vecmat2x2
136-
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
137-
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
138-
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
139-
// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
140-
// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
141-
// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
142-
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
111+
// CHECK-LABEL: func @matvec_k_mk_m
112+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
113+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
114+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
115+
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
143116
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
144-
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
145-
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
117+
// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
118+
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
146119
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
147-
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
120+
// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
148121
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
149-
// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
150-
// CHECK: return
151-
func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
152-
%arg2: memref<vector<2xf32>>) {
153-
%A = memref.load %arg0[] : memref<vector<2x2xf32>>
154-
%x = memref.load %arg1[] : memref<vector<2xf32>>
155-
%b = memref.load %arg2[] : memref<vector<2xf32>>
122+
func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
123+
%x: vector<2xf32>,
124+
%b: vector<2xf32>) -> vector<2xf32> {
156125
%0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
157-
memref.store %0, %arg2[] : memref<vector<2xf32>>
158-
return
126+
return %0 : vector<2xf32>
159127
}
160128

161-
// CHECK-LABEL: func @vecmattrans2x2
162-
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
163-
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
164-
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
165-
// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
166-
// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
167-
// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
168-
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
169-
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
170-
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
171-
// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
172-
// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
129+
// CHECK-LABEL: func @matvec_k_km_m
130+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
131+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
132+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
133+
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
134+
// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
135+
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
136+
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
137+
// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
173138
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
174-
// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
175-
// CHECK: return
176-
func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
177-
%arg2: memref<vector<2xf32>>) {
178-
%A = memref.load %arg0[] : memref<vector<2x2xf32>>
179-
%x = memref.load %arg1[] : memref<vector<2xf32>>
180-
%b = memref.load %arg2[] : memref<vector<2xf32>>
139+
func.func @matvec_k_km_m(%A: vector<2x2xf32>,
140+
%x: vector<2xf32>,
141+
%b: vector<2xf32>) -> vector<2xf32> {
181142
%0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
182-
memref.store %0, %arg2[] : memref<vector<2xf32>>
183-
return
143+
return %0 : vector<2xf32>
184144
}
185145

186-
// CHECK-LABEL: func @redpar_vecmattrans2x2
187-
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
188-
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
189-
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
190-
// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
191-
// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
192-
// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
193-
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
194-
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
195-
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
196-
// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
197-
// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
146+
// CHECK-LABEL: func @matvec_m_mk_k
147+
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
148+
// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
149+
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
150+
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
151+
// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
152+
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
153+
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
154+
// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
198155
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
199-
// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
200-
// CHECK: return
201-
func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
202-
%arg2: memref<vector<2xf32>>) {
203-
%A = memref.load %arg0[] : memref<vector<2x2xf32>>
204-
%x = memref.load %arg1[] : memref<vector<2xf32>>
205-
%b = memref.load %arg2[] : memref<vector<2xf32>>
156+
func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
157+
%x: vector<2xf32>,
158+
%b: vector<2xf32>) -> vector<2xf32> {
206159
%0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
207-
memref.store %0, %arg2[] : memref<vector<2xf32>>
208-
return
160+
return %0 : vector<2xf32>
209161
}
210162

211163
module attributes {transform.with_named_sequence} {

0 commit comments

Comments
 (0)