Skip to content

Commit 8b64109

Browse files
authored
Fix linalg.matmul_transpose_b for big tiles (#410)
Signed-off-by: dchigarev <[email protected]>
1 parent 672edc9 commit 8b64109

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,12 @@ static SmallVector<Value> createNdDescriptorTiles(
706706
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
707707
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
708708
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
709-
if (transpose) {
710-
std::swap(newRowOffs, newColOffs);
711-
}
712709
auto tile = rewriter
713710
.create<xegpu::UpdateNdOffsetOp>(
714711
loc, descType, rootTile,
715-
/*offsets=*/ValueRange{newRowOffs, newColOffs},
712+
/*offsets=*/
713+
transpose ? ValueRange{newColOffs, newRowOffs}
714+
: ValueRange{newRowOffs, newColOffs},
716715
SmallVector<int64_t>{ShapedType::kDynamic,
717716
ShapedType::kDynamic})
718717
.getResult();

test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
module {
44
func.func @matmul_transpose_b(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) {
55
%c0 = arith.constant 0 : index
6-
%c32 = arith.constant 32 : index
6+
%c16 = arith.constant 16 : index
7+
%c64 = arith.constant 64 : index
78
%c1024 = arith.constant 1024 : index
8-
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
9-
%subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
10-
%subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
11-
%subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
12-
linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>)
9+
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c16, %c64) {
10+
%subview_0 = memref.subview %arg2[%arg3, %arg4] [16, 64] [1, 1] : memref<1024x1024xf16> to memref<16x64xf16, strided<[1024, 1], offset: ?>>
11+
%subview_1 = memref.subview %arg0[%arg3, 0] [16, 1024] [1, 1] : memref<1024x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>>
12+
%subview_2 = memref.subview %arg1[%arg4, 0] [64, 1024] [1, 1] : memref<1024x1024xf16> to memref<64x1024xf16, strided<[1024, 1], offset: ?>>
13+
linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<64x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<16x64xf16, strided<[1024, 1], offset: ?>>)
1314
scf.reduce
1415
}
1516
return
@@ -19,7 +20,7 @@ module {
1920
// CHECK-LABEL: func.func @matmul_transpose_b
2021
// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>
2122

22-
// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
23+
// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c16, %c64) {
2324
// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
2425
// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
2526
// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}
@@ -43,9 +44,11 @@ module {
4344
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
4445
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
4546
// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]
47+
// CHECK: %[[tB2:.+]] = xegpu.update_nd_offset %[[rootB]], [%c32, %c0]
48+
// CHECK: %[[tB3:.+]] = xegpu.update_nd_offset %[[rootB]], [%c48, %c0]
4649

4750
// Create DPAS computation loop over tiled reduction dimension.
48-
// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
51+
// CHECK: %[[res:.+]]:13 = scf.for{{.*}}%c0 to %c1024 step %c16
4952
// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
5053
// CHECK-SAME: {
5154

@@ -66,10 +69,10 @@ module {
6669

6770
// Extract DPAS-sized chunks from larger loaded tile A.
6871
// Tile B is already in the correct shape.
69-
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
70-
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
72+
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<16x16xf16> to vector<256xf16>
73+
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<256xf16> to vector<128xf16>
7174
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
72-
// CHECK-COUNT-3: vector.extract_strided_slice
75+
// CHECK-COUNT-1: vector.extract_strided_slice
7376

7477
// Perform DPAS computation.
7578
// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]

0 commit comments

Comments
 (0)