|
17 | 17 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
18 | 18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
19 | 19 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
| 20 | +#include "mlir/IR/Builders.h" |
20 | 21 | #include "mlir/IR/BuiltinTypes.h"
|
| 22 | +#include "mlir/IR/OpDefinition.h" |
21 | 23 | #include "mlir/Support/LogicalResult.h"
|
22 | 24 | #include "mlir/Support/MathExtras.h"
|
23 | 25 | #include "mlir/Transforms/DialectConversion.h"
|
@@ -102,13 +104,64 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
|
102 | 104 | AffineExpr s0;
|
103 | 105 | bindSymbols(builder.getContext(), s0);
|
104 | 106 | int scaleFactor = targetBits / sourceBits;
|
105 |
| - OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply( |
106 |
| - builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx}); |
| 107 | + AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits; |
| 108 | + OpFoldResult offsetVal = |
| 109 | + affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); |
107 | 110 | Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
|
108 | 111 | IntegerType dstType = builder.getIntegerType(targetBits);
|
109 | 112 | return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
|
110 | 113 | }
|
111 | 114 |
|
| 115 | +/// When writing a subbyte size, masked bitwise operations are used to only |
| 116 | +/// modify the relevant bits. This function returns an and mask for clearing |
| 117 | +/// the destination bits in a subbyte write. E.g., when writing to the second |
| 118 | +/// i4 in an i32, 0xFFFFFF0F is created. |
| 119 | +static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, |
| 120 | + int64_t srcBits, int64_t dstBits, |
| 121 | + Value bitwidthOffset, OpBuilder &builder) { |
| 122 | + auto dstIntegerType = builder.getIntegerType(dstBits); |
| 123 | + auto maskRightAlignedAttr = |
| 124 | + builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); |
| 125 | + Value maskRightAligned = builder.create<arith::ConstantOp>( |
| 126 | + loc, dstIntegerType, maskRightAlignedAttr); |
| 127 | + Value writeMaskInverse = |
| 128 | + builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset); |
| 129 | + auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); |
| 130 | + Value flipVal = |
| 131 | + builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr); |
| 132 | + return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal); |
| 133 | +} |
| 134 | + |
| 135 | +/// Returns the scaled linearized index based on the `srcBits` and `dstBits` |
| 136 | +/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and |
| 137 | +/// the returned index has the granularity of `dstBits` |
| 138 | +static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, |
| 139 | + OpFoldResult linearizedIndex, |
| 140 | + int64_t srcBits, int64_t dstBits) { |
| 141 | + AffineExpr s0; |
| 142 | + bindSymbols(builder.getContext(), s0); |
| 143 | + int64_t scaler = dstBits / srcBits; |
| 144 | + OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply( |
| 145 | + builder, loc, s0.floorDiv(scaler), {linearizedIndex}); |
| 146 | + return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices); |
| 147 | +} |
| 148 | + |
| 149 | +static OpFoldResult |
| 150 | +getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, |
| 151 | + const SmallVector<OpFoldResult> &indices, |
| 152 | + Value memref) { |
| 153 | + auto stridedMetadata = |
| 154 | + builder.create<memref::ExtractStridedMetadataOp>(loc, memref); |
| 155 | + OpFoldResult linearizedIndices; |
| 156 | + std::tie(std::ignore, linearizedIndices) = |
| 157 | + memref::getLinearizedMemRefOffsetAndSize( |
| 158 | + builder, loc, srcBits, srcBits, |
| 159 | + stridedMetadata.getConstifiedMixedOffset(), |
| 160 | + stridedMetadata.getConstifiedMixedSizes(), |
| 161 | + stridedMetadata.getConstifiedMixedStrides(), indices); |
| 162 | + return linearizedIndices; |
| 163 | +} |
| 164 | + |
112 | 165 | namespace {
|
113 | 166 |
|
114 | 167 | //===----------------------------------------------------------------------===//
|
@@ -218,32 +271,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
|
218 | 271 | bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
|
219 | 272 | ValueRange{});
|
220 | 273 | } else {
|
221 |
| - SmallVector<OpFoldResult> indices = |
222 |
| - getAsOpFoldResult(adaptor.getIndices()); |
223 |
| - |
224 |
| - auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>( |
225 |
| - loc, op.getMemRef()); |
226 |
| - |
227 | 274 | // Linearize the indices of the original load instruction. Do not account
|
228 | 275 | // for the scaling yet. This will be accounted for later.
|
229 |
| - OpFoldResult linearizedIndices; |
230 |
| - std::tie(std::ignore, linearizedIndices) = |
231 |
| - memref::getLinearizedMemRefOffsetAndSize( |
232 |
| - rewriter, loc, srcBits, srcBits, |
233 |
| - stridedMetadata.getConstifiedMixedOffset(), |
234 |
| - stridedMetadata.getConstifiedMixedSizes(), |
235 |
| - stridedMetadata.getConstifiedMixedStrides(), indices); |
236 |
| - |
237 |
| - AffineExpr s0; |
238 |
| - bindSymbols(rewriter.getContext(), s0); |
239 |
| - int64_t scaler = dstBits / srcBits; |
240 |
| - OpFoldResult scaledLinearizedIndices = |
241 |
| - affine::makeComposedFoldedAffineApply( |
242 |
| - rewriter, loc, s0.floorDiv(scaler), {linearizedIndices}); |
| 276 | + OpFoldResult linearizedIndices = getLinearizedSrcIndices( |
| 277 | + rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); |
| 278 | + |
243 | 279 | Value newLoad = rewriter.create<memref::LoadOp>(
|
244 | 280 | loc, adaptor.getMemref(),
|
245 |
| - getValueOrCreateConstantIndexOp(rewriter, loc, |
246 |
| - scaledLinearizedIndices)); |
| 281 | + getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, |
| 282 | + dstBits)); |
247 | 283 |
|
248 | 284 | // Get the offset and shift the bits to the rightmost.
|
249 | 285 | // Note, currently only the big-endian is supported.
|
@@ -305,6 +341,63 @@ struct ConvertMemRefReinterpretCast final
|
305 | 341 | }
|
306 | 342 | };
|
307 | 343 |
|
| 344 | +//===----------------------------------------------------------------------===// |
| 345 | +// ConvertMemrefStore |
| 346 | +//===----------------------------------------------------------------------===// |
| 347 | + |
| 348 | +struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { |
| 349 | + using OpConversionPattern::OpConversionPattern; |
| 350 | + |
| 351 | + LogicalResult |
| 352 | + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, |
| 353 | + ConversionPatternRewriter &rewriter) const override { |
| 354 | + auto convertedType = adaptor.getMemref().getType().cast<MemRefType>(); |
| 355 | + int srcBits = op.getMemRefType().getElementTypeBitWidth(); |
| 356 | + int dstBits = convertedType.getElementTypeBitWidth(); |
| 357 | + auto dstIntegerType = rewriter.getIntegerType(dstBits); |
| 358 | + if (dstBits % srcBits != 0) { |
| 359 | + return rewriter.notifyMatchFailure( |
| 360 | + op, "only dstBits % srcBits == 0 supported"); |
| 361 | + } |
| 362 | + |
| 363 | + Location loc = op.getLoc(); |
| 364 | + Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, |
| 365 | + adaptor.getValue()); |
| 366 | + |
| 367 | + // Special case 0-rank memref stores. No need for masking. |
| 368 | + if (convertedType.getRank() == 0) { |
| 369 | + rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign, |
| 370 | + extendedInput, adaptor.getMemref(), |
| 371 | + ValueRange{}); |
| 372 | + rewriter.eraseOp(op); |
| 373 | + return success(); |
| 374 | + } |
| 375 | + |
| 376 | + OpFoldResult linearizedIndices = getLinearizedSrcIndices( |
| 377 | + rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); |
| 378 | + Value storeIndices = getIndicesForLoadOrStore( |
| 379 | + rewriter, loc, linearizedIndices, srcBits, dstBits); |
| 380 | + Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, |
| 381 | + dstBits, rewriter); |
| 382 | + Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits, |
| 383 | + dstBits, bitwidthOffset, rewriter); |
| 384 | + // Align the value to write with the destination bits |
| 385 | + Value alignedVal = |
| 386 | + rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset); |
| 387 | + |
| 388 | + // Clear destination bits |
| 389 | + rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, |
| 390 | + writeMask, adaptor.getMemref(), |
| 391 | + storeIndices); |
| 392 | + // Write srcs bits to destination |
| 393 | + rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, |
| 394 | + alignedVal, adaptor.getMemref(), |
| 395 | + storeIndices); |
| 396 | + rewriter.eraseOp(op); |
| 397 | + return success(); |
| 398 | + } |
| 399 | +}; |
| 400 | + |
308 | 401 | //===----------------------------------------------------------------------===//
|
309 | 402 | // ConvertMemRefSubview
|
310 | 403 | //===----------------------------------------------------------------------===//
|
@@ -350,9 +443,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
|
350 | 443 | // Populate `memref.*` conversion patterns.
|
351 | 444 | patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
|
352 | 445 | ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
|
353 |
| - ConvertMemRefAssumeAlignment, ConvertMemRefSubview, |
354 |
| - ConvertMemRefReinterpretCast>(typeConverter, |
355 |
| - patterns.getContext()); |
| 446 | + ConvertMemrefStore, ConvertMemRefAssumeAlignment, |
| 447 | + ConvertMemRefSubview, ConvertMemRefReinterpretCast>( |
| 448 | + typeConverter, patterns.getContext()); |
356 | 449 | memref::populateResolveExtractStridedMetadataPatterns(patterns);
|
357 | 450 | }
|
358 | 451 |
|
|
0 commit comments