Skip to content

Commit ddee079

Browse files
committed
Addressing review feedbacks
1 parent d596704 commit ddee079

File tree

2 files changed

+61
-32
lines changed

2 files changed

+61
-32
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "mlir/Pass/Pass.h"
2525
#include "mlir/Support/LogicalResult.h"
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27-
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
27+
#include "llvm/Support/MathExtras.h"
2828

2929
namespace mlir::amdgpu {
3030
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -76,6 +76,9 @@ static LogicalResult transferPreconditions(
7676
if (!memRefType.isLastDimUnitStride())
7777
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
7878

79+
if (memRefType.getElementTypeBitWidth() < 8)
80+
return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
81+
7982
// If there is broadcasting involved then we first load the unbroadcasted
8083
// vector, and then broadcast it with `vector.broadcast`.
8184
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
@@ -127,14 +130,17 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
127130
return res;
128131
}
129132

133+
static constexpr char kTransferReadNeedsMask[] =
134+
"amdgpu.buffer_transfer_read_needs_mask";
135+
130136
namespace {
131137

132138
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
133139
using OpRewritePattern::OpRewritePattern;
134140

135141
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
136142
PatternRewriter &rewriter) const override {
137-
if (readOp->hasAttr("amdgpu.buffer_transfer_read_needs_mask"))
143+
if (readOp->hasAttr(kTransferReadNeedsMask))
138144
return failure();
139145

140146
bool requiresBroadcasting = false;
@@ -154,71 +160,96 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
154160

155161
auto stridedMetadata =
156162
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
163+
SmallVector<OpFoldResult> strides =
164+
stridedMetadata.getConstifiedMixedStrides();
165+
SmallVector<OpFoldResult> sizes =
166+
stridedMetadata.getConstifiedMixedSizes();
167+
OpFoldResult offset =
168+
stridedMetadata.getConstifiedMixedOffset();
157169
OpFoldResult linearizedIndices;
158170
std::tie(std::ignore, linearizedIndices) =
159-
memref::getLinearizedMemRefOffsetAndSize(
160-
rewriter, loc, elementBitWidth, elementBitWidth,
161-
stridedMetadata.getConstifiedMixedOffset(),
162-
stridedMetadata.getConstifiedMixedSizes(),
163-
stridedMetadata.getConstifiedMixedStrides(), indices);
164-
Value linearIndex =
165-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
171+
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
172+
elementBitWidth, offset, sizes,
173+
strides, indices);
166174

167175
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
168176
// Note below doesn't give the correct result for the linearized size.
169177
// Value totalSize = getValueOrCreateConstantIndexOp(
170178
// rewriter, loc, linearizedInfo.linearizedSize);
171-
// It compute the mutiplied sizes of all dimensions instead of taking
179+
// It computes the multiplied sizes of all dimensions instead of taking
172180
// the maximum of each dimension size * stride.
173181
SmallVector<AffineExpr> productExpressions;
174182
SmallVector<Value> productResults;
175183
unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
176184

177185
SmallVector<AffineExpr> symbols(2 * sourceRank);
178-
SmallVector<Value> offsetValues(2 * sourceRank);
186+
SmallVector<Value> offsetValues;
179187
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
188+
189+
size_t symbolIndex = 0;
180190
for (size_t i = 0; i < sourceRank; ++i) {
181-
unsigned offsetIdx = 2 * i;
182-
productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
183-
offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
184-
offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
191+
AffineExpr strideExpr, sizeExpr;
192+
OpFoldResult stride = strides[i];
193+
OpFoldResult size = sizes[i];
194+
if (auto constantStride =
195+
getConstantIntValue(stride)) {
196+
strideExpr = rewriter.getAffineConstantExpr(*constantStride);
197+
} else {
198+
strideExpr = symbols[symbolIndex++];
199+
offsetValues.push_back(getValueOrCreateConstantIndexOp(
200+
rewriter, loc, stride));
201+
}
202+
203+
if (auto constantSize =
204+
getConstantIntValue(size)) {
205+
sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
206+
} else {
207+
sizeExpr = symbols[symbolIndex++];
208+
offsetValues.push_back(getValueOrCreateConstantIndexOp(
209+
rewriter, loc, size));
210+
}
211+
212+
productExpressions.push_back(strideExpr * sizeExpr);
185213
}
186214

187215
AffineMap maxMap = AffineMap::get(
188-
/*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
216+
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
189217
rewriter.getContext());
190218
Value totalSize =
191219
rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
192220

193221
// delta = bufferSize - linearizedOffset
194222
Value vectorSizeOffset =
195223
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
224+
Value linearIndex =
225+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
196226
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
197227

198228
// 1) check if delta < vectorSize
199229
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
200-
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
230+
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
201231

202232
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
203233
Value deltaBytes = rewriter.create<arith::MulIOp>(
204234
loc, delta,
205235
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
206236
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
207-
loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
237+
loc, llvm::divideCeil(32, elementBitWidth));
208238
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
209239
loc, arith::CmpIPredicate::ne,
210240
rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
211241
rewriter.create<arith::ConstantIndexOp>(loc, 0));
212242

213243
// We take the fallback of transfer_read default lowering only it is both
214-
// out-of-bounds and not word aligned.
244+
// out-of-bounds and not word aligned. The fallback ensures correct results
245+
// when loading at the boundary of the buffer since buffer load returns
246+
// inconsistent zeros for the whole word when boundary is crossed.
215247
Value ifCondition =
216248
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
217249

218250
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
219251
Operation *read = builder.clone(*readOp.getOperation());
220-
read->setAttr("amdgpu.buffer_transfer_read_needs_mask",
221-
builder.getUnitAttr());
252+
read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
222253
Value readResult = read->getResult(0);
223254
builder.create<scf::YieldOp>(loc, readResult);
224255
};
@@ -243,7 +274,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
243274
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
244275
RewritePatternSet &patterns) {
245276
patterns.add<TransferReadLowering>(patterns.getContext());
246-
vector::populateVectorTransferLoweringPatterns(patterns);
247277
}
248278

249279
struct AmdgpuTransferReadToLoadPass final

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
1212

1313
// CHECK: %[[FALSE:.*]] = arith.constant false
1414
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
15-
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]]
15+
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
1616

1717
// CHECK: } else {
1818
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
@@ -39,23 +39,23 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
3939

4040
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
4141
// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
42-
// CHECK: %[[COND1:.*]] = arith.cmpi ule, %[[DELTA]], %[[VECTORSIZE]]
42+
// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[VECTORSIZE]]
4343

4444
// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
4545
// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
4646
// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
4747

4848
// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
4949
// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
50-
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]]
50+
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG2]]]
5151
// CHECK: } else {
5252
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
5353
// CHECK: return %[[IF]] : vector<4xf16>
5454

5555
// -----
5656

5757
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
58-
// CHECK: #map1 = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
58+
// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
5959
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
6060
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
6161
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
@@ -69,10 +69,9 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
6969
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
7070
// CHECK: %[[C0:.*]] = arith.constant 0 : index
7171
// CHECK: %[[C4:.*]] = arith.constant 4 : index
72-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
7372
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
7473
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
75-
// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[C1]], %[[SIZES]]#1]
74+
// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
7675
// CHECK: %[[IF:.*]] = scf.if
7776
// CHECK: return
7877

@@ -87,8 +86,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
8786
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
8887
return %res : vector<4xf32>
8988
}
90-
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
91-
// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
89+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
90+
// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
9291
// CHECK: return %[[RES]] : vector<4xf32>
9392

9493
// -----
@@ -102,8 +101,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
102101
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
103102
return %res : vector<4xf32>
104103
}
105-
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
106-
// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
104+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
105+
// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
107106
// CHECK: return %[[RES]] : vector<4xf32>
108107

109108
// -----

0 commit comments

Comments
 (0)