|
1 |
| -// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s |
2 |
| - |
3 |
| -// TODO: write CHECK directives |
| 1 | +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s |
4 | 2 |
|
5 | 3 | #map = affine_map<(d0) -> (d0 * 64)>
|
6 | 4 | #map1 = affine_map<(d0) -> (d0 * 16)>
|
7 | 5 |
|
8 | 6 | func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>) {
|
| 7 | + // CHECK: %[[loadAccumMatmul:.+]] = arith.constant dense<0.000000e+00> : vector<4x32xf16> |
| 8 | + // CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<32xf16> |
| 9 | + // CHECK: %[[colTileShift:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271]> : vector<32xindex> |
| 10 | + // CHECK: %[[loadOffset:.+]] = arith.constant dense<512> : vector<32xindex> |
9 | 11 | %cst = arith.constant 0.000000e+00 : f16
|
10 | 12 | %c1 = arith.constant 1 : index
|
11 | 13 | %c2 = arith.constant 2 : index
|
12 | 14 | %c4 = arith.constant 4 : index
|
13 | 15 | %c16 = arith.constant 16 : index
|
14 |
| - gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) { |
15 |
| - %x_group_idx = affine.apply #map(%arg5) |
16 |
| - %y_group_idx = affine.apply #map(%arg6) |
| 16 | + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c16, %arg16 = %c1) { |
| 17 | + %x_group_idx = affine.apply #map(%arg3) |
| 18 | + %y_group_idx = affine.apply #map(%arg4) |
17 | 19 |
|
18 |
| - %x_thread_idx = affine.apply #map1(%arg8) |
19 |
| - %y_thread_idx = affine.apply #map1(%arg9) |
| 20 | + // CHECK: %[[X_THREAD_IDX:.+]] = affine.apply #map1(%arg6) |
| 21 | + // CHECK: %[[Y_THREAD_IDX:.+]] = affine.apply #map1(%arg7) |
| 22 | + %x_thread_idx = affine.apply #map1(%arg6) |
| 23 | + %y_thread_idx = affine.apply #map1(%arg7) |
20 | 24 |
|
21 | 25 | %x_global_idx = arith.addi %x_group_idx, %x_thread_idx : index
|
22 | 26 | %y_global_idx = arith.addi %y_group_idx, %y_thread_idx : index
|
23 | 27 |
|
24 | 28 | %a_subview = memref.subview %arg0[%x_global_idx, 0] [16, 1024] [1, 1] : memref<128x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>>
|
25 | 29 | %b_subview = memref.subview %arg1[0, %y_global_idx] [1024, 16] [1, 1] : memref<1024x1024xf16> to memref<1024x16xf16, strided<[1024, 1], offset: ?>>
|
26 | 30 |
|
| 31 | + // CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<64x256xf16, 3> |
27 | 32 | %slm_buff = memref.alloc() : memref<64x256xf16, 3>
|
| 33 | + // CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .* |
| 34 | + // CHECK: %[[SLM_X_OFF:.+]] = arith.muli %[[X_THREAD_IDX]], %c256 : index |
| 35 | + // CHECK: %[[SLM_THREAD_OFF:.+]] = arith.addi %[[SLM_X_OFF]], %[[Y_THREAD_IDX]] : index |
| 36 | + // CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c16384], strides: [%c1] : memref<64x256xf16, 3> to memref<16384xf16, 3> |
28 | 37 | %slm_subview = memref.subview %slm_buff[%x_thread_idx, %y_thread_idx] [16, 16] [1, 1] : memref<64x256xf16, 3> to memref<16x16xf16, strided<[256, 1], offset: ?>, 3>
|
29 | 38 |
|
| 39 | + // CHECK: %[[SLM_THREAD_OFF_V:.+]] = vector.splat %[[SLM_THREAD_OFF]] : vector<32xindex> |
| 40 | + // CHECK: %[[DESC_OFFSET0:.+]] = arith.addi %[[SLM_THREAD_OFF_V]], %[[colTileShift]] : vector<32xindex> |
| 41 | + // CHECK: %[[ROOT_DESC:.+]] = xegpu.create_tdesc %[[FLAT_SLM]], %[[DESC_OFFSET0]] : memref<16384xf16, 3>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 1 : i64>> |
| 42 | + // CHECK: %[[FILL_DESC1:.+]] = xegpu.update_offset %[[ROOT_DESC]], %[[loadOffset]] |
| 43 | + // CHECK: %[[FILL_DESC2:.+]] = xegpu.update_offset %[[FILL_DESC1]], %[[loadOffset]] |
| 44 | + // CHECK-COUNT-5: xegpu.update_offset |
| 45 | + |
| 46 | + // CHECK: xegpu.store %[[ZERO]], %[[ROOT_DESC]] |
| 47 | + // CHECK: xegpu.store %[[ZERO]], %[[FILL_DESC1]] |
| 48 | + // CHECK-COUNT-6: xegpu.store |
30 | 49 | linalg.fill ins(%cst : f16) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>)
|
31 |
| - linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) |
32 | 50 |
|
33 |
| - %a_add_subview = memref.subview %arg0[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>> |
34 |
| - %out_subview = memref.subview %arg2[%x_global_idx, %y_global_idx] [16, 16] [1, 1] : memref<128x1024xf16> to memref<16x16xf16, strided<[1024, 1], offset: ?>> |
| 51 | + // CHECK: %[[MATMUL_DESC1:.+]] = xegpu.update_offset %[[ROOT_DESC]], %[[loadOffset]] |
| 52 | + // CHECK: %[[MATMUL_DESC2:.+]] = xegpu.update_offset %[[MATMUL_DESC1]], %[[loadOffset]] |
| 53 | + // CHECK-COUNT-5: xegpu.update_offset |
| 54 | + |
| 55 | + // CHECK: %[[MATMUL_LOAD0:.+]] = xegpu.load %[[ROOT_DESC]] |
| 56 | + // CHECK-NEXT: %[[loadAccumMatmul1:.+]] = vector.insert %[[MATMUL_LOAD0]], %[[loadAccumMatmul]] [0] |
| 57 | + // CHECK-NEXT: %[[MATMUL_LOAD1:.+]] = xegpu.load %[[MATMUL_DESC1]] |
| 58 | + // CHECK-NEXT: %[[loadAccumMatmul2:.+]] = vector.insert %[[MATMUL_LOAD1]], %[[loadAccumMatmul1]] [1] |
| 59 | + // CHECK-COUNT-2: xegpu.load |
| 60 | + |
| 61 | + // CHECK: vector.shape_cast |
| 62 | + // CHECK-SAME: vector<4x32xf16> to vector<128xf16> |
| 63 | + // CHECK: vector.shape_cast |
| 64 | + // CHECK-SAME: vector<128xf16> to vector<8x16xf16> |
35 | 65 |
|
36 |
| - linalg.add ins(%slm_subview, %a_add_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>, memref<16x16xf16, strided<[1024, 1], offset: ?>>) outs(%out_subview : memref<16x16xf16, strided<[1024, 1], offset: ?>>) |
| 66 | + // CHECK-COUNT-4: xegpu.load |
| 67 | + // CHECK: vector.shape_cast |
| 68 | + // CHECK-SAME: vector<4x32xf16> to vector<128xf16> |
| 69 | + // CHECK: vector.shape_cast |
| 70 | + // CHECK-SAME: vector<128xf16> to vector<8x16xf16> |
| 71 | + |
| 72 | + // STORE: |
| 73 | + // %[[FLAT_MATMUL_RES0:.+]] = vector.shape_cast %[[MATMUL_RES0:.+]] : vector<8x16xf16> to vector<128xf16> |
| 74 | + // %[[STORE_TILE0:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES0]] {offsets = [0], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> |
| 75 | + // xegpu.store %[[STORE_TILE0]], %[[ROOT_DESC]] |
| 76 | + // %[[STORE_TILE1:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES0]] {offsets = [32], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> |
| 77 | + // xegpu.store %[[STORE_TILE0]], %[[MATMUL_DESC1]] |
| 78 | + // CHECK-COUNT-2: xegpu.store |
| 79 | + |
| 80 | + // %[[FLAT_MATMUL_RES1:.+]] = vector.shape_cast %[[MATMUL_RES1:.+]] : vector<8x16xf16> to vector<128xf16> |
| 81 | + // %[[STORE_TILE1_0:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES1]] {offsets = [0], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> |
| 82 | + // xegpu.store %[[STORE_TILE1_0]] |
| 83 | + // %[[STORE_TILE1_1:.+]] = vector.extract_strided_slice %[[FLAT_MATMUL_RES1]] {offsets = [32], sizes = [32], strides = [1]} : vector<128xf16> to vector<32xf16> |
| 84 | + // xegpu.store %[[STORE_TILE1_1]] |
| 85 | + // CHECK-COUNT-2: xegpu.store |
| 86 | + |
| 87 | + linalg.matmul ins(%a_subview, %b_subview : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x16xf16, strided<[1024, 1], offset: ?>>) outs(%slm_subview : memref<16x16xf16, strided<[256, 1], offset: ?>, 3>) |
37 | 88 | gpu.terminator
|
38 | 89 | }
|
39 | 90 | return
|
|
0 commit comments