24
24
#include " mlir/Pass/Pass.h"
25
25
#include " mlir/Support/LogicalResult.h"
26
26
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27
- #include " mlir/Transforms/WalkPatternRewriteDriver .h"
27
+ #include " llvm/Support/MathExtras .h"
28
28
29
29
namespace mlir ::amdgpu {
30
30
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -76,6 +76,9 @@ static LogicalResult transferPreconditions(
76
76
if (!memRefType.isLastDimUnitStride ())
77
77
return rewriter.notifyMatchFailure (xferOp, " != 1 stride needs VectorToSCF" );
78
78
79
+ if (memRefType.getElementTypeBitWidth () < 8 )
80
+ return rewriter.notifyMatchFailure (xferOp, " unsupported sub-byte type" );
81
+
79
82
// If there is broadcasting involved then we first load the unbroadcasted
80
83
// vector, and then broadcast it with `vector.broadcast`.
81
84
ArrayRef<int64_t > vectorShape = xferOp.getVectorType ().getShape ();
@@ -127,14 +130,17 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
127
130
return res;
128
131
}
129
132
133
+ static constexpr char kTransferReadNeedsMask [] =
134
+ " amdgpu.buffer_transfer_read_needs_mask" ;
135
+
130
136
namespace {
131
137
132
138
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
133
139
using OpRewritePattern::OpRewritePattern;
134
140
135
141
LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
136
142
PatternRewriter &rewriter) const override {
137
- if (readOp->hasAttr (" amdgpu.buffer_transfer_read_needs_mask " ))
143
+ if (readOp->hasAttr (kTransferReadNeedsMask ))
138
144
return failure ();
139
145
140
146
bool requiresBroadcasting = false ;
@@ -154,71 +160,96 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
154
160
155
161
auto stridedMetadata =
156
162
rewriter.create <memref::ExtractStridedMetadataOp>(loc, src);
163
+ SmallVector<OpFoldResult> strides =
164
+ stridedMetadata.getConstifiedMixedStrides ();
165
+ SmallVector<OpFoldResult> sizes =
166
+ stridedMetadata.getConstifiedMixedSizes ();
167
+ OpFoldResult offset =
168
+ stridedMetadata.getConstifiedMixedOffset ();
157
169
OpFoldResult linearizedIndices;
158
170
std::tie (std::ignore, linearizedIndices) =
159
- memref::getLinearizedMemRefOffsetAndSize (
160
- rewriter, loc, elementBitWidth, elementBitWidth,
161
- stridedMetadata.getConstifiedMixedOffset (),
162
- stridedMetadata.getConstifiedMixedSizes (),
163
- stridedMetadata.getConstifiedMixedStrides (), indices);
164
- Value linearIndex =
165
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
171
+ memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
172
+ elementBitWidth, offset, sizes,
173
+ strides, indices);
166
174
167
175
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
168
176
// Note below doesn't give the correct result for the linearized size.
169
177
// Value totalSize = getValueOrCreateConstantIndexOp(
170
178
// rewriter, loc, linearizedInfo.linearizedSize);
171
- // It compute the mutiplied sizes of all dimensions instead of taking
179
+ // It computes the multiplied sizes of all dimensions instead of taking
172
180
// the maximum of each dimension size * stride.
173
181
SmallVector<AffineExpr> productExpressions;
174
182
SmallVector<Value> productResults;
175
183
unsigned sourceRank = cast<ShapedType>(src.getType ()).getRank ();
176
184
177
185
SmallVector<AffineExpr> symbols (2 * sourceRank);
178
- SmallVector<Value> offsetValues ( 2 * sourceRank) ;
186
+ SmallVector<Value> offsetValues;
179
187
bindSymbolsList (rewriter.getContext (), MutableArrayRef{symbols});
188
+
189
+ size_t symbolIndex = 0 ;
180
190
for (size_t i = 0 ; i < sourceRank; ++i) {
181
- unsigned offsetIdx = 2 * i;
182
- productExpressions.push_back (symbols[offsetIdx] * symbols[offsetIdx + 1 ]);
183
- offsetValues[offsetIdx] = stridedMetadata.getStrides ()[i];
184
- offsetValues[offsetIdx + 1 ] = stridedMetadata.getSizes ()[i];
191
+ AffineExpr strideExpr, sizeExpr;
192
+ OpFoldResult stride = strides[i];
193
+ OpFoldResult size = sizes[i];
194
+ if (auto constantStride =
195
+ getConstantIntValue (stride)) {
196
+ strideExpr = rewriter.getAffineConstantExpr (*constantStride);
197
+ } else {
198
+ strideExpr = symbols[symbolIndex++];
199
+ offsetValues.push_back (getValueOrCreateConstantIndexOp (
200
+ rewriter, loc, stride));
201
+ }
202
+
203
+ if (auto constantSize =
204
+ getConstantIntValue (size)) {
205
+ sizeExpr = rewriter.getAffineConstantExpr (*constantSize);
206
+ } else {
207
+ sizeExpr = symbols[symbolIndex++];
208
+ offsetValues.push_back (getValueOrCreateConstantIndexOp (
209
+ rewriter, loc, size));
210
+ }
211
+
212
+ productExpressions.push_back (strideExpr * sizeExpr);
185
213
}
186
214
187
215
AffineMap maxMap = AffineMap::get (
188
- /* dimCount=*/ 0 , /* symbolCount=*/ symbols. size () , productExpressions,
216
+ /* dimCount=*/ 0 , /* symbolCount=*/ symbolIndex , productExpressions,
189
217
rewriter.getContext ());
190
218
Value totalSize =
191
219
rewriter.create <affine::AffineMaxOp>(loc, maxMap, offsetValues);
192
220
193
221
// delta = bufferSize - linearizedOffset
194
222
Value vectorSizeOffset =
195
223
rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
224
+ Value linearIndex =
225
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
196
226
Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
197
227
198
228
// 1) check if delta < vectorSize
199
229
Value isOutofBounds = rewriter.create <arith::CmpIOp>(
200
- loc, arith::CmpIPredicate::ule , delta, vectorSizeOffset);
230
+ loc, arith::CmpIPredicate::ult , delta, vectorSizeOffset);
201
231
202
232
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
203
233
Value deltaBytes = rewriter.create <arith::MulIOp>(
204
234
loc, delta,
205
235
rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
206
236
Value elementsPerWord = rewriter.create <arith::ConstantIndexOp>(
207
- loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1 );
237
+ loc, llvm::divideCeil ( 32 , elementBitWidth) );
208
238
Value isNotWordAligned = rewriter.create <arith::CmpIOp>(
209
239
loc, arith::CmpIPredicate::ne,
210
240
rewriter.create <arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
211
241
rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
212
242
213
243
// We take the fallback of transfer_read default lowering only it is both
214
- // out-of-bounds and not word aligned.
244
+ // out-of-bounds and not word aligned. The fallback ensures correct results
245
+ // when loading at the boundary of the buffer since buffer load returns
246
+ // inconsistent zeros for the whole word when boundary is crossed.
215
247
Value ifCondition =
216
248
rewriter.create <arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
217
249
218
250
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
219
251
Operation *read = builder.clone (*readOp.getOperation ());
220
- read->setAttr (" amdgpu.buffer_transfer_read_needs_mask" ,
221
- builder.getUnitAttr ());
252
+ read->setAttr (kTransferReadNeedsMask , builder.getUnitAttr ());
222
253
Value readResult = read->getResult (0 );
223
254
builder.create <scf::YieldOp>(loc, readResult);
224
255
};
@@ -243,7 +274,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
243
274
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns (
244
275
RewritePatternSet &patterns) {
245
276
patterns.add <TransferReadLowering>(patterns.getContext ());
246
- vector::populateVectorTransferLoweringPatterns (patterns);
247
277
}
248
278
249
279
struct AmdgpuTransferReadToLoadPass final
0 commit comments