Skip to content

Commit 17afa5b

Browse files
authored
[mlir][nfc] Update tests for Contract -> Op transforms (#76054)
Updates two tests for vector.contract -> vector.outerproduct transformations: 1. Rename "vector-contract-to-outerproduct-transforms.mlir" as "vector-contract-to-outerproduct-matmul-transforms.mlir". The new name more accurate captures what's being tested. it is also consistent with "vector-contract-to-outerproduct-matvec-transforms.mlir", which covers vector matvec operations and makes finding relevant tests easier. 2. For matmul tests, move the traits definining the iteration spaces to the top of the file. This is consistent with how matvec tests are defined and also makes it easy to quickly identify what cases are covered. 3. For matmul tests, use more meaningful names for function arguments. This helps keep things consistent across the file (i.e. function definitions wih check lines and comments). 4. For matvec test, move a few tests around so that the most basic case (without masking) is first. 5. Update comments.
1 parent 513c215 commit 17afa5b

File tree

2 files changed

+149
-144
lines changed

2 files changed

+149
-144
lines changed

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir renamed to mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir

Lines changed: 119 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

3-
// NOTE - tests in this file are duplicated so that there's a version for
4-
// * _fixed width_ and for _scalable_ vectors.
5-
// In order for the "vector.contract -> vector.outerproduct" patterns to work,
6-
// only the non-reduction dimension can be scalable (*). For Matmul operations
7-
// that is set to be the N dimension (i.e. rows of the output matrix), which
8-
// matches how matrix multiplication are normally implemented for e.g.
9-
// Arm SVE. However, making the M dimension scalable (i.e. columns of the
10-
// output matrix) should work as well.
11-
//
12-
// (*) The conversion tested in this file unrolls along the reduction
13-
// dimension, which is not supported for scalable vectors.
3+
/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
4+
/// matmul operations:
5+
/// C += A * B.
6+
/// (A, B and C are 2-d matrices). ATM three different variants / are tested:
7+
/// * plain (no mask, fixed-wdith vectors),
8+
/// * masked (fixed-width vectors,
9+
/// * scalable (mask + scalable vectors).
10+
/// In order for the "vector.contract -> vector.outerproduct" patterns to work,
11+
/// only the non-reduction dimension can be scalable (*). For matmul operations
12+
/// that is set to be the N dimension (i.e. rows of the output matrix), which
13+
/// matches how matrix multiplication are normally implemented for e.g.
14+
/// Arm SVE. However, making the M dimension scalable (i.e. columns of the
15+
/// output matrix) should work as well.
16+
///
17+
/// (*) The conversion tested in this file unrolls along the reduction
18+
/// dimension, which is not supported for scalable vectors.
1419

15-
// ============================================================================
16-
// Matmul 0 (plain + masked + mixed types)
17-
// ============================================================================
1820
#matmat_accesses_0 = [
1921
affine_map<(m, n, k) -> (m, k)>,
2022
affine_map<(m, n, k) -> (k, n)>,
@@ -25,6 +27,49 @@
2527
iterator_types = ["parallel", "parallel", "reduction"]
2628
}
2729

30+
#matmat_accesses_1 = [
31+
affine_map<(m, n, k) -> (m, k)>,
32+
affine_map<(m, n, k) -> (n, k)>,
33+
affine_map<(m, n, k) -> (m, n)>
34+
]
35+
#matmat_trait_1 = {
36+
indexing_maps = #matmat_accesses_1,
37+
iterator_types = ["parallel", "parallel", "reduction"]
38+
}
39+
40+
#matmat_accesses_2 = [
41+
affine_map<(m, n, k) -> (k, m)>,
42+
affine_map<(m, n, k) -> (k, n)>,
43+
affine_map<(m, n, k) -> (m, n)>
44+
]
45+
#matmat_trait_2 = {
46+
indexing_maps = #matmat_accesses_2,
47+
iterator_types = ["parallel", "parallel", "reduction"]
48+
}
49+
50+
#matmat_accesses_3 = [
51+
affine_map<(m, n, k) -> (k, m)>,
52+
affine_map<(m, n, k) -> (n, k)>,
53+
affine_map<(m, n, k) -> (m, n)>
54+
]
55+
#matmat_trait_3 = {
56+
indexing_maps = #matmat_accesses_3,
57+
iterator_types = ["parallel", "parallel", "reduction"]
58+
}
59+
60+
#matmat_accesses_4 = [
61+
affine_map<(m, n, k) -> (m, k)>,
62+
affine_map<(m, n, k) -> (k, n)>,
63+
affine_map<(m, n, k) -> (n, m)>
64+
]
65+
#matmat_trait_4 = {
66+
indexing_maps = #matmat_accesses_4,
67+
iterator_types = ["parallel", "parallel", "reduction"]
68+
}
69+
70+
// ============================================================================
71+
// Matmul 0 (plain + masked + mixed types)
72+
// ============================================================================
2873
// CHECK-LABEL: func @matmul
2974
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
3075
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -53,10 +98,10 @@
5398
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
5499
//
55100
// CHECK: return %[[c3]] : vector<2x3xf32>
56-
func.func @matmul(%arg0: vector<2x4xf32>,
57-
%arg1: vector<4x3xf32>,
58-
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
59-
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
101+
func.func @matmul(%A: vector<2x4xf32>,
102+
%B: vector<4x3xf32>,
103+
%C: vector<2x3xf32>) -> vector<2x3xf32> {
104+
%0 = vector.contract #matmat_trait_0 %A, %B, %C
60105
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
61106
return %0 : vector<2x3xf32>
62107
}
@@ -89,10 +134,10 @@ func.func @matmul(%arg0: vector<2x4xf32>,
89134
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
90135
//
91136
// CHECK: return %[[c3]] : vector<2x[3]xf32>
92-
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
93-
%arg1: vector<4x[3]xf32>,
94-
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
95-
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
137+
func.func @matmul_scalable(%A: vector<2x4xf32>,
138+
%B: vector<4x[3]xf32>,
139+
%C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
140+
%0 = vector.contract #matmat_trait_0 %A, %B, %C
96141
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
97142
return %0 : vector<2x[3]xf32>
98143
}
@@ -114,11 +159,11 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
114159
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
115160
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
116161

117-
func.func @masked_matmul(%arg0: vector<3x5xf32>,
118-
%arg1: vector<5x7xf32>,
119-
%arg2: vector<3x7xf32>,
162+
func.func @masked_matmul(%A: vector<3x5xf32>,
163+
%B: vector<5x7xf32>,
164+
%C: vector<3x7xf32>,
120165
%m : vector<3x7x5xi1>) -> vector<3x7xf32> {
121-
%0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
166+
%0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
122167
: vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
123168
return %0 : vector<3x7xf32>
124169
}
@@ -140,11 +185,11 @@ func.func @masked_matmul(%arg0: vector<3x5xf32>,
140185
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
141186
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
142187

143-
func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
144-
%arg1: vector<5x[7]xf32>,
145-
%arg2: vector<3x[7]xf32>,
188+
func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
189+
%B: vector<5x[7]xf32>,
190+
%C: vector<3x[7]xf32>,
146191
%m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
147-
%0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
192+
%0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
148193
: vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
149194
return %0 : vector<3x[7]xf32>
150195
}
@@ -160,11 +205,11 @@ func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
160205
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
161206
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
162207
// CHECK: return %[[c0]] : vector<2x3xf32>
163-
func.func @matmul_mixed(%arg0: vector<2x1xf16>,
164-
%arg1: vector<1x3xf16>,
165-
%arg2: vector<2x3xf32>) -> vector<2x3xf32>
208+
func.func @matmul_mixed(%A: vector<2x1xf16>,
209+
%B: vector<1x3xf16>,
210+
%C: vector<2x3xf32>) -> vector<2x3xf32>
166211
{
167-
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
212+
%0 = vector.contract #matmat_trait_0 %A, %B, %C
168213
: vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
169214
return %0 : vector<2x3xf32>
170215
}
@@ -180,28 +225,18 @@ func.func @matmul_mixed(%arg0: vector<2x1xf16>,
180225
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
181226
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
182227
// CHECK: return %[[c0]] : vector<2x[3]xf32>
183-
func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
184-
%arg1: vector<1x[3]xf16>,
185-
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
228+
func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
229+
%B: vector<1x[3]xf16>,
230+
%C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
186231
{
187-
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
232+
%0 = vector.contract #matmat_trait_0 %A, %B, %C
188233
: vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
189234
return %0 : vector<2x[3]xf32>
190235
}
191236

192237
// ============================================================================
193-
// Matmul 1 (plain)
238+
// Matmul 1 (plain + scalable)
194239
// ============================================================================
195-
#matmat_accesses_1 = [
196-
affine_map<(m, n, k) -> (m, k)>,
197-
affine_map<(m, n, k) -> (n, k)>,
198-
affine_map<(m, n, k) -> (m, n)>
199-
]
200-
#matmat_trait_1 = {
201-
indexing_maps = #matmat_accesses_1,
202-
iterator_types = ["parallel", "parallel", "reduction"]
203-
}
204-
205240
// CHECK-LABEL: func @matmul_1
206241
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
207242
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -212,11 +247,11 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
212247
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
213248
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
214249
// CHECK: return %[[c0]] : vector<2x3xf32>
215-
func.func @matmul_1(%arg0: vector<2x1xf32>,
216-
%arg1: vector<3x1xf32>,
217-
%arg2: vector<2x3xf32>) -> vector<2x3xf32>
250+
func.func @matmul_1(%A: vector<2x1xf32>,
251+
%B: vector<3x1xf32>,
252+
%C: vector<2x3xf32>) -> vector<2x3xf32>
218253
{
219-
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
254+
%0 = vector.contract #matmat_trait_1 %A, %B, %C
220255
: vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
221256
return %0 : vector<2x3xf32>
222257
}
@@ -231,28 +266,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>,
231266
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
232267
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
233268
// CHECK: return %[[c0]] : vector<2x[3]xf32>
234-
func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
235-
%arg1: vector<[3]x1xf32>,
236-
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
269+
func.func @matmul_1_scalable(%A: vector<2x1xf32>,
270+
%B: vector<[3]x1xf32>,
271+
%C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
237272
{
238-
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
273+
%0 = vector.contract #matmat_trait_1 %A, %B, %C
239274
: vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
240275
return %0 : vector<2x[3]xf32>
241276
}
242277

243278
// ============================================================================
244-
// Matmul 2 (plain)
279+
// Matmul 2 (plain + scalable)
245280
// ============================================================================
246-
#matmat_accesses_2 = [
247-
affine_map<(m, n, k) -> (k, m)>,
248-
affine_map<(m, n, k) -> (k, n)>,
249-
affine_map<(m, n, k) -> (m, n)>
250-
]
251-
#matmat_trait_2 = {
252-
indexing_maps = #matmat_accesses_2,
253-
iterator_types = ["parallel", "parallel", "reduction"]
254-
}
255-
256281
// CHECK-LABEL: func @matmul_2
257282
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
258283
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -261,11 +286,11 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
261286
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
262287
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
263288
// CHECK: return %[[c0]] : vector<2x3xf32>
264-
func.func @matmul_2(%arg0: vector<1x2xf32>,
265-
%arg1: vector<1x3xf32>,
266-
%arg2: vector<2x3xf32>) -> vector<2x3xf32>
289+
func.func @matmul_2(%A: vector<1x2xf32>,
290+
%B: vector<1x3xf32>,
291+
%C: vector<2x3xf32>) -> vector<2x3xf32>
267292
{
268-
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
293+
%0 = vector.contract #matmat_trait_2 %A, %B, %C
269294
: vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
270295
return %0 : vector<2x3xf32>
271296
}
@@ -278,28 +303,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>,
278303
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
279304
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
280305
// CHECK: return %[[c0]] : vector<2x[3]xf32>
281-
func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
282-
%arg1: vector<1x[3]xf32>,
283-
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
306+
func.func @matmul_2_scalable(%A: vector<1x2xf32>,
307+
%B: vector<1x[3]xf32>,
308+
%C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
284309
{
285-
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
310+
%0 = vector.contract #matmat_trait_2 %A, %B, %C
286311
: vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
287312
return %0 : vector<2x[3]xf32>
288313
}
289314

290315
// ============================================================================
291-
// Matmul 3 (plain)
316+
// Matmul 3 (plain + scalable)
292317
// ============================================================================
293-
#matmat_accesses_3 = [
294-
affine_map<(m, n, k) -> (k, m)>,
295-
affine_map<(m, n, k) -> (n, k)>,
296-
affine_map<(m, n, k) -> (m, n)>
297-
]
298-
#matmat_trait_3 = {
299-
indexing_maps = #matmat_accesses_3,
300-
iterator_types = ["parallel", "parallel", "reduction"]
301-
}
302-
303318
// CHECK-LABEL: func @matmul_3
304319
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
305320
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -309,11 +324,11 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
309324
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
310325
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
311326
// CHECK: return %[[c0]] : vector<2x3xf32>
312-
func.func @matmul_3(%arg0: vector<1x2xf32>,
313-
%arg1: vector<3x1xf32>,
314-
%arg2: vector<2x3xf32>) -> vector<2x3xf32>
327+
func.func @matmul_3(%A: vector<1x2xf32>,
328+
%B: vector<3x1xf32>,
329+
%C: vector<2x3xf32>) -> vector<2x3xf32>
315330
{
316-
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
331+
%0 = vector.contract #matmat_trait_3 %A, %B, %C
317332
: vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
318333
return %0 : vector<2x3xf32>
319334
}
@@ -327,28 +342,18 @@ func.func @matmul_3(%arg0: vector<1x2xf32>,
327342
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
328343
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
329344
// CHECK: return %[[c0]] : vector<2x[3]xf32>
330-
func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
331-
%arg1: vector<[3]x1xf32>,
332-
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
345+
func.func @matmul_3_scalable(%A: vector<1x2xf32>,
346+
%B: vector<[3]x1xf32>,
347+
%C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
333348
{
334-
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
349+
%0 = vector.contract #matmat_trait_3 %A, %B, %C
335350
: vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
336351
return %0 : vector<2x[3]xf32>
337352
}
338353

339354
// ============================================================================
340-
// Matmul 4 (plain)
355+
// Matmul 4 (plain + scalable)
341356
// ============================================================================
342-
#matmat_accesses_4 = [
343-
affine_map<(m, n, k) -> (m, k)>,
344-
affine_map<(m, n, k) -> (k, n)>,
345-
affine_map<(m, n, k) -> (n, m)>
346-
]
347-
#matmat_trait_4 = {
348-
indexing_maps = #matmat_accesses_4,
349-
iterator_types = ["parallel", "parallel", "reduction"]
350-
}
351-
352357
// CHECK-LABEL: func @matmul_4
353358
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
354359
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -358,11 +363,11 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
358363
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
359364
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
360365
// CHECK: return %[[c0]] : vector<3x2xf32>
361-
func.func @matmul_4(%arg0: vector<2x1xf32>,
362-
%arg1: vector<1x3xf32>,
363-
%arg2: vector<3x2xf32>) -> vector<3x2xf32>
366+
func.func @matmul_4(%A: vector<2x1xf32>,
367+
%B: vector<1x3xf32>,
368+
%C: vector<3x2xf32>) -> vector<3x2xf32>
364369
{
365-
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
370+
%0 = vector.contract #matmat_trait_4 %A, %B, %C
366371
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
367372
return %0 : vector<3x2xf32>
368373
}
@@ -376,11 +381,11 @@ func.func @matmul_4(%arg0: vector<2x1xf32>,
376381
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
377382
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
378383
// CHECK: return %[[c0]] : vector<3x[2]xf32>
379-
func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
380-
%arg1: vector<1x3xf32>,
381-
%arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
384+
func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
385+
%B: vector<1x3xf32>,
386+
%C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
382387
{
383-
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
388+
%0 = vector.contract #matmat_trait_4 %A, %B, %C
384389
: vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
385390
return %0 : vector<3x[2]xf32>
386391
}

0 commit comments

Comments
 (0)