Skip to content

Commit 5069134

Browse files
committed
Implement vector stores
1 parent 44a41b0 commit 5069134

File tree

5 files changed

+433
-35
lines changed

5 files changed

+433
-35
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. `useAtomicWrites` indicates whether to use atomic
368+
/// operations in the places where thread contention is possible.
368369
void populateVectorNarrowTypeEmulationPatterns(
369370
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
371+
RewritePatternSet &patterns, bool useAtomicWrites = true);
371372

372373
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373374
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 229 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Transforms/DialectConversion.h"
3434
#include "llvm/ADT/SmallVector.h"
3535
#include "llvm/Support/Debug.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include <cstdint>
@@ -208,13 +209,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
208209
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
209210
/// emitting `vector.extract_strided_slice`.
210211
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
211-
VectorType extractType, Value source,
212-
int64_t frontOffset,
212+
Value source, int64_t frontOffset,
213213
int64_t subvecSize) {
214214
auto vectorType = cast<VectorType>(source.getType());
215-
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
216-
"expected 1-D source and destination types");
217-
(void)vectorType;
215+
assert(vectorType.getRank() == 1 && "expected 1-D source types");
218216
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
219217
"subvector out of bounds");
220218

@@ -225,9 +223,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
225223
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
226224
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
227225
auto strides = rewriter.getI64ArrayAttr({1});
226+
227+
auto resultVectorType =
228+
VectorType::get({subvecSize}, vectorType.getElementType());
228229
return rewriter
229-
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
230-
sizes, strides)
230+
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
231+
offsets, sizes, strides)
231232
->getResult(0);
232233
}
233234

@@ -306,6 +307,73 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
306307
newLoad);
307308
}
308309

310+
/// Atomically store a subbyte-sized value to memory, with a mask.
311+
static void atomicStore(OpBuilder &rewriter, Location loc,
312+
TypedValue<MemRefType> emulatedMemref,
313+
Value emulatedIndex, TypedValue<VectorType> value,
314+
Value mask, int64_t scale) {
315+
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
316+
loc, emulatedMemref, ValueRange{emulatedIndex});
317+
OpBuilder builder =
318+
OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
319+
Value origValue = atomicOp.getCurrentValue();
320+
321+
// i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
322+
auto oneVectorType = VectorType::get({1}, origValue.getType());
323+
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
324+
ValueRange{origValue});
325+
auto vectorBitCast =
326+
builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
327+
328+
auto select =
329+
builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
330+
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
331+
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
332+
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
333+
}
334+
335+
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
336+
static void rmwStore(OpBuilder &rewriter, Location loc,
337+
TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
338+
TypedValue<VectorType> value, Value mask,
339+
int64_t numSrcElemsPerDest) {
340+
auto emulatedIOType =
341+
VectorType::get({1}, emulatedMemref.getType().getElementType());
342+
auto elemLoad = rewriter.create<vector::LoadOp>(
343+
loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
344+
auto fromBitcast = rewriter.create<vector::BitCastOp>(
345+
loc,
346+
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
347+
elemLoad);
348+
auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
349+
auto toBitcast =
350+
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
351+
rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
352+
emulatedIndex);
353+
}
354+
355+
static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
356+
"`atomicStore` and `rmwStore` must have same function type.");
357+
358+
// Extract a slice of a vector, and insert it into a byte vector.
359+
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
360+
Location loc, TypedValue<VectorType> vector,
361+
int64_t sliceOffset, int64_t sliceNumElements,
362+
int64_t byteOffset) {
363+
auto vectorElementType = vector.getType().getElementType();
364+
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
365+
"vector element must be a valid sub-byte type");
366+
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
367+
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
368+
loc, VectorType::get({scale}, vectorElementType),
369+
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
370+
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
371+
sliceOffset, sliceNumElements);
372+
auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
373+
emptyByteVector, byteOffset);
374+
return inserted;
375+
}
376+
309377
namespace {
310378

311379
//===----------------------------------------------------------------------===//
@@ -315,6 +383,10 @@ namespace {
315383
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
316384
using OpConversionPattern::OpConversionPattern;
317385

386+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
387+
: OpConversionPattern<vector::StoreOp>(context),
388+
useAtomicWrites_(useAtomicWrites) {}
389+
318390
LogicalResult
319391
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
320392
ConversionPatternRewriter &rewriter) const override {
@@ -326,7 +398,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
326398

327399
auto loc = op.getLoc();
328400
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
329-
Type oldElementType = op.getValueToStore().getType().getElementType();
401+
auto valueToStore = op.getValueToStore();
402+
Type oldElementType = valueToStore.getType().getElementType();
330403
Type newElementType = convertedType.getElementType();
331404
int srcBits = oldElementType.getIntOrFloatBitWidth();
332405
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -335,7 +408,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
335408
return rewriter.notifyMatchFailure(
336409
op, "only dstBits % srcBits == 0 supported");
337410
}
338-
int scale = dstBits / srcBits;
411+
int numSrcElemsPerDest = dstBits / srcBits;
339412

340413
// Adjust the number of elements to store when emulating narrow types.
341414
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -350,32 +423,154 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
350423
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
351424
// vector<4xi8>
352425

353-
auto origElements = op.getValueToStore().getType().getNumElements();
354-
if (origElements % scale != 0)
355-
return failure();
426+
auto origElements = valueToStore.getType().getNumElements();
427+
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
356428

357429
auto stridedMetadata =
358430
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
359431

360432
OpFoldResult linearizedIndices;
361-
std::tie(std::ignore, linearizedIndices) =
433+
memref::LinearizedMemRefInfo linearizedInfo;
434+
std::tie(linearizedInfo, linearizedIndices) =
362435
memref::getLinearizedMemRefOffsetAndSize(
363436
rewriter, loc, srcBits, dstBits,
364437
stridedMetadata.getConstifiedMixedOffset(),
365438
stridedMetadata.getConstifiedMixedSizes(),
366439
stridedMetadata.getConstifiedMixedStrides(),
367440
getAsOpFoldResult(adaptor.getIndices()));
368441

369-
auto numElements = origElements / scale;
370-
auto bitCast = rewriter.create<vector::BitCastOp>(
371-
loc, VectorType::get(numElements, newElementType),
372-
op.getValueToStore());
442+
auto foldedNumFrontPadElems =
443+
isUnalignedEmulation
444+
? getConstantIntValue(linearizedInfo.intraDataOffset)
445+
: 0;
446+
447+
if (!foldedNumFrontPadElems) {
448+
// Unimplemented case for dynamic front padding size != 0
449+
return failure();
450+
}
451+
452+
TypedValue<MemRefType> emulatedMemref =
453+
cast<TypedValue<MemRefType>>(adaptor.getBase());
454+
455+
// Shortcut: conditions when subbyte store at the front is not needed:
456+
// 1. The source vector size is multiple of byte size
457+
// 2. The address of the store is aligned to the emulated width boundary
458+
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
459+
auto numElements = origElements / numSrcElemsPerDest;
460+
auto bitCast = rewriter.create<vector::BitCastOp>(
461+
loc, VectorType::get(numElements, newElementType),
462+
op.getValueToStore());
463+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
464+
op, bitCast.getResult(), emulatedMemref,
465+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
466+
return llvm::success();
467+
}
468+
469+
// The index into the target memref we are storing to
470+
Value currentDestIndex =
471+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
472+
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
473+
auto subWidthStoreMaskType =
474+
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
475+
// The index into the source vector we are currently processing
476+
auto currentSourceIndex = 0;
477+
478+
// 1. Partial width store for the first byte, when the store address is not
479+
// aligned to emulated width boundary, deal with the unaligned part so that
480+
// the rest elements are aligned to width boundary.
481+
auto frontSubWidthStoreElem =
482+
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
483+
if (frontSubWidthStoreElem != 0) {
484+
auto frontMaskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, false);
485+
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
486+
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
487+
origElements, true);
488+
frontSubWidthStoreElem = origElements;
489+
} else {
490+
std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
491+
*foldedNumFrontPadElems, true);
492+
}
493+
auto frontMask = rewriter.create<arith::ConstantOp>(
494+
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
373495

374-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
375-
op, bitCast.getResult(), adaptor.getBase(),
376-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
496+
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
497+
auto value = extractSliceIntoByte(
498+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
499+
frontSubWidthStoreElem, *foldedNumFrontPadElems);
500+
501+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
502+
cast<TypedValue<VectorType>>(value),
503+
frontMask.getResult(), numSrcElemsPerDest);
504+
505+
currentDestIndex = rewriter.create<arith::AddIOp>(
506+
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
507+
}
508+
509+
if (currentSourceIndex >= origElements) {
510+
rewriter.eraseOp(op);
511+
return success();
512+
}
513+
514+
// 2. Full width store. After the previous step, the store address is
515+
// aligned to the emulated width boundary.
516+
int64_t fullWidthStoreSize =
517+
(origElements - currentSourceIndex) / numSrcElemsPerDest;
518+
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
519+
if (fullWidthStoreSize != 0) {
520+
auto fullWidthStorePart = staticallyExtractSubvector(
521+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
522+
currentSourceIndex, numNonFullWidthElements);
523+
524+
auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
525+
auto memrefElemType =
526+
dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
527+
auto storeType = VectorType::get(
528+
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
529+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
530+
fullWidthStorePart);
531+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
532+
currentDestIndex);
533+
534+
currentSourceIndex += numNonFullWidthElements;
535+
currentDestIndex = rewriter.create<arith::AddIOp>(
536+
loc, rewriter.getIndexType(), currentDestIndex,
537+
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
538+
}
539+
540+
// 3. Deal with trailing elements that are aligned to the emulated width,
541+
// but their length is smaller than the emulated width.
542+
auto remainingElements = origElements - currentSourceIndex;
543+
if (remainingElements != 0) {
544+
auto subWidthStorePart = extractSliceIntoByte(
545+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
546+
currentSourceIndex, remainingElements, 0);
547+
548+
// Generate back mask
549+
auto maskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, 0);
550+
std::fill_n(maskValues.begin(), remainingElements, 1);
551+
auto backMask = rewriter.create<arith::ConstantOp>(
552+
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
553+
554+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
555+
cast<TypedValue<VectorType>>(subWidthStorePart),
556+
backMask.getResult(), numSrcElemsPerDest);
557+
}
558+
559+
rewriter.eraseOp(op);
377560
return success();
378561
}
562+
563+
/// Store a subbyte-sized value to memory, with a mask. Depending on the
564+
/// configuration, it could be an atomic store or an RMW sequence.
565+
template <typename... Args>
566+
void subEmulatedWidthStore(Args &&...args) const {
567+
std::function<decltype(atomicStore)> storeFunc =
568+
useAtomicWrites_ ? atomicStore : rmwStore;
569+
storeFunc(std::forward<Args>(args)...);
570+
}
571+
572+
private:
573+
const bool useAtomicWrites_;
379574
};
380575

381576
//===----------------------------------------------------------------------===//
@@ -581,9 +776,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
581776
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
582777
linearizedInfo.intraDataOffset, origElements);
583778
} else if (isUnalignedEmulation) {
584-
result =
585-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
586-
*foldedIntraVectorOffset, origElements);
779+
result = staticallyExtractSubvector(
780+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
587781
}
588782
rewriter.replaceOp(op, result);
589783
return success();
@@ -742,9 +936,8 @@ struct ConvertVectorMaskedLoad final
742936
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
743937
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
744938
} else if (isUnalignedEmulation) {
745-
result =
746-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
747-
*foldedIntraVectorOffset, origElements);
939+
result = staticallyExtractSubvector(
940+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
748941
}
749942
rewriter.replaceOp(op, result);
750943

@@ -827,9 +1020,8 @@ struct ConvertVectorTransferRead final
8271020
linearizedInfo.intraDataOffset,
8281021
origElements);
8291022
} else if (isUnalignedEmulation) {
830-
result =
831-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
832-
*foldedIntraVectorOffset, origElements);
1023+
result = staticallyExtractSubvector(
1024+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
8331025
}
8341026
rewriter.replaceOp(op, result);
8351027

@@ -1574,12 +1766,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
15741766

15751767
void vector::populateVectorNarrowTypeEmulationPatterns(
15761768
const arith::NarrowTypeEmulationConverter &typeConverter,
1577-
RewritePatternSet &patterns) {
1769+
RewritePatternSet &patterns, bool useAtomicWrites) {
15781770

1579-
// Populate `vector.*` conversion patterns.
1580-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1771+
// Populate `vector.*` load conversion patterns.
1772+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
15811773
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
15821774
typeConverter, patterns.getContext());
1775+
1776+
// Populate `vector.*` store conversion patterns. The caller can choose
1777+
// to avoid emitting atomic operations and reduce it to load-modify-write
1778+
// sequence for stores if it is known there are no thread contentions.
1779+
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
15831780
}
15841781

15851782
void vector::populateVectorNarrowTypeRewritePatterns(

0 commit comments

Comments
 (0)