1
1
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
2
3
- // NOTE - tests in this file are duplicated so that there's a version for
4
- // * _fixed width_ and for _scalable_ vectors.
5
- // In order for the "vector.contract -> vector.outerproduct" patterns to work,
6
- // only the non-reduction dimension can be scalable (*). For Matmul operations
7
- // that is set to be the N dimension (i.e. rows of the output matrix), which
8
- // matches how matrix multiplication are normally implemented for e.g.
9
- // Arm SVE. However, making the M dimension scalable (i.e. columns of the
10
- // output matrix) should work as well.
11
- //
12
- // (*) The conversion tested in this file unrolls along the reduction
13
- // dimension, which is not supported for scalable vectors.
3
+ /// Tests for `vector.contract` -> `vector.outerproduct` transformations for
4
+ /// matmul operations:
5
+ /// C += A * B.
6
+ /// (A, B and C are 2-d matrices). ATM three different variants / are tested:
7
+ /// * plain (no mask, fixed-wdith vectors),
8
+ /// * masked (fixed-width vectors,
9
+ /// * scalable (mask + scalable vectors).
10
+ /// In order for the "vector.contract -> vector.outerproduct" patterns to work,
11
+ /// only the non-reduction dimension can be scalable (*). For matmul operations
12
+ /// that is set to be the N dimension (i.e. rows of the output matrix), which
13
+ /// matches how matrix multiplication are normally implemented for e.g.
14
+ /// Arm SVE. However, making the M dimension scalable (i.e. columns of the
15
+ /// output matrix) should work as well.
16
+ ///
17
+ /// (*) The conversion tested in this file unrolls along the reduction
18
+ /// dimension, which is not supported for scalable vectors.
14
19
15
- // ============================================================================
16
- // Matmul 0 (plain + masked + mixed types)
17
- // ============================================================================
18
20
#matmat_accesses_0 = [
19
21
affine_map <(m , n , k ) -> (m , k )>,
20
22
affine_map <(m , n , k ) -> (k , n )>,
25
27
iterator_types = [" parallel" , " parallel" , " reduction" ]
26
28
}
27
29
30
+ #matmat_accesses_1 = [
31
+ affine_map <(m , n , k ) -> (m , k )>,
32
+ affine_map <(m , n , k ) -> (n , k )>,
33
+ affine_map <(m , n , k ) -> (m , n )>
34
+ ]
35
+ #matmat_trait_1 = {
36
+ indexing_maps = #matmat_accesses_1 ,
37
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
38
+ }
39
+
40
+ #matmat_accesses_2 = [
41
+ affine_map <(m , n , k ) -> (k , m )>,
42
+ affine_map <(m , n , k ) -> (k , n )>,
43
+ affine_map <(m , n , k ) -> (m , n )>
44
+ ]
45
+ #matmat_trait_2 = {
46
+ indexing_maps = #matmat_accesses_2 ,
47
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
48
+ }
49
+
50
+ #matmat_accesses_3 = [
51
+ affine_map <(m , n , k ) -> (k , m )>,
52
+ affine_map <(m , n , k ) -> (n , k )>,
53
+ affine_map <(m , n , k ) -> (m , n )>
54
+ ]
55
+ #matmat_trait_3 = {
56
+ indexing_maps = #matmat_accesses_3 ,
57
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
58
+ }
59
+
60
+ #matmat_accesses_4 = [
61
+ affine_map <(m , n , k ) -> (m , k )>,
62
+ affine_map <(m , n , k ) -> (k , n )>,
63
+ affine_map <(m , n , k ) -> (n , m )>
64
+ ]
65
+ #matmat_trait_4 = {
66
+ indexing_maps = #matmat_accesses_4 ,
67
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
68
+ }
69
+
70
+ // ============================================================================
71
+ // Matmul 0 (plain + masked + mixed types)
72
+ // ============================================================================
28
73
// CHECK-LABEL: func @matmul
29
74
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
30
75
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
53
98
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
54
99
//
55
100
// CHECK: return %[[c3]] : vector<2x3xf32>
56
- func.func @matmul (%arg0 : vector <2 x4 xf32 >,
57
- %arg1 : vector <4 x3 xf32 >,
58
- %arg2 : vector <2 x3 xf32 >) -> vector <2 x3 xf32 > {
59
- %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
101
+ func.func @matmul (%A : vector <2 x4 xf32 >,
102
+ %B : vector <4 x3 xf32 >,
103
+ %C : vector <2 x3 xf32 >) -> vector <2 x3 xf32 > {
104
+ %0 = vector.contract #matmat_trait_0 %A , %B , %C
60
105
: vector <2 x4 xf32 >, vector <4 x3 xf32 > into vector <2 x3 xf32 >
61
106
return %0 : vector <2 x3 xf32 >
62
107
}
@@ -89,10 +134,10 @@ func.func @matmul(%arg0: vector<2x4xf32>,
89
134
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
90
135
//
91
136
// CHECK: return %[[c3]] : vector<2x[3]xf32>
92
- func.func @matmul_scalable (%arg0 : vector <2 x4 xf32 >,
93
- %arg1 : vector <4 x[3 ]xf32 >,
94
- %arg2 : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 > {
95
- %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
137
+ func.func @matmul_scalable (%A : vector <2 x4 xf32 >,
138
+ %B : vector <4 x[3 ]xf32 >,
139
+ %C : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 > {
140
+ %0 = vector.contract #matmat_trait_0 %A , %B , %C
96
141
: vector <2 x4 xf32 >, vector <4 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
97
142
return %0 : vector <2 x[3 ]xf32 >
98
143
}
@@ -114,11 +159,11 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
114
159
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
115
160
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
116
161
117
- func.func @masked_matmul (%arg0 : vector <3 x5 xf32 >,
118
- %arg1 : vector <5 x7 xf32 >,
119
- %arg2 : vector <3 x7 xf32 >,
162
+ func.func @masked_matmul (%A : vector <3 x5 xf32 >,
163
+ %B : vector <5 x7 xf32 >,
164
+ %C : vector <3 x7 xf32 >,
120
165
%m : vector <3 x7 x5 xi1 >) -> vector <3 x7 xf32 > {
121
- %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
166
+ %0 = vector.mask %m { vector.contract #matmat_trait_0 %A , %B , %C
122
167
: vector <3 x5 xf32 >, vector <5 x7 xf32 > into vector <3 x7 xf32 > } : vector <3 x7 x5 xi1 > -> vector <3 x7 xf32 >
123
168
return %0 : vector <3 x7 xf32 >
124
169
}
@@ -140,11 +185,11 @@ func.func @masked_matmul(%arg0: vector<3x5xf32>,
140
185
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
141
186
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
142
187
143
- func.func @masked_matmul_scalable (%arg0 : vector <3 x5 xf32 >,
144
- %arg1 : vector <5 x[7 ]xf32 >,
145
- %arg2 : vector <3 x[7 ]xf32 >,
188
+ func.func @masked_matmul_scalable (%A : vector <3 x5 xf32 >,
189
+ %B : vector <5 x[7 ]xf32 >,
190
+ %C : vector <3 x[7 ]xf32 >,
146
191
%m : vector <3 x[7 ]x5 xi1 >) -> vector <3 x[7 ]xf32 > {
147
- %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
192
+ %0 = vector.mask %m { vector.contract #matmat_trait_0 %A , %B , %C
148
193
: vector <3 x5 xf32 >, vector <5 x[7 ]xf32 > into vector <3 x[7 ]xf32 > } : vector <3 x[7 ]x5 xi1 > -> vector <3 x[7 ]xf32 >
149
194
return %0 : vector <3 x[7 ]xf32 >
150
195
}
@@ -160,11 +205,11 @@ func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
160
205
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
161
206
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
162
207
// CHECK: return %[[c0]] : vector<2x3xf32>
163
- func.func @matmul_mixed (%arg0 : vector <2 x1 xf16 >,
164
- %arg1 : vector <1 x3 xf16 >,
165
- %arg2 : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
208
+ func.func @matmul_mixed (%A : vector <2 x1 xf16 >,
209
+ %B : vector <1 x3 xf16 >,
210
+ %C : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
166
211
{
167
- %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
212
+ %0 = vector.contract #matmat_trait_0 %A , %B , %C
168
213
: vector <2 x1 xf16 >, vector <1 x3 xf16 > into vector <2 x3 xf32 >
169
214
return %0 : vector <2 x3 xf32 >
170
215
}
@@ -180,28 +225,18 @@ func.func @matmul_mixed(%arg0: vector<2x1xf16>,
180
225
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
181
226
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
182
227
// CHECK: return %[[c0]] : vector<2x[3]xf32>
183
- func.func @matmul_mixed_scalable (%arg0 : vector <2 x1 xf16 >,
184
- %arg1 : vector <1 x[3 ]xf16 >,
185
- %arg2 : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
228
+ func.func @matmul_mixed_scalable (%A : vector <2 x1 xf16 >,
229
+ %B : vector <1 x[3 ]xf16 >,
230
+ %C : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
186
231
{
187
- %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
232
+ %0 = vector.contract #matmat_trait_0 %A , %B , %C
188
233
: vector <2 x1 xf16 >, vector <1 x[3 ]xf16 > into vector <2 x[3 ]xf32 >
189
234
return %0 : vector <2 x[3 ]xf32 >
190
235
}
191
236
192
237
// ============================================================================
193
- // Matmul 1 (plain)
238
+ // Matmul 1 (plain + scalable )
194
239
// ============================================================================
195
- #matmat_accesses_1 = [
196
- affine_map <(m , n , k ) -> (m , k )>,
197
- affine_map <(m , n , k ) -> (n , k )>,
198
- affine_map <(m , n , k ) -> (m , n )>
199
- ]
200
- #matmat_trait_1 = {
201
- indexing_maps = #matmat_accesses_1 ,
202
- iterator_types = [" parallel" , " parallel" , " reduction" ]
203
- }
204
-
205
240
// CHECK-LABEL: func @matmul_1
206
241
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
207
242
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -212,11 +247,11 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
212
247
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
213
248
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
214
249
// CHECK: return %[[c0]] : vector<2x3xf32>
215
- func.func @matmul_1 (%arg0 : vector <2 x1 xf32 >,
216
- %arg1 : vector <3 x1 xf32 >,
217
- %arg2 : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
250
+ func.func @matmul_1 (%A : vector <2 x1 xf32 >,
251
+ %B : vector <3 x1 xf32 >,
252
+ %C : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
218
253
{
219
- %0 = vector.contract #matmat_trait_1 %arg0 , %arg1 , %arg2
254
+ %0 = vector.contract #matmat_trait_1 %A , %B , %C
220
255
: vector <2 x1 xf32 >, vector <3 x1 xf32 > into vector <2 x3 xf32 >
221
256
return %0 : vector <2 x3 xf32 >
222
257
}
@@ -231,28 +266,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>,
231
266
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
232
267
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
233
268
// CHECK: return %[[c0]] : vector<2x[3]xf32>
234
- func.func @matmul_1_scalable (%arg0 : vector <2 x1 xf32 >,
235
- %arg1 : vector <[3 ]x1 xf32 >,
236
- %arg2 : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
269
+ func.func @matmul_1_scalable (%A : vector <2 x1 xf32 >,
270
+ %B : vector <[3 ]x1 xf32 >,
271
+ %C : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
237
272
{
238
- %0 = vector.contract #matmat_trait_1 %arg0 , %arg1 , %arg2
273
+ %0 = vector.contract #matmat_trait_1 %A , %B , %C
239
274
: vector <2 x1 xf32 >, vector <[3 ]x1 xf32 > into vector <2 x[3 ]xf32 >
240
275
return %0 : vector <2 x[3 ]xf32 >
241
276
}
242
277
243
278
// ============================================================================
244
- // Matmul 2 (plain)
279
+ // Matmul 2 (plain + scalable )
245
280
// ============================================================================
246
- #matmat_accesses_2 = [
247
- affine_map <(m , n , k ) -> (k , m )>,
248
- affine_map <(m , n , k ) -> (k , n )>,
249
- affine_map <(m , n , k ) -> (m , n )>
250
- ]
251
- #matmat_trait_2 = {
252
- indexing_maps = #matmat_accesses_2 ,
253
- iterator_types = [" parallel" , " parallel" , " reduction" ]
254
- }
255
-
256
281
// CHECK-LABEL: func @matmul_2
257
282
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
258
283
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -261,11 +286,11 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
261
286
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
262
287
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
263
288
// CHECK: return %[[c0]] : vector<2x3xf32>
264
- func.func @matmul_2 (%arg0 : vector <1 x2 xf32 >,
265
- %arg1 : vector <1 x3 xf32 >,
266
- %arg2 : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
289
+ func.func @matmul_2 (%A : vector <1 x2 xf32 >,
290
+ %B : vector <1 x3 xf32 >,
291
+ %C : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
267
292
{
268
- %0 = vector.contract #matmat_trait_2 %arg0 , %arg1 , %arg2
293
+ %0 = vector.contract #matmat_trait_2 %A , %B , %C
269
294
: vector <1 x2 xf32 >, vector <1 x3 xf32 > into vector <2 x3 xf32 >
270
295
return %0 : vector <2 x3 xf32 >
271
296
}
@@ -278,28 +303,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>,
278
303
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
279
304
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
280
305
// CHECK: return %[[c0]] : vector<2x[3]xf32>
281
- func.func @matmul_2_scalable (%arg0 : vector <1 x2 xf32 >,
282
- %arg1 : vector <1 x[3 ]xf32 >,
283
- %arg2 : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
306
+ func.func @matmul_2_scalable (%A : vector <1 x2 xf32 >,
307
+ %B : vector <1 x[3 ]xf32 >,
308
+ %C : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
284
309
{
285
- %0 = vector.contract #matmat_trait_2 %arg0 , %arg1 , %arg2
310
+ %0 = vector.contract #matmat_trait_2 %A , %B , %C
286
311
: vector <1 x2 xf32 >, vector <1 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
287
312
return %0 : vector <2 x[3 ]xf32 >
288
313
}
289
314
290
315
// ============================================================================
291
- // Matmul 3 (plain)
316
+ // Matmul 3 (plain + scalable )
292
317
// ============================================================================
293
- #matmat_accesses_3 = [
294
- affine_map <(m , n , k ) -> (k , m )>,
295
- affine_map <(m , n , k ) -> (n , k )>,
296
- affine_map <(m , n , k ) -> (m , n )>
297
- ]
298
- #matmat_trait_3 = {
299
- indexing_maps = #matmat_accesses_3 ,
300
- iterator_types = [" parallel" , " parallel" , " reduction" ]
301
- }
302
-
303
318
// CHECK-LABEL: func @matmul_3
304
319
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
305
320
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -309,11 +324,11 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
309
324
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
310
325
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
311
326
// CHECK: return %[[c0]] : vector<2x3xf32>
312
- func.func @matmul_3 (%arg0 : vector <1 x2 xf32 >,
313
- %arg1 : vector <3 x1 xf32 >,
314
- %arg2 : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
327
+ func.func @matmul_3 (%A : vector <1 x2 xf32 >,
328
+ %B : vector <3 x1 xf32 >,
329
+ %C : vector <2 x3 xf32 >) -> vector <2 x3 xf32 >
315
330
{
316
- %0 = vector.contract #matmat_trait_3 %arg0 , %arg1 , %arg2
331
+ %0 = vector.contract #matmat_trait_3 %A , %B , %C
317
332
: vector <1 x2 xf32 >, vector <3 x1 xf32 > into vector <2 x3 xf32 >
318
333
return %0 : vector <2 x3 xf32 >
319
334
}
@@ -327,28 +342,18 @@ func.func @matmul_3(%arg0: vector<1x2xf32>,
327
342
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
328
343
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
329
344
// CHECK: return %[[c0]] : vector<2x[3]xf32>
330
- func.func @matmul_3_scalable (%arg0 : vector <1 x2 xf32 >,
331
- %arg1 : vector <[3 ]x1 xf32 >,
332
- %arg2 : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
345
+ func.func @matmul_3_scalable (%A : vector <1 x2 xf32 >,
346
+ %B : vector <[3 ]x1 xf32 >,
347
+ %C : vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 >
333
348
{
334
- %0 = vector.contract #matmat_trait_3 %arg0 , %arg1 , %arg2
349
+ %0 = vector.contract #matmat_trait_3 %A , %B , %C
335
350
: vector <1 x2 xf32 >, vector <[3 ]x1 xf32 > into vector <2 x[3 ]xf32 >
336
351
return %0 : vector <2 x[3 ]xf32 >
337
352
}
338
353
339
354
// ============================================================================
340
- // Matmul 4 (plain)
355
+ // Matmul 4 (plain + scalable )
341
356
// ============================================================================
342
- #matmat_accesses_4 = [
343
- affine_map <(m , n , k ) -> (m , k )>,
344
- affine_map <(m , n , k ) -> (k , n )>,
345
- affine_map <(m , n , k ) -> (n , m )>
346
- ]
347
- #matmat_trait_4 = {
348
- indexing_maps = #matmat_accesses_4 ,
349
- iterator_types = [" parallel" , " parallel" , " reduction" ]
350
- }
351
-
352
357
// CHECK-LABEL: func @matmul_4
353
358
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
354
359
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -358,11 +363,11 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
358
363
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
359
364
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
360
365
// CHECK: return %[[c0]] : vector<3x2xf32>
361
- func.func @matmul_4 (%arg0 : vector <2 x1 xf32 >,
362
- %arg1 : vector <1 x3 xf32 >,
363
- %arg2 : vector <3 x2 xf32 >) -> vector <3 x2 xf32 >
366
+ func.func @matmul_4 (%A : vector <2 x1 xf32 >,
367
+ %B : vector <1 x3 xf32 >,
368
+ %C : vector <3 x2 xf32 >) -> vector <3 x2 xf32 >
364
369
{
365
- %0 = vector.contract #matmat_trait_4 %arg0 , %arg1 , %arg2
370
+ %0 = vector.contract #matmat_trait_4 %A , %B , %C
366
371
: vector <2 x1 xf32 >, vector <1 x3 xf32 > into vector <3 x2 xf32 >
367
372
return %0 : vector <3 x2 xf32 >
368
373
}
@@ -376,11 +381,11 @@ func.func @matmul_4(%arg0: vector<2x1xf32>,
376
381
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
377
382
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
378
383
// CHECK: return %[[c0]] : vector<3x[2]xf32>
379
- func.func @matmul_4_scalable (%arg0 : vector <[2 ]x1 xf32 >,
380
- %arg1 : vector <1 x3 xf32 >,
381
- %arg2 : vector <3 x[2 ]xf32 >) -> vector <3 x[2 ]xf32 >
384
+ func.func @matmul_4_scalable (%A : vector <[2 ]x1 xf32 >,
385
+ %B : vector <1 x3 xf32 >,
386
+ %C : vector <3 x[2 ]xf32 >) -> vector <3 x[2 ]xf32 >
382
387
{
383
- %0 = vector.contract #matmat_trait_4 %arg0 , %arg1 , %arg2
388
+ %0 = vector.contract #matmat_trait_4 %A , %B , %C
384
389
: vector <[2 ]x1 xf32 >, vector <1 x3 xf32 > into vector <3 x[2 ]xf32 >
385
390
return %0 : vector <3 x[2 ]xf32 >
386
391
}
0 commit comments