Skip to content

Commit f63e8f6

Browse files
committed
Add transform test
Signed-off-by: dchigarev <[email protected]>
1 parent 6a4cd4a commit f63e8f6

File tree

3 files changed

+65
-13
lines changed

3 files changed

+65
-13
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
888888
assert(loadShape.size() <= 2 &&
889889
"Require at most 2D tile size for eltwise lowering");
890890

891-
auto srcType = src.getType().cast<MemRefType>();
891+
auto srcType = cast<MemRefType>(src.getType());
892892
assert(srcType.getRank() == 2 && "Expected a 2D memref");
893893
auto elemByteWidth = srcType.getElementType().getIntOrFloatBitWidth() / 8;
894894

test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-slm.mlir

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,90 @@
1-
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s
2-
3-
// TODO: write CHECK directives
1+
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s
42

53
#map = affine_map<(d0) -> (d0 * 64)>
64
#map1 = affine_map<(d0) -> (d0 * 16)>
75

86
func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>) {
7+
// CHECK: %[[loadAccumMatmul:.+]] = arith.constant dense<0.000000e+00> : vector<4x32xf16>
8+
// CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<32xf16>
9+
// CHECK: %[[colTileShift:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271]> : vector<32xindex>
10+
// CHECK: %[[loadOffset:.+]] = arith.constant dense<512> : vector<32xindex>
911
%cst = arith.constant 0.000000e+00 : f16
1012
%c1 = arith.constant 1 : index
1113
%c2 = arith.constant 2 : index
1214
%c4 = arith.constant 4 : index
1315
%c16 = arith.constant 16 : index
14-
gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) {
15-
%x_group_idx = affine.apply #map(%arg5)
16-
%y_group_idx = affine.apply #map(%arg6)
16+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) {
17+
%x_group_idx = affine.apply #map(%arg3)
18+
%y_group_idx = affine.apply #map(%arg4)
1719

18-
%x_thread_idx = affine.apply #map1(%arg8)
19-
%y_thread_idx = affine.apply #map1(%arg9)
20+
// CHECK: %[[X_THREAD_IDX:.+]] = affine.apply #map1(%arg6)
21+
// CHECK: %[[Y_THREAD_IDX:.+]] = affine.apply #map1(%arg7)
22+
%x_thread_idx = affine.apply #map1(%arg6)
23+
%y_thread_idx = affine.apply #map1(%arg7)
2024

2125
%x_global_idx = arith.addi %x_group_idx, %x_thread_idx : index
2226
%y_global_idx = arith.addi %y_group_idx, %y_thread_idx : index
2327

2428
%a_subview = memref.subview %arg0[%x_global_idx, 0] [16, 1024] [1, 1] : memref<128x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>>
2529
%b_subview = memref.subview %arg1[0, %y_global_idx] [1024, 16] [1, 1] : memref<1024x1024xf16> to memref<1024x16xf16, strided<[1024, 1], offset: ?>>
2630

31+
// CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<64x256xf16, 3>
2732
%slm_buff = memref.alloc() : memref<64x256xf16, 3>
33+
// CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .*
34+
// CHECK: %[[SLM_X_OFF:.+]] = arith.muli %[[X_THREAD_IDX]], %c256 : index
35+
// CHECK: %[[SLM_THREAD_OFF:.+]] = arith.addi %[[SLM_X_OFF]], %[[Y_THREAD_IDX]] : index
36+
// CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c16384], strides: [%c1] : memref<64x256xf16, 3> to memref<16384xf16, 3>
2837
%slm_subview = memref.subview %slm_buff[%x_thread_idx, %y_thread_idx] [16, 16] [1, 1] : memref<64x256xf16, 3> to memref<16x16xf16, strided<[256, 1], offset: ?>, 3>
2938

39+
// CHECK: %[[SLM_THREAD_OFF_V:.+]] = vector.splat %[[SLM_THREAD_OFF]] : vector<32xindex>
40+
// CHECK: %[[DESC_OFFSET0:.+]] = arith.addi %[[SLM_THREAD_OFF_V]], %[[colTileShift]] : vector<32xindex>
41+
// CHECK: %[[ROOT_DESC:.+]] = xegpu.create_tdesc %[[FLAT_SLM]], %[[DESC_OFFSET0]] : memref<16384xf16, 3>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 1 : i64>>
42+
// CHECK: %[[FILL_DESC1:.+]] = xegpu.update_offset %[[ROOT_DESC]], %[[loadOffset]]
43+
// CHECK: %[[FILL_DESC2:.+]] = xegpu.update_offset %[[FILL_DESC1]], %[[loadOffset]]
44+
// CHECK-COUNT-5: xegpu.update_offset
45+
46+
// CHECK: xegpu.store %[[ZERO]], %[[ROOT_DESC]]
47+
// CHECK: xegpu.store %[[ZERO]], %[[FILL_DESC1]]
48+
// CHECK-COUNT-6: xegpu.store
3049
linalg.fill ins(%cst : f16) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>)
31-
linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>)
3250

33-
%a_add_subview = memref.subview %arg0[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>>
34-
%out_subview = memref.subview %arg2[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>>
51+
// CHECK: %[[MATMUL_DESC1:.+]] = xegpu.update_offset %[[ROOT_DESC]], %[[loadOffset]]
52+
// CHECK: %[[MATMUL_DESC2:.+]] = xegpu.update_offset %[[MATMUL_DESC1]], %[[loadOffset]]
53+
// CHECK-COUNT-5: xegpu.update_offset
54+
55+
// CHECK: %[[MATMUL_LOAD0:.+]] = xegpu.load %[[ROOT_DESC]]
56+
// CHECK-NEXT: %[[loadAccumMatmul1:.+]] = vector.insert %[[MATMUL_LOAD0]], %[[loadAccumMatmul]] [0]
57+
// CHECK-NEXT: %[[MATMUL_LOAD1:.+]] = xegpu.load %[[MATMUL_DESC1]]
58+
// CHECK-NEXT: %[[loadAccumMatmul2:.+]] = vector.insert %[[MATMUL_LOAD1]], %[[loadAccumMatmul1]] [1]
59+
// CHECK-COUNT-2: xegpu.load
60+
61+
// CHECK: vector.shape_cast
62+
// CHECK-SAME: vector<4x32xf16> to vector<128xf16>
63+
// CHECK: vector.shape_cast
64+
// CHECK-SAME: vector<128xf16> to vector<8x16xf16>
3565

36-
linalg.add ins(%slm_subview, %a_add_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>, memref<16x16xf16, strided<[1024, 1], offset: ?>>) outs(%out_subview : memref<16x16xf16, strided<[1024, 1], offset: ?>>)
66+
// CHECK-COUNT-4: xegpu.load
67+
// CHECK: vector.shape_cast
68+
// CHECK-SAME: vector<4x32xf16> to vector<128xf16>
69+
// CHECK: vector.shape_cast
70+
// CHECK-SAME: vector<128xf16> to vector<8x16xf16>
71+
72+
// STORE:
73+
// %[[FLAT_MATMUL_RES0:.+]] = vector.shape_cast %[[MATMUL_RES0:.+]] : vector<8x16xf16> to vector<128xf16>
74+
// %[[STORE_TILE0:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES0]] {offsets = [0], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16>
75+
// xegpu.store %[[STORE_TILE0]], %[[ROOT_DESC]]
76+
// %[[STORE_TILE1:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES0]] {offsets = [32], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16>
77+
// xegpu.store %[[STORE_TILE0]], %[[MATMUL_DESC1]]
78+
// CHECK-COUNT-2: xegpu.store
79+
80+
// %[[FLAT_MATMUL_RES1:.+]] = vector.shape_cast %[[MATMUL_RES1:.+]] : vector<8x16xf16> to vector<128xf16>
81+
// %[[STORE_TILE1_0:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES1]] {offsets = [0], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16>
82+
// xegpu.store %[[STORE_TILE1_0]]
83+
// %[[STORE_TILE1_1:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES1]] {offsets = [32], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16>
84+
// xegpu.store %[[STORE_TILE1_1]]
85+
// CHECK-COUNT-2: xegpu.store
86+
87+
linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>)
3788
gpu.terminator
3889
}
3990
return

test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_64x128_slm.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s
22

33
module @fragment_name {
4+
// This kernel requires using SLM
45
func.func @entry(%0: tensor<64x128xf16>, %1: tensor<128x128xf16>, %2: tensor<64x128xf16>, %res: tensor<64x128xf16>) -> tensor<64x128xf16> {
56
%3 = tensor.empty() : tensor<128x128xf16>
67
%4 = tensor.empty() : tensor<64x128xf16>

0 commit comments

Comments
 (0)