Skip to content

Commit edfe3d4

Browse files
committed
Implement vector stores
1 parent 9d85ba5 commit edfe3d4

File tree

2 files changed

+313
-32
lines changed

2 files changed

+313
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 180 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Transforms/DialectConversion.h"
3434
#include "llvm/ADT/SmallVector.h"
3535
#include "llvm/Support/Debug.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include <cstdint>
@@ -143,13 +144,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
143144
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
144145
/// emitting `vector.extract_strided_slice`.
145146
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
146-
VectorType extractType, Value source,
147-
int64_t frontOffset,
147+
Value source, int64_t frontOffset,
148148
int64_t subvecSize) {
149149
auto vectorType = cast<VectorType>(source.getType());
150-
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
151-
"expected 1-D source and destination types");
152-
(void)vectorType;
150+
assert(vectorType.getRank() == 1 && "expected 1-D source types");
153151
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
154152
"subvector out of bounds");
155153

@@ -160,9 +158,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
160158
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
161159
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
162160
auto strides = rewriter.getI64ArrayAttr({1});
161+
162+
auto resultVectorType =
163+
VectorType::get({subvecSize}, vectorType.getElementType());
163164
return rewriter
164-
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
165-
sizes, strides)
165+
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
166+
offsets, sizes, strides)
166167
->getResult(0);
167168
}
168169

@@ -171,12 +172,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
171172
/// `vector.insert_strided_slice`.
172173
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
173174
Value src, Value dest, int64_t offset) {
174-
auto srcType = cast<VectorType>(src.getType());
175-
auto destType = cast<VectorType>(dest.getType());
175+
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
176+
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
176177
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
177178
"expected source and dest to be vector type");
178-
(void)srcType;
179-
(void)destType;
180179
auto offsets = rewriter.getI64ArrayAttr({offset});
181180
auto strides = rewriter.getI64ArrayAttr({1});
182181
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -243,6 +242,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
243242
newLoad);
244243
}
245244

245+
static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
246+
Value memref, Value index, Value value) {
247+
auto originType = dyn_cast<VectorType>(value.getType());
248+
auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
249+
auto scale = memrefElemType.getIntOrFloatBitWidth() /
250+
originType.getElementType().getIntOrFloatBitWidth();
251+
auto storeType =
252+
VectorType::get({originType.getNumElements() / scale}, memrefElemType);
253+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
254+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
255+
}
256+
257+
/// atomically store a subbyte-sized value to memory, with a mask.
258+
static Value atomicStore(OpBuilder &rewriter, Location loc,
259+
Value emulatedMemref, Value emulatedIndex,
260+
TypedValue<VectorType> value, Value mask,
261+
int64_t scale) {
262+
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
263+
loc, emulatedMemref, ValueRange{emulatedIndex});
264+
OpBuilder builder =
265+
OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
266+
Value origValue = atomicOp.getCurrentValue();
267+
268+
// i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
269+
auto oneVectorType = VectorType::get({1}, origValue.getType());
270+
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
271+
ValueRange{origValue});
272+
auto vectorBitCast =
273+
builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
274+
275+
auto select =
276+
builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
277+
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
278+
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
279+
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
280+
return atomicOp;
281+
}
282+
283+
// Extract a slice of a vector, and insert it into a byte vector.
284+
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
285+
Location loc, TypedValue<VectorType> vector,
286+
int64_t sliceOffset, int64_t sliceNumElements,
287+
int64_t byteOffset) {
288+
auto vectorElementType = vector.getType().getElementType();
289+
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
290+
"vector element must be a valid sub-byte type");
291+
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
292+
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
293+
loc, VectorType::get({scale}, vectorElementType),
294+
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
295+
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
296+
sliceOffset, sliceNumElements);
297+
auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
298+
emptyByteVector, byteOffset);
299+
return inserted;
300+
}
301+
246302
namespace {
247303

248304
//===----------------------------------------------------------------------===//
@@ -263,7 +319,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
263319

264320
auto loc = op.getLoc();
265321
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
266-
Type oldElementType = op.getValueToStore().getType().getElementType();
322+
auto valueToStore = op.getValueToStore();
323+
Type oldElementType = valueToStore.getType().getElementType();
267324
Type newElementType = convertedType.getElementType();
268325
int srcBits = oldElementType.getIntOrFloatBitWidth();
269326
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -287,30 +344,124 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
287344
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
288345
// vector<4xi8>
289346

290-
auto origElements = op.getValueToStore().getType().getNumElements();
291-
if (origElements % scale != 0)
292-
return failure();
347+
auto origElements = valueToStore.getType().getNumElements();
348+
bool isUnalignedEmulation = origElements % scale != 0;
293349

294350
auto stridedMetadata =
295351
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
296352

297353
OpFoldResult linearizedIndices;
298-
std::tie(std::ignore, linearizedIndices) =
354+
memref::LinearizedMemRefInfo linearizedInfo;
355+
std::tie(linearizedInfo, linearizedIndices) =
299356
memref::getLinearizedMemRefOffsetAndSize(
300357
rewriter, loc, srcBits, dstBits,
301358
stridedMetadata.getConstifiedMixedOffset(),
302359
stridedMetadata.getConstifiedMixedSizes(),
303360
stridedMetadata.getConstifiedMixedStrides(),
304361
getAsOpFoldResult(adaptor.getIndices()));
305362

306-
auto numElements = origElements / scale;
307-
auto bitCast = rewriter.create<vector::BitCastOp>(
308-
loc, VectorType::get(numElements, newElementType),
309-
op.getValueToStore());
363+
auto foldedIntraVectorOffset =
364+
isUnalignedEmulation
365+
? getConstantIntValue(linearizedInfo.intraDataOffset)
366+
: 0;
367+
368+
if (!foldedIntraVectorOffset) {
369+
// unimplemented case for dynamic front padding size
370+
return failure();
371+
}
372+
373+
// conditions when atomic stores and all that are not needed:
374+
// 1. The source vector size is multiple of byte size
375+
// 2. The address of the store is byte aligned
376+
if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
377+
auto numElements = origElements / scale;
378+
auto bitCast = rewriter.create<vector::BitCastOp>(
379+
loc, VectorType::get(numElements, newElementType),
380+
op.getValueToStore());
381+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
382+
op, bitCast.getResult(), adaptor.getBase(),
383+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
384+
return llvm::success();
385+
}
386+
387+
Value emulatedMemref = adaptor.getBase();
388+
// the index into the target memref we are storing to
389+
Value currentDestIndex =
390+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
391+
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
392+
auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
393+
// the index into the source vector we are currently processing
394+
auto currentSourceIndex = 0;
395+
396+
// 1. atomic store for the first byte
397+
auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
398+
if (frontAtomicStoreElem != 0) {
399+
auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
400+
if (*foldedIntraVectorOffset + origElements < scale) {
401+
std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
402+
origElements, true);
403+
frontAtomicStoreElem = origElements;
404+
} else {
405+
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
406+
*foldedIntraVectorOffset, true);
407+
}
408+
auto frontMask = rewriter.create<arith::ConstantOp>(
409+
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
410+
411+
currentSourceIndex = scale - (*foldedIntraVectorOffset);
412+
auto value = extractSliceIntoByte(
413+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
414+
frontAtomicStoreElem, *foldedIntraVectorOffset);
415+
416+
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
417+
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
418+
scale);
419+
420+
currentDestIndex = rewriter.create<arith::AddIOp>(
421+
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
422+
}
423+
424+
if (currentSourceIndex >= origElements) {
425+
rewriter.eraseOp(op);
426+
return success();
427+
}
428+
429+
// 2. non-atomic store
430+
int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
431+
int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
432+
if (nonAtomicStoreSize != 0) {
433+
auto nonAtomicStorePart = staticallyExtractSubvector(
434+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
435+
currentSourceIndex, numNonAtomicElements);
436+
437+
nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
438+
nonAtomicStorePart);
439+
440+
currentSourceIndex += numNonAtomicElements;
441+
currentDestIndex = rewriter.create<arith::AddIOp>(
442+
loc, rewriter.getIndexType(), currentDestIndex,
443+
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
444+
}
445+
446+
// 3. atomic store for the last byte
447+
auto remainingElements = origElements - currentSourceIndex;
448+
if (remainingElements != 0) {
449+
auto atomicStorePart = extractSliceIntoByte(
450+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
451+
currentSourceIndex, remainingElements, 0);
452+
453+
// back mask
454+
auto maskValues = llvm::SmallVector<bool>(scale, 0);
455+
std::fill_n(maskValues.begin(), remainingElements, 1);
456+
auto backMask = rewriter.create<arith::ConstantOp>(
457+
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
458+
459+
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
460+
cast<TypedValue<VectorType>>(atomicStorePart),
461+
backMask.getResult(), scale);
462+
}
310463

311-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
312-
op, bitCast.getResult(), adaptor.getBase(),
313-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
464+
rewriter.eraseOp(op);
314465
return success();
315466
}
316467
};
@@ -518,9 +669,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
518669
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
519670
linearizedInfo.intraDataOffset, origElements);
520671
} else if (isUnalignedEmulation) {
521-
result =
522-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
523-
*foldedIntraVectorOffset, origElements);
672+
result = staticallyExtractSubvector(
673+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
524674
}
525675
rewriter.replaceOp(op, result);
526676
return success();
@@ -679,9 +829,8 @@ struct ConvertVectorMaskedLoad final
679829
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
680830
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
681831
} else if (isUnalignedEmulation) {
682-
result =
683-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
684-
*foldedIntraVectorOffset, origElements);
832+
result = staticallyExtractSubvector(
833+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
685834
}
686835
rewriter.replaceOp(op, result);
687836

@@ -764,9 +913,8 @@ struct ConvertVectorTransferRead final
764913
linearizedInfo.intraDataOffset,
765914
origElements);
766915
} else if (isUnalignedEmulation) {
767-
result =
768-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
769-
*foldedIntraVectorOffset, origElements);
916+
result = staticallyExtractSubvector(
917+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
770918
}
771919
rewriter.replaceOp(op, result);
772920

0 commit comments

Comments
 (0)