Skip to content

Commit d596704

Browse files
committed
Matching unit test with latest implementation
1 parent a09cd51 commit d596704

File tree

2 files changed

+71
-54
lines changed

2 files changed

+71
-54
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
154154

155155
auto stridedMetadata =
156156
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
157-
memref::LinearizedMemRefInfo linearizedInfo;
158157
OpFoldResult linearizedIndices;
159-
std::tie(linearizedInfo, linearizedIndices) =
158+
std::tie(std::ignore, linearizedIndices) =
160159
memref::getLinearizedMemRefOffsetAndSize(
161160
rewriter, loc, elementBitWidth, elementBitWidth,
162161
stridedMetadata.getConstifiedMixedOffset(),
@@ -173,8 +172,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
173172
// the maximum of each dimension size * stride.
174173
SmallVector<AffineExpr> productExpressions;
175174
SmallVector<Value> productResults;
176-
unsigned sourceRank =
177-
cast<ShapedType>(readOp.getSource().getType()).getRank();
175+
unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
178176

179177
SmallVector<AffineExpr> symbols(2 * sourceRank);
180178
SmallVector<Value> offsetValues(2 * sourceRank);

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,55 +9,72 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
99
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
1010
return %res : vector<4xf32>
1111
}
12-
// CHECK: %[[CST:.*]] = arith.constant 0.0
13-
// CHECK: %[[C0:.*]] = arith.constant 0
14-
// CHECK: %[[C1:.*]] = arith.constant 1
15-
// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
16-
// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
17-
// CHECK: %[[C8:.*]] = arith.constant 8
18-
// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
19-
// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
20-
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
21-
// CHECK: %[[C4:.*]] = arith.constant 4
22-
// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
23-
24-
// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
25-
// CHECK: %[[MUL4:.*]] = arith.muli
26-
27-
// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
28-
// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
29-
30-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
31-
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
32-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
12+
13+
// CHECK: %[[FALSE:.*]] = arith.constant false
14+
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
15+
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]]
3316

3417
// CHECK: } else {
35-
// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
18+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
19+
// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]]
3620

3721
// CHECK: return %[[IF]] : vector<4xf32>
3822

3923
// -----
4024

41-
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
42-
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
43-
// CHECK-SAME: %[[ARG1:.*]]: index
44-
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
45-
func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
46-
%cf0 = arith.constant 0.0 : f32
47-
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
48-
return %res : vector<4xf32>
25+
// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
26+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_f16(
27+
// CHECK-SAME: %[[ARG0:.+]]: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>,
28+
// CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
29+
// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>)
30+
func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xf16> {
31+
%cf0 = arith.constant 0.0 : f16
32+
%res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf16>
33+
return %res : vector<4xf16>
4934
}
35+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
36+
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 64
37+
// CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
38+
// CHECK-DAG: %[[VECTORSIZE:.*]] = arith.constant 4
39+
40+
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
41+
// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
42+
// CHECK: %[[COND1:.*]] = arith.cmpi ule, %[[DELTA]], %[[VECTORSIZE]]
43+
44+
// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
45+
// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
46+
// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
47+
48+
// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
49+
// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
50+
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]]
51+
// CHECK: } else {
52+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
53+
// CHECK: return %[[IF]] : vector<4xf16>
54+
55+
// -----
5056

51-
// CHECK: %[[C1:.*]] = arith.constant 1
52-
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
53-
// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
54-
// CHECK: %[[C0:.*]] = arith.constant 0
55-
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
56-
// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
57+
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
58+
// CHECK: #map1 = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
59+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
60+
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
61+
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
62+
// CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
63+
func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xi8> {
64+
%cf0 = arith.constant 0 : i8
65+
%res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
66+
return %res : vector<4xi8>
67+
}
5768

58-
// CHECK: %[[C1_1:.*]] = arith.constant 1
59-
// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
60-
// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
69+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
70+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
71+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
72+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
73+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
74+
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
75+
// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[C1]], %[[SIZES]]#1]
76+
// CHECK: %[[IF:.*]] = scf.if
77+
// CHECK: return
6178

6279
// -----
6380

@@ -70,8 +87,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
7087
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
7188
return %res : vector<4xf32>
7289
}
73-
// CHECK: %[[CST:.*]] = arith.constant 0.0
74-
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
90+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
91+
// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
7592
// CHECK: return %[[RES]] : vector<4xf32>
7693

7794
// -----
@@ -85,8 +102,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
85102
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
86103
return %res : vector<4xf32>
87104
}
88-
// CHECK: %[[CST:.*]] = arith.constant 0.0
89-
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
105+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
106+
// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
90107
// CHECK: return %[[RES]] : vector<4xf32>
91108

92109
// -----
@@ -103,10 +120,11 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
103120
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
104121
return %res : vector<4xf32>
105122
}
106-
// CHECK: %[[CST:.*]] = arith.constant 0.0
107-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
123+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
124+
// CHECK: %[[FALSE:.*]] = arith.constant false
125+
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
108126
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
109-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
127+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
110128
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
111129

112130
// -----
@@ -122,7 +140,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
122140
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
123141
return %res : vector<1xf32>
124142
}
125-
// CHECK: %[[CST:.*]] = arith.constant 0.0
126-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
127-
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
128-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
143+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
144+
// CHECK: %[[FALSE:.*]] = arith.constant false
145+
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<1xf32>) {
146+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
147+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]

0 commit comments

Comments
 (0)