Skip to content

Commit 2c31325

Browse files
authored
[MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (#113411)
Previously, the pass only supported emulation of loading vector sizes that are multiples of the emulated data type. This patch expands its support for emulating sizes that are not multiples of byte sizes. In such cases, the element values are packed back-to-back to preserve memory space. To give a concrete example: if an input has type `memref<3x3xi2>`, it is actually occupying 3 bytes in memory, with the first 18 bits storing the values and the last 6 bits as padding. The slice of `vector<3xi2>` at index `[2, 0]` is stored in memory from bit 12 to bit 18. To properly load the elements from bit 12 to bit 18 from memory, first load byte 2 and byte 3, and convert it to a vector of `i2` type; then extract bits 4 to 10 (element index 2-5) to form a `vector<3xi2>`. A limitation of this patch is that the linearized index of the unaligned vector has to be known at compile time. Extra code needs to be emitted to handle it if the condition does not hold. The following ops are updated: * `vector::LoadOp` * `vector::TransferReadOp` * `vector::MaskedLoadOp`
1 parent c62130f commit 2c31325

File tree

4 files changed

+264
-55
lines changed

4 files changed

+264
-55
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ namespace memref {
3232
bool isStaticShapeAndContiguousRowMajor(MemRefType type);
3333

3434
/// For a `memref` with `offset`, `sizes` and `strides`, returns the
35-
/// offset and size to use for the linearized `memref`.
35+
/// offset, size, and potentially the size padded at the front to use for the
36+
/// linearized `memref`.
3637
/// - If the linearization is done for emulating load/stores of
3738
/// element type with bitwidth `srcBits` using element type with
3839
/// bitwidth `dstBits`, the linearized offset and size are
@@ -42,9 +43,14 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
4243
/// index to use in the linearized `memref`. The linearized index
4344
/// is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
4445
/// 0, is returned for the linearized index.
46+
/// - If the size of the load/store is smaller than the linearized memref
47+
/// load/store, the memory region emulated is larger than the actual memory
48+
/// region needed. `intraDataOffset` returns the element offset of the data
49+
/// relevant at the beginning.
4550
struct LinearizedMemRefInfo {
4651
OpFoldResult linearizedOffset;
4752
OpFoldResult linearizedSize;
53+
OpFoldResult intraDataOffset;
4854
};
4955
std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
5056
OpBuilder &builder, Location loc, int srcBits, int dstBits,

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
8181

8282
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8383
int64_t scaler = dstBits / srcBits;
84-
addMulMap = addMulMap.floorDiv(scaler);
8584
mulMap = mulMap.floorDiv(scaler);
8685

8786
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
88-
builder, loc, addMulMap, offsetValues);
87+
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
8988
OpFoldResult linearizedSize =
9089
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
9190

@@ -95,7 +94,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
9594
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
9695
builder, loc, s0.floorDiv(scaler), {offset});
9796

98-
return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
97+
OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
98+
builder, loc, addMulMap % scaler, offsetValues);
99+
100+
return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
101+
linearizedIndices};
99102
}
100103

101104
LinearizedMemRefInfo

0 commit comments

Comments
 (0)