|
1 |
| -// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s |
| 1 | +// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s |
2 | 2 |
|
3 | 3 | // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
|
4 |
| - func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, |
| 4 | + func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, |
5 | 5 | %readRow: index, %readCol: index,
|
6 | 6 | %writeRow: index, %writeCol: index,
|
7 |
| - %fragRow: index, %fragCol: index, |
| 7 | + %fragRow: index, %fragCol: index, |
8 | 8 | %fragColPerm: index,
|
9 | 9 | %stRow: index, %stCol: index) {
|
10 |
| - // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16 |
| 10 | + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16 |
11 | 11 | %cst = arith.constant 0.000000e+00 : f16
|
12 | 12 |
|
13 | 13 | // CHECK: [[shmA:%.+]] = memref.alloc
|
14 | 14 | // CHECK: [[shmB:%.+]] = memref.alloc
|
15 | 15 | %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
|
16 | 16 | %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
|
17 | 17 |
|
18 |
| - // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> |
19 | 18 | %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
|
20 |
| - // CHECK: [[c7:%.+]] = arith.constant 7 : index |
21 |
| - // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]] |
22 |
| - // CHECK: [[c2:%.+]] = arith.constant 2 : index |
23 |
| - // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
24 |
| - // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] |
25 |
| - // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3> |
| 19 | + // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| 20 | + // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]] |
| 21 | + // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| 22 | + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| 23 | + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] |
26 | 24 | vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
|
27 | 25 | gpu.barrier
|
28 | 26 | gpu.barrier
|
29 |
| - // CHECK: [[c7:%.+]] = arith.constant 7 : index |
30 |
| - // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]] |
31 |
| - // CHECK: [[c2:%.+]] = arith.constant 2 : index |
32 |
| - // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| 27 | + // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| 28 | + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| 29 | + // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| 30 | + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
33 | 31 | // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
|
34 |
| - // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16> |
35 | 32 | %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
|
36 | 33 |
|
37 |
| - // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> |
38 | 34 | %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
|
39 |
| - // CHECK: [[c7:%.+]] = arith.constant 7 : index |
40 |
| - // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]] |
41 |
| - // CHECK: [[c2:%.+]] = arith.constant 2 : index |
42 |
| - // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
43 |
| - // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] |
44 |
| - // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3> |
| 35 | + // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| 36 | + // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]] |
| 37 | + // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| 38 | + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| 39 | + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] |
45 | 40 | vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
|
46 | 41 | gpu.barrier
|
47 | 42 | gpu.barrier
|
48 |
| - // CHECK: [[c7:%.+]] = arith.constant 7 : index |
49 |
| - // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]] |
50 |
| - // CHECK: [[c2:%.+]] = arith.constant 2 : index |
51 |
| - // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
| 43 | + // CHECK: [[c6:%.+]] = arith.constant 6 : index |
| 44 | + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] |
| 45 | + // CHECK: [[c2:%.+]] = arith.constant 2 : index |
| 46 | + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] |
52 | 47 | // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
|
53 |
| - // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16> |
54 | 48 | %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
|
55 | 49 | return
|
56 | 50 | }
|
0 commit comments