@@ -357,46 +357,35 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
357
357
return %r0 , %1 : vector <4 xf32 >, vector <2 x2 xf32 >
358
358
}
359
359
360
- // MATRIX-LABEL: func @column_major_matmul
361
- // MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x3xf32>,
362
- // MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x4xf32>,
363
- // MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
364
- // MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<12xf32>
365
- // MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<8xf32>
366
- // MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
367
- // MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<4x3xf32>
368
- // MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
369
- // MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<4x3xf32>
370
- // MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
371
- // MATRIX: %[[a4:.*]] = vector.extract %[[A]][2] : vector<4x3xf32>
372
- // MATRIX: %[[a5:.*]] = vector.insert_strided_slice %[[a4]], %[[a3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
373
- // MATRIX: %[[a6:.*]] = vector.extract %[[A]][3] : vector<4x3xf32>
374
- // MATRIX: %[[a7:.*]] = vector.insert_strided_slice %[[a6]], %[[a5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
375
- // MATRIX: %[[b8:.*]] = vector.extract %[[B]][0] : vector<2x4xf32>
376
- // MATRIX: %[[b9:.*]] = vector.insert_strided_slice %[[b8]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
377
- // MATRIX: %[[b10:.*]] = vector.extract %[[B]][1] : vector<2x4xf32>
378
- // MATRIX: %[[b11:.*]] = vector.insert_strided_slice %[[b10]], %[[b9]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
379
- // MATRIX: %[[mm12:.*]] = vector.matrix_multiply %[[a7]], %[[b11]] {lhs_columns = 3 : i32, lhs_rows = 4 : i32, rhs_columns = 4 : i32} : (vector<12xf32>, vector<8xf32>) -> vector<12xf32>
380
- // MATRIX: %[[mm13:.*]] = vector.strided_slice %[[mm12]] {offsets = [0], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
381
- // MATRIX: %[[mm14:.*]] = vector.insert %[[mm13]], %[[vcst_1]] [0] : vector<2xf32> into vector<3x2xf32>
382
- // MATRIX: %[[mm15:.*]] = vector.strided_slice %[[mm12]] {offsets = [2], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
383
- // MATRIX: %[[mm16:.*]] = vector.insert %[[mm15]], %[[mm14]] [1] : vector<2xf32> into vector<3x2xf32>
384
- // MATRIX: %[[mm17:.*]] = vector.strided_slice %[[mm12]] {offsets = [4], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
385
- // MATRIX: %[[mm18:.*]] = vector.insert %[[mm17]], %[[mm16]] [2] : vector<2xf32> into vector<3x2xf32>
386
- // MATRIX: %[[mm19:.*]] = addf %[[C]], %[[mm18]] : vector<3x2xf32>
387
- #column_major_matmat_accesses = [
388
- affine_map <(i , j , k ) -> (k , j )>,
389
- affine_map <(i , j , k ) -> (i , k )>,
390
- affine_map <(i , j , k ) -> (j , i )>
391
- ]
392
- #column_major_matmat_trait = {
393
- indexing_maps = #column_major_matmat_accesses ,
394
- iterator_types = [" parallel" , " parallel" , " reduction" ]
395
- }
396
- func @column_major_matmul (%arg0: vector <4 x3 xf32 >,
397
- %arg1: vector <2 x4 xf32 >,
398
- %arg2: vector <3 x2 xf32 >) -> vector <3 x2 xf32 > {
399
- %0 = vector.contract #column_major_matmat_trait %arg0 , %arg1 , %arg2
400
- : vector <4 x3 xf32 >, vector <2 x4 xf32 > into vector <3 x2 xf32 >
401
- return %0 : vector <3 x2 xf32 >
360
+ // MATRIX-LABEL: func @matmul
361
+ // MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
362
+ // MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
363
+ // MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
364
+ // MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<8xf32>
365
+ // MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<12xf32>
366
+ // MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
367
+ // MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
368
+ // MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
369
+ // MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
370
+ // MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
371
+ // MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
372
+ // MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
373
+ // MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
374
+ // MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
375
+ // MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
376
+ // MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
377
+ // MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
378
+ // MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
379
+ // MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
380
+ // MATRIX: %[[mm2:.*]] = vector.strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
381
+ // MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32>
382
+ // MATRIX: %[[mm4:.*]] = vector.strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
383
+ // MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
384
+ // MATRIX: %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32>
385
+ func @matmul (%arg0: vector <2 x4 xf32 >,
386
+ %arg1: vector <4 x3 xf32 >,
387
+ %arg2: vector <2 x3 xf32 >) -> vector <2 x3 xf32 > {
388
+ %0 = vector.contract #matmat_trait %arg0 , %arg1 , %arg2
389
+ : vector <2 x4 xf32 >, vector <4 x3 xf32 > into vector <2 x3 xf32 >
390
+ return %0 : vector <2 x3 xf32 >
402
391
}
0 commit comments