Skip to content

Commit fe648aa

Browse files
committed
Implement vector stores
1 parent 2ed8c5d commit fe648aa

File tree

5 files changed

+436
-36
lines changed

5 files changed

+436
-36
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. `useAtomicWrites` indicates whether to use atomic
368+
/// operations in the places where thread contention is possible.
368369
void populateVectorNarrowTypeEmulationPatterns(
369370
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
371+
RewritePatternSet &patterns, bool useAtomicWrites = true);
371372

372373
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373374
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 232 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Transforms/DialectConversion.h"
3434
#include "llvm/ADT/SmallVector.h"
3535
#include "llvm/Support/Debug.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include <cstdint>
@@ -211,13 +212,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
211212
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
212213
/// emitting `vector.extract_strided_slice`.
213214
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
214-
VectorType extractType, Value source,
215-
int64_t frontOffset,
215+
Value source, int64_t frontOffset,
216216
int64_t subvecSize) {
217217
auto vectorType = cast<VectorType>(source.getType());
218-
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
219-
"expected 1-D source and destination types");
220-
(void)vectorType;
218+
assert(vectorType.getRank() == 1 && "expected 1-D source types");
221219
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
222220
"subvector out of bounds");
223221

@@ -228,9 +226,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
228226
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
229227
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
230228
auto strides = rewriter.getI64ArrayAttr({1});
229+
230+
auto resultVectorType =
231+
VectorType::get({subvecSize}, vectorType.getElementType());
231232
return rewriter
232-
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
233-
sizes, strides)
233+
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
234+
offsets, sizes, strides)
234235
->getResult(0);
235236
}
236237

@@ -309,6 +310,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
309310
newLoad);
310311
}
311312

313+
/// Atomically store a subbyte-sized value to memory, with a mask.
314+
static void atomicStore(OpBuilder &builder, Location loc,
315+
TypedValue<MemRefType> emulatedMemref,
316+
Value emulatedIndex, TypedValue<VectorType> value,
317+
Value mask, int64_t) {
318+
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
319+
loc, emulatedMemref, ValueRange{emulatedIndex});
320+
Value origValue = atomicOp.getCurrentValue();
321+
322+
OpBuilder::InsertionGuard guard(builder);
323+
builder.setInsertionPointToStart(atomicOp.getBody());
324+
325+
// i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
326+
auto oneVectorType = VectorType::get({1}, origValue.getType());
327+
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
328+
ValueRange{origValue});
329+
auto vectorBitCast =
330+
builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
331+
332+
auto select =
333+
builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
334+
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
335+
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
336+
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
337+
}
338+
339+
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
340+
static void rmwStore(OpBuilder &rewriter, Location loc,
341+
TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
342+
TypedValue<VectorType> value, Value mask,
343+
int64_t numSrcElemsPerDest) {
344+
auto emulatedIOType =
345+
VectorType::get({1}, emulatedMemref.getType().getElementType());
346+
auto elemLoad = rewriter.create<vector::LoadOp>(
347+
loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
348+
auto fromBitcast = rewriter.create<vector::BitCastOp>(
349+
loc,
350+
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
351+
elemLoad);
352+
auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
353+
auto toBitcast =
354+
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
355+
rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
356+
emulatedIndex);
357+
}
358+
359+
static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
360+
"`atomicStore` and `rmwStore` must have same signature, as per "
361+
"the design to keep the code clean, which one to call is "
362+
"determined by the `useAtomicWrites` flag.");
363+
364+
// Extract a slice of a vector, and insert it into a byte vector.
365+
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
366+
Location loc, TypedValue<VectorType> vector,
367+
int64_t sliceOffset, int64_t sliceNumElements,
368+
int64_t byteOffset) {
369+
auto vectorElementType = vector.getType().getElementType();
370+
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
371+
"vector element must be a valid sub-byte type");
372+
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
373+
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
374+
loc, VectorType::get({scale}, vectorElementType),
375+
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
376+
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
377+
sliceOffset, sliceNumElements);
378+
auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
379+
emptyByteVector, byteOffset);
380+
return inserted;
381+
}
382+
312383
namespace {
313384

314385
//===----------------------------------------------------------------------===//
@@ -318,6 +389,10 @@ namespace {
318389
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
319390
using OpConversionPattern::OpConversionPattern;
320391

392+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
393+
: OpConversionPattern<vector::StoreOp>(context),
394+
useAtomicWrites_(useAtomicWrites) {}
395+
321396
LogicalResult
322397
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
323398
ConversionPatternRewriter &rewriter) const override {
@@ -329,16 +404,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
329404

330405
auto loc = op.getLoc();
331406
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
332-
Type oldElementType = op.getValueToStore().getType().getElementType();
333-
Type newElementType = convertedType.getElementType();
407+
auto valueToStore = cast<TypedValue<VectorType>>(op.getValueToStore());
408+
auto oldElementType = valueToStore.getType().getElementType();
409+
auto newElementType = convertedType.getElementType();
334410
int srcBits = oldElementType.getIntOrFloatBitWidth();
335411
int dstBits = newElementType.getIntOrFloatBitWidth();
336412

337413
if (dstBits % srcBits != 0) {
338414
return rewriter.notifyMatchFailure(
339415
op, "only dstBits % srcBits == 0 supported");
340416
}
341-
int scale = dstBits / srcBits;
417+
int numSrcElemsPerDest = dstBits / srcBits;
342418

343419
// Adjust the number of elements to store when emulating narrow types.
344420
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -353,32 +429,153 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
353429
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
354430
// vector<4xi8>
355431

356-
auto origElements = op.getValueToStore().getType().getNumElements();
357-
if (origElements % scale != 0)
358-
return failure();
432+
auto origElements = valueToStore.getType().getNumElements();
433+
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
359434

360435
auto stridedMetadata =
361436
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
362437

363438
OpFoldResult linearizedIndices;
364-
std::tie(std::ignore, linearizedIndices) =
439+
memref::LinearizedMemRefInfo linearizedInfo;
440+
std::tie(linearizedInfo, linearizedIndices) =
365441
memref::getLinearizedMemRefOffsetAndSize(
366442
rewriter, loc, srcBits, dstBits,
367443
stridedMetadata.getConstifiedMixedOffset(),
368444
stridedMetadata.getConstifiedMixedSizes(),
369445
stridedMetadata.getConstifiedMixedStrides(),
370446
getAsOpFoldResult(adaptor.getIndices()));
371447

372-
auto numElements = origElements / scale;
373-
auto bitCast = rewriter.create<vector::BitCastOp>(
374-
loc, VectorType::get(numElements, newElementType),
375-
op.getValueToStore());
448+
auto foldedNumFrontPadElems =
449+
isUnalignedEmulation
450+
? getConstantIntValue(linearizedInfo.intraDataOffset)
451+
: 0;
376452

377-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
378-
op, bitCast.getResult(), adaptor.getBase(),
379-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
453+
if (!foldedNumFrontPadElems) {
454+
// Unimplemented case for dynamic front padding size != 0
455+
return failure();
456+
}
457+
458+
auto emulatedMemref = cast<TypedValue<MemRefType>>(adaptor.getBase());
459+
460+
// Shortcut: conditions when subbyte store at the front is not needed:
461+
// 1. The source vector size is multiple of byte size
462+
// 2. The address of the store is aligned to the emulated width boundary
463+
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
464+
auto numElements = origElements / numSrcElemsPerDest;
465+
auto bitCast = rewriter.create<vector::BitCastOp>(
466+
loc, VectorType::get(numElements, newElementType),
467+
op.getValueToStore());
468+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
469+
op, bitCast.getResult(), emulatedMemref,
470+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
471+
return success();
472+
}
473+
474+
// The index into the target memref we are storing to
475+
Value currentDestIndex =
476+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
477+
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
478+
auto subWidthStoreMaskType =
479+
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
480+
// The index into the source vector we are currently processing
481+
auto currentSourceIndex = 0;
482+
483+
// 1. Partial width store for the first byte, when the store address is not
484+
// aligned to emulated width boundary, deal with the unaligned part so that
485+
// the rest elements are aligned to width boundary.
486+
auto frontSubWidthStoreElem =
487+
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
488+
if (frontSubWidthStoreElem != 0) {
489+
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
490+
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
491+
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
492+
origElements, true);
493+
frontSubWidthStoreElem = origElements;
494+
} else {
495+
std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
496+
*foldedNumFrontPadElems, true);
497+
}
498+
auto frontMask = rewriter.create<arith::ConstantOp>(
499+
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
500+
501+
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
502+
auto value =
503+
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
504+
frontSubWidthStoreElem, *foldedNumFrontPadElems);
505+
506+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
507+
cast<TypedValue<VectorType>>(value),
508+
frontMask.getResult(), numSrcElemsPerDest);
509+
510+
currentDestIndex = rewriter.create<arith::AddIOp>(
511+
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
512+
}
513+
514+
if (currentSourceIndex >= origElements) {
515+
rewriter.eraseOp(op);
516+
return success();
517+
}
518+
519+
// 2. Full width store. After the previous step, the store address is
520+
// aligned to the emulated width boundary.
521+
int64_t fullWidthStoreSize =
522+
(origElements - currentSourceIndex) / numSrcElemsPerDest;
523+
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
524+
if (fullWidthStoreSize != 0) {
525+
auto fullWidthStorePart = staticallyExtractSubvector(
526+
rewriter, loc, valueToStore, currentSourceIndex,
527+
numNonFullWidthElements);
528+
529+
auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
530+
auto memrefElemType =
531+
dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
532+
auto storeType = VectorType::get(
533+
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
534+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
535+
fullWidthStorePart);
536+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
537+
currentDestIndex);
538+
539+
currentSourceIndex += numNonFullWidthElements;
540+
currentDestIndex = rewriter.create<arith::AddIOp>(
541+
loc, rewriter.getIndexType(), currentDestIndex,
542+
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
543+
}
544+
545+
// 3. Deal with trailing elements that are aligned to the emulated width,
546+
// but their length is smaller than the emulated width.
547+
auto remainingElements = origElements - currentSourceIndex;
548+
if (remainingElements != 0) {
549+
auto subWidthStorePart = extractSliceIntoByte(
550+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
551+
currentSourceIndex, remainingElements, 0);
552+
553+
// Generate back mask
554+
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
555+
std::fill_n(maskValues.begin(), remainingElements, 1);
556+
auto backMask = rewriter.create<arith::ConstantOp>(
557+
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
558+
559+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
560+
cast<TypedValue<VectorType>>(subWidthStorePart),
561+
backMask.getResult(), numSrcElemsPerDest);
562+
}
563+
564+
rewriter.eraseOp(op);
380565
return success();
381566
}
567+
568+
/// Store a subbyte-sized value to memory, with a mask. Depending on the
569+
/// configuration, it could be an atomic store or an RMW sequence.
570+
template <typename... Args>
571+
void subEmulatedWidthStore(Args &&...args) const {
572+
std::function<decltype(atomicStore)> storeFunc =
573+
useAtomicWrites_ ? atomicStore : rmwStore;
574+
storeFunc(std::forward<Args>(args)...);
575+
}
576+
577+
private:
578+
const bool useAtomicWrites_;
382579
};
383580

384581
//===----------------------------------------------------------------------===//
@@ -584,9 +781,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
584781
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
585782
linearizedInfo.intraDataOffset, origElements);
586783
} else if (isUnalignedEmulation) {
587-
result =
588-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
589-
*foldedIntraVectorOffset, origElements);
784+
result = staticallyExtractSubvector(
785+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
590786
}
591787
rewriter.replaceOp(op, result);
592788
return success();
@@ -745,9 +941,8 @@ struct ConvertVectorMaskedLoad final
745941
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
746942
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
747943
} else if (isUnalignedEmulation) {
748-
result =
749-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
750-
*foldedIntraVectorOffset, origElements);
944+
result = staticallyExtractSubvector(
945+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
751946
}
752947
rewriter.replaceOp(op, result);
753948

@@ -830,9 +1025,8 @@ struct ConvertVectorTransferRead final
8301025
linearizedInfo.intraDataOffset,
8311026
origElements);
8321027
} else if (isUnalignedEmulation) {
833-
result =
834-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
835-
*foldedIntraVectorOffset, origElements);
1028+
result = staticallyExtractSubvector(
1029+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
8361030
}
8371031
rewriter.replaceOp(op, result);
8381032

@@ -1577,12 +1771,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
15771771

15781772
void vector::populateVectorNarrowTypeEmulationPatterns(
15791773
const arith::NarrowTypeEmulationConverter &typeConverter,
1580-
RewritePatternSet &patterns) {
1774+
RewritePatternSet &patterns, bool useAtomicWrites) {
15811775

1582-
// Populate `vector.*` conversion patterns.
1583-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1776+
// Populate `vector.*` load conversion patterns.
1777+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
15841778
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
15851779
typeConverter, patterns.getContext());
1780+
1781+
// Populate `vector.*` store conversion patterns. The caller can choose
1782+
// to avoid emitting atomic operations and reduce it to load-modify-write
1783+
// sequence for stores if it is known there are no thread contentions.
1784+
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
15861785
}
15871786

15881787
void vector::populateVectorNarrowTypeRewritePatterns(

0 commit comments

Comments
 (0)