Skip to content

Commit cfc1e1d

Browse files
committed
[MLIR] VectorEmulateNarrowType support unaligned cases
Previously the pass only supports emulation of vector sizes that are a multiple of emulated data type (i8). This patch expands its support of emulation which's size are not a multiple of byte sizes, such as `vector<3xi2>`. A limitation of this patch is that the linearized index of the unaligned vector has to be known at compile time. Extra code needs to be emitted to handle it if the condition does not hold. The following ops are updated: * `vector::LoadOp` * `vector::StoreOp` * `vector::TransferReadOp`
1 parent 2a25200 commit cfc1e1d

File tree

4 files changed

+204
-31
lines changed

4 files changed

+204
-31
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ namespace memref {
3232
bool isStaticShapeAndContiguousRowMajor(MemRefType type);
3333

3434
/// For a `memref` with `offset`, `sizes` and `strides`, returns the
35-
/// offset and size to use for the linearized `memref`.
35+
/// offset, size, and potentially the size padded at the front to use for the
36+
/// linearized `memref`.
3637
/// - If the linearization is done for emulating load/stores of
3738
/// element type with bitwidth `srcBits` using element type with
3839
/// bitwidth `dstBits`, the linearized offset and size are
@@ -42,9 +43,15 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
4243
/// index to use in the linearized `memref`. The linearized index
4344
/// is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
4445
/// 0, is returned for the linearized index.
46+
/// - If the size of the load/store is smaller than the linearized memref
47+
/// load/store,
48+
/// the memory region emulated is larger than the actual memory region needed.
49+
/// `frontPaddingSize` returns the size of the irrelevant offset at the
50+
/// beginning.
4551
struct LinearizedMemRefInfo {
4652
OpFoldResult linearizedOffset;
4753
OpFoldResult linearizedSize;
54+
OpFoldResult frontPaddingSize;
4855
};
4956
std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
5057
OpBuilder &builder, Location loc, int srcBits, int dstBits,

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
8181

8282
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8383
int64_t scaler = dstBits / srcBits;
84-
addMulMap = addMulMap.floorDiv(scaler);
8584
mulMap = mulMap.floorDiv(scaler);
8685

8786
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
88-
builder, loc, addMulMap, offsetValues);
87+
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
8988
OpFoldResult linearizedSize =
9089
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
9190

@@ -95,7 +94,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
9594
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
9695
builder, loc, s0.floorDiv(scaler), {offset});
9796

98-
return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
97+
OpFoldResult frontPaddingSize = affine::makeComposedFoldedAffineApply(
98+
builder, loc, addMulMap % scaler, offsetValues);
99+
100+
return {{adjustBaseOffset, linearizedSize, frontPaddingSize},
101+
linearizedIndices};
99102
}
100103

101104
LinearizedMemRefInfo

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 135 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/Support/Debug.h"
2525
#include "llvm/Support/raw_ostream.h"
2626
#include <cstdint>
27+
#include <optional>
2728

2829
using namespace mlir;
2930

@@ -102,6 +103,23 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
102103
return newMask;
103104
}
104105

106+
///
107+
static std::optional<int64_t>
108+
getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
109+
const memref::LinearizedMemRefInfo linearizedInfo,
110+
bool isUnalignedEmulation) {
111+
if (!isUnalignedEmulation)
112+
return 0;
113+
auto foldedFrontPaddingSize = getValueOrCreateConstantIndexOp(
114+
rewriter, loc, linearizedInfo.frontPaddingSize);
115+
// try to fold the front padding size into a constant
116+
if (auto frontPadding = dyn_cast_or_null<arith::ConstantIndexOp>(
117+
foldedFrontPaddingSize.getDefiningOp())) {
118+
return frontPadding.value();
119+
}
120+
return std::nullopt;
121+
}
122+
105123
namespace {
106124

107125
//===----------------------------------------------------------------------===//
@@ -142,29 +160,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
142160
// vector<4xi8>
143161

144162
auto origElements = op.getValueToStore().getType().getNumElements();
145-
if (origElements % scale != 0)
146-
return failure();
163+
164+
// if the size of vector we are loading is not byte-aligned, extra handling
165+
// is needed
166+
bool isUnalignedEmulation = origElements % scale != 0;
147167

148168
auto stridedMetadata =
149169
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
150170

151171
OpFoldResult linearizedIndices;
152-
std::tie(std::ignore, linearizedIndices) =
172+
memref::LinearizedMemRefInfo linearizedInfo;
173+
std::tie(linearizedInfo, linearizedIndices) =
153174
memref::getLinearizedMemRefOffsetAndSize(
154175
rewriter, loc, srcBits, dstBits,
155176
stridedMetadata.getConstifiedMixedOffset(),
156177
stridedMetadata.getConstifiedMixedSizes(),
157178
stridedMetadata.getConstifiedMixedStrides(),
158179
getAsOpFoldResult(adaptor.getIndices()));
159180

160-
auto numElements = origElements / scale;
161-
auto bitCast = rewriter.create<vector::BitCastOp>(
162-
loc, VectorType::get(numElements, newElementType),
163-
op.getValueToStore());
181+
auto foldedFrontPaddingSize = getFrontPaddingSize(
182+
rewriter, loc, linearizedInfo, isUnalignedEmulation);
164183

165-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
166-
op, bitCast.getResult(), adaptor.getBase(),
167-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
184+
if (!foldedFrontPaddingSize) {
185+
// unimplemented case for dynamic front padding size
186+
return failure();
187+
}
188+
189+
auto numElements =
190+
(*foldedFrontPaddingSize + origElements + scale - 1) / scale;
191+
auto newVectorType = VectorType::get(numElements, newElementType);
192+
193+
if (isUnalignedEmulation) {
194+
auto insertedVectorType =
195+
VectorType::get(numElements * scale, oldElementType);
196+
197+
auto linearizedIndicesValue =
198+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
199+
auto passThru =
200+
rewriter.create<vector::LoadOp>(loc, newVectorType, adaptor.getBase(),
201+
ValueRange{linearizedIndicesValue});
202+
auto bitcastedPassThru =
203+
rewriter.create<vector::BitCastOp>(loc, insertedVectorType, passThru);
204+
205+
// just extract it and use it for the strided slice offset
206+
auto insertStridedSlice = rewriter.create<vector::InsertStridedSliceOp>(
207+
loc, insertedVectorType, op.getValueToStore(), bitcastedPassThru,
208+
rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
209+
rewriter.getI64ArrayAttr({1}));
210+
// bit cast the vector to the original type
211+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, newVectorType,
212+
insertStridedSlice);
213+
214+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
215+
op, bitCast.getResult(), adaptor.getBase(), linearizedIndicesValue);
216+
} else {
217+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, newVectorType,
218+
op.getValueToStore());
219+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
220+
op, bitCast.getResult(), adaptor.getBase(),
221+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
222+
}
168223
return success();
169224
}
170225
};
@@ -294,35 +349,67 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
294349
// %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
295350
// %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
296351
//
297-
// TODO: Currently, only the even number of elements loading is supported.
298-
// To deal with the odd number of elements, one has to extract the
299-
// subvector at the proper offset after bit-casting.
352+
// There are cases where the number of elements to load is not byte-aligned,
353+
// for example:
354+
//
355+
// %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
356+
//
357+
// we will have to load extra bytes and extract the exact slice in between.
358+
//
359+
// %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
360+
// %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
361+
// %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
362+
// = [1]}
363+
// : vector<8xi2> to vector<3xi2>
364+
//
365+
// TODO: Currently the extract_strided_slice's attributes must be known at
366+
// compile time as they must be constants.
300367

301368
auto origElements = op.getVectorType().getNumElements();
302-
if (origElements % scale != 0)
303-
return failure();
369+
bool isUnalignedEmulation = origElements % scale != 0;
304370

305371
auto stridedMetadata =
306372
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
307373

308374
OpFoldResult linearizedIndices;
309-
std::tie(std::ignore, linearizedIndices) =
375+
memref::LinearizedMemRefInfo linearizedInfo;
376+
std::tie(linearizedInfo, linearizedIndices) =
310377
memref::getLinearizedMemRefOffsetAndSize(
311378
rewriter, loc, srcBits, dstBits,
312379
stridedMetadata.getConstifiedMixedOffset(),
313380
stridedMetadata.getConstifiedMixedSizes(),
314381
stridedMetadata.getConstifiedMixedStrides(),
315382
getAsOpFoldResult(adaptor.getIndices()));
316383

317-
auto numElements = (origElements + scale - 1) / scale;
384+
auto foldedFrontPaddingSize = getFrontPaddingSize(
385+
rewriter, loc, linearizedInfo, isUnalignedEmulation);
386+
387+
if (!foldedFrontPaddingSize) {
388+
// unimplemented case for dynamic front padding size
389+
return failure();
390+
}
391+
392+
auto numElements =
393+
(*foldedFrontPaddingSize + origElements + scale - 1) / scale;
394+
auto loadVectorType = VectorType::get(numElements, newElementType);
318395
auto newLoad = rewriter.create<vector::LoadOp>(
319-
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
396+
loc, loadVectorType, adaptor.getBase(),
320397
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
321398

399+
auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
322400
auto bitCast =
323-
rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
324-
325-
rewriter.replaceOp(op, bitCast->getResult(0));
401+
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
402+
403+
if (newBitCastType.getNumElements() != origElements) {
404+
auto extractStridedSlice = rewriter.create<vector::ExtractStridedSliceOp>(
405+
loc, op.getType(), bitCast,
406+
rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
407+
rewriter.getI64ArrayAttr({origElements}),
408+
rewriter.getI64ArrayAttr({1}));
409+
rewriter.replaceOp(op, extractStridedSlice.getResult());
410+
} else {
411+
rewriter.replaceOp(op, bitCast->getResult(0));
412+
}
326413
return success();
327414
}
328415
};
@@ -464,8 +551,8 @@ struct ConvertVectorTransferRead final
464551
int scale = dstBits / srcBits;
465552

466553
auto origElements = op.getVectorType().getNumElements();
467-
if (origElements % scale != 0)
468-
return failure();
554+
555+
bool isUnalignedEmulation = origElements % scale != 0;
469556

470557
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
471558
adaptor.getPadding());
@@ -474,26 +561,47 @@ struct ConvertVectorTransferRead final
474561
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
475562

476563
OpFoldResult linearizedIndices;
477-
std::tie(std::ignore, linearizedIndices) =
564+
memref::LinearizedMemRefInfo linearizedInfo;
565+
std::tie(linearizedInfo, linearizedIndices) =
478566
memref::getLinearizedMemRefOffsetAndSize(
479567
rewriter, loc, srcBits, dstBits,
480568
stridedMetadata.getConstifiedMixedOffset(),
481569
stridedMetadata.getConstifiedMixedSizes(),
482570
stridedMetadata.getConstifiedMixedStrides(),
483571
getAsOpFoldResult(adaptor.getIndices()));
484572

485-
auto numElements = (origElements + scale - 1) / scale;
573+
auto foldedFrontPaddingSize = getFrontPaddingSize(
574+
rewriter, loc, linearizedInfo, isUnalignedEmulation);
575+
576+
if (!foldedFrontPaddingSize) {
577+
// unimplemented case for dynamic front padding size
578+
return failure();
579+
}
580+
581+
auto numElements =
582+
(*foldedFrontPaddingSize + origElements + scale - 1) / scale;
486583
auto newReadType = VectorType::get(numElements, newElementType);
487584

488585
auto newRead = rewriter.create<vector::TransferReadOp>(
489586
loc, newReadType, adaptor.getSource(),
490587
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
491588
newPadding);
492589

590+
auto bitCastType = VectorType::get(numElements * scale, oldElementType);
493591
auto bitCast =
494-
rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
592+
rewriter.create<vector::BitCastOp>(loc, bitCastType, newRead);
593+
594+
if (isUnalignedEmulation) {
595+
// we only extract a portion of the vector.
596+
rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
597+
op, op.getType(), bitCast,
598+
rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
599+
rewriter.getI64ArrayAttr({origElements}),
600+
rewriter.getI64ArrayAttr({1}));
601+
} else {
602+
rewriter.replaceOp(op, bitCast->getResult(0));
603+
}
495604

496-
rewriter.replaceOp(op, bitCast->getResult(0));
497605
return success();
498606
}
499607
};
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
2+
3+
func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
4+
%0 = memref.alloc() : memref<3x3xi2>
5+
%c0 = arith.constant 0 : index
6+
%c2 = arith.constant 2 : index
7+
%cst = arith.constant dense<0> : vector<3x3xi2>
8+
%1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
9+
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
10+
return %2 : vector<3x3xi2>
11+
}
12+
13+
// CHECK: func @vector_load_i2
14+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
15+
// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
16+
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
17+
// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<2xi8> to vector<8xi2>
18+
// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
19+
20+
//-----
21+
22+
func.func @vector_store_i2(%arg0: vector<3xi2>) {
23+
%0 = memref.alloc() : memref<3x3xi2>
24+
%c0 = arith.constant 0 : index
25+
%c2 = arith.constant 2 : index
26+
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
27+
return
28+
}
29+
30+
// CHECK: func @vector_store_i2
31+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
32+
// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
33+
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
34+
// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<8xi2>
35+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %arg0, %[[BITCAST1]] {offsets = [2], strides = [1]} : vector<3xi2> into vector<8xi2>
36+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
37+
// CHECK: vector.store %[[BITCAST2]], %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
38+
39+
//-----
40+
41+
func.func @vector_transfer_read_i2() -> vector<3xi2> {
42+
%0 = memref.alloc() : memref<3x3xi2>
43+
%c0i2 = arith.constant 0 : i2
44+
%c0 = arith.constant 0 : index
45+
%c2 = arith.constant 2 : index
46+
%1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
47+
return %1 : vector<3xi2>
48+
}
49+
50+
// CHECK: func @vector_transfer_read_i2
51+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
52+
// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
53+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
54+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
55+
// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>

0 commit comments

Comments
 (0)