3
3
module {
4
4
func.func @matmul_transpose_b (%arg0: memref <1024 x1024 xf16 >, %arg1: memref <1024 x1024 xf16 >, %arg2: memref <1024 x1024 xf16 >) {
5
5
%c0 = arith.constant 0 : index
6
- %c32 = arith.constant 32 : index
6
+ %c16 = arith.constant 16 : index
7
+ %c64 = arith.constant 64 : index
7
8
%c1024 = arith.constant 1024 : index
8
- scf.parallel (%arg3 , %arg4 ) = (%c0 , %c0 ) to (%c1024 , %c1024 ) step (%c32 , %c32 ) {
9
- %subview_0 = memref.subview %arg2 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <32 x 32 x f16 , strided <[1024 , 1 ], offset : ?>>
10
- %subview_1 = memref.subview %arg0 [%arg3 , 0 ] [32 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
11
- %subview_2 = memref.subview %arg1 [%arg4 , 0 ] [32 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
12
- linalg.matmul_transpose_b ins (%subview_1 , %subview_2 : memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>, memref <32 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>) outs (%subview_0 : memref <32 x 32 x f16 , strided <[1024 , 1 ], offset : ?>>)
9
+ scf.parallel (%arg3 , %arg4 ) = (%c0 , %c0 ) to (%c1024 , %c1024 ) step (%c16 , %c64 ) {
10
+ %subview_0 = memref.subview %arg2 [%arg3 , %arg4 ] [16 , 64 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <16 x 64 x f16 , strided <[1024 , 1 ], offset : ?>>
11
+ %subview_1 = memref.subview %arg0 [%arg3 , 0 ] [16 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <16 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
12
+ %subview_2 = memref.subview %arg1 [%arg4 , 0 ] [64 , 1024 ] [1 , 1 ] : memref <1024 x1024 xf16 > to memref <64 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>
13
+ linalg.matmul_transpose_b ins (%subview_1 , %subview_2 : memref <16 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>, memref <64 x 1024 x f16 , strided <[1024 , 1 ], offset : ?>>) outs (%subview_0 : memref <16 x 64 x f16 , strided <[1024 , 1 ], offset : ?>>)
13
14
scf.reduce
14
15
}
15
16
return
@@ -19,7 +20,7 @@ module {
19
20
// CHECK-LABEL: func.func @matmul_transpose_b
20
21
// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>
21
22
22
- // CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32 , %c32 ) {
23
+ // CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c16 , %c64 ) {
23
24
// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
24
25
// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
25
26
// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}
@@ -43,9 +44,11 @@ module {
43
44
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
44
45
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
45
46
// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]
47
+ // CHECK: %[[tB2:.+]] = xegpu.update_nd_offset %[[rootB]], [%c32, %c0]
48
+ // CHECK: %[[tB3:.+]] = xegpu.update_nd_offset %[[rootB]], [%c48, %c0]
46
49
47
50
// Create DPAS computation loop over tiled reduction dimension.
48
- // CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
51
+ // CHECK: %[[res:.+]]:13 = scf.for{{.*}}%c0 to %c1024 step %c16
49
52
// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
50
53
// CHECK-SAME: {
51
54
@@ -66,10 +69,10 @@ module {
66
69
67
70
// Extract DPAS-sized chunks from larger loaded tile A.
68
71
// Tile B is already in the correct shape.
69
- // CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16 > to vector<512xf16 >
70
- // CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16 > to vector<128xf16>
72
+ // CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<16x16xf16 > to vector<256xf16 >
73
+ // CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<256xf16 > to vector<128xf16>
71
74
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
72
- // CHECK-COUNT-3 : vector.extract_strided_slice
75
+ // CHECK-COUNT-1 : vector.extract_strided_slice
73
76
74
77
// Perform DPAS computation.
75
78
// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]
0 commit comments