Skip to content

Commit bfdbed2

Browse files
committed
Implement vector stores
1 parent 36fa8bd commit bfdbed2

File tree

2 files changed

+313
-33
lines changed

2 files changed

+313
-33
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 178 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Transforms/DialectConversion.h"
3434
#include "llvm/ADT/SmallVector.h"
3535
#include "llvm/Support/Debug.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include "llvm/Support/raw_ostream.h"
3839
#include <cstdint>
@@ -143,19 +144,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
143144
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
144145
/// emitting `vector.extract_strided_slice`.
145146
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
146-
VectorType extractType, Value source,
147-
int64_t frontOffset,
147+
Value source, int64_t frontOffset,
148148
int64_t subvecSize) {
149-
auto vectorType = cast<VectorType>(source.getType());
150-
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
151-
"expected 1-D source and destination types");
152-
(void)vectorType;
149+
auto vectorType = llvm::cast<VectorType>(source.getType());
150+
assert(vectorType.getRank() == 1 && "expected 1-D source types");
153151
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
154152
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
155153
auto strides = rewriter.getI64ArrayAttr({1});
154+
155+
auto resultVectorType =
156+
VectorType::get({subvecSize}, vectorType.getElementType());
156157
return rewriter
157-
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
158-
sizes, strides)
158+
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
159+
offsets, sizes, strides)
159160
->getResult(0);
160161
}
161162

@@ -164,12 +165,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
164165
/// `vector.insert_strided_slice`.
165166
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
166167
Value src, Value dest, int64_t offset) {
167-
auto srcType = cast<VectorType>(src.getType());
168-
auto destType = cast<VectorType>(dest.getType());
168+
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
169+
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
169170
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
170171
"expected source and dest to be vector type");
171-
(void)srcType;
172-
(void)destType;
173172
auto offsets = rewriter.getI64ArrayAttr({offset});
174173
auto strides = rewriter.getI64ArrayAttr({1});
175174
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -236,6 +235,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
236235
newLoad);
237236
}
238237

238+
static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
239+
Value memref, Value index, Value value) {
240+
auto originType = dyn_cast<VectorType>(value.getType());
241+
auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
242+
auto scale = memrefElemType.getIntOrFloatBitWidth() /
243+
originType.getElementType().getIntOrFloatBitWidth();
244+
auto storeType =
245+
VectorType::get({originType.getNumElements() / scale}, memrefElemType);
246+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
247+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
248+
}
249+
250+
/// atomically store a subbyte-sized value to memory, with a mask.
251+
static Value atomicStore(OpBuilder &rewriter, Location loc,
252+
Value emulatedMemref, Value emulatedIndex,
253+
TypedValue<VectorType> value, Value mask,
254+
int64_t scale) {
255+
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
256+
loc, emulatedMemref, ValueRange{emulatedIndex});
257+
OpBuilder builder =
258+
OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
259+
Value origValue = atomicOp.getCurrentValue();
260+
261+
// i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
262+
auto oneVectorType = VectorType::get({1}, origValue.getType());
263+
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
264+
ValueRange{origValue});
265+
auto vectorBitCast =
266+
builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
267+
268+
auto select =
269+
builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
270+
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
271+
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
272+
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
273+
return atomicOp;
274+
}
275+
276+
// Extract a slice of a vector, and insert it into a byte vector.
277+
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
278+
Location loc, TypedValue<VectorType> vector,
279+
int64_t sliceOffset, int64_t sliceNumElements,
280+
int64_t byteOffset) {
281+
auto vectorElementType = vector.getType().getElementType();
282+
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
283+
"vector element must be a valid sub-byte type");
284+
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
285+
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
286+
loc, VectorType::get({scale}, vectorElementType),
287+
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
288+
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
289+
sliceOffset, sliceNumElements);
290+
auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
291+
emptyByteVector, byteOffset);
292+
return inserted;
293+
}
294+
239295
namespace {
240296

241297
//===----------------------------------------------------------------------===//
@@ -256,7 +312,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
256312

257313
auto loc = op.getLoc();
258314
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
259-
Type oldElementType = op.getValueToStore().getType().getElementType();
315+
auto valueToStore = op.getValueToStore();
316+
Type oldElementType = valueToStore.getType().getElementType();
260317
Type newElementType = convertedType.getElementType();
261318
int srcBits = oldElementType.getIntOrFloatBitWidth();
262319
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -280,30 +337,121 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
280337
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
281338
// vector<4xi8>
282339

283-
auto origElements = op.getValueToStore().getType().getNumElements();
284-
if (origElements % scale != 0)
285-
return failure();
340+
auto origElements = valueToStore.getType().getNumElements();
341+
bool isUnalignedEmulation = origElements % scale != 0;
286342

287343
auto stridedMetadata =
288344
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
289345

290346
OpFoldResult linearizedIndices;
291-
std::tie(std::ignore, linearizedIndices) =
347+
memref::LinearizedMemRefInfo linearizedInfo;
348+
std::tie(linearizedInfo, linearizedIndices) =
292349
memref::getLinearizedMemRefOffsetAndSize(
293350
rewriter, loc, srcBits, dstBits,
294351
stridedMetadata.getConstifiedMixedOffset(),
295352
stridedMetadata.getConstifiedMixedSizes(),
296353
stridedMetadata.getConstifiedMixedStrides(),
297354
getAsOpFoldResult(adaptor.getIndices()));
298355

299-
auto numElements = origElements / scale;
300-
auto bitCast = rewriter.create<vector::BitCastOp>(
301-
loc, VectorType::get(numElements, newElementType),
302-
op.getValueToStore());
356+
auto foldedIntraVectorOffset =
357+
isUnalignedEmulation
358+
? getConstantIntValue(linearizedInfo.intraDataOffset)
359+
: 0;
360+
361+
if (!foldedIntraVectorOffset) {
362+
// unimplemented case for dynamic front padding size
363+
return failure();
364+
}
365+
366+
if (!isUnalignedEmulation) {
367+
auto numElements = origElements / scale;
368+
auto bitCast = rewriter.create<vector::BitCastOp>(
369+
loc, VectorType::get(numElements, newElementType),
370+
op.getValueToStore());
371+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
372+
op, bitCast.getResult(), adaptor.getBase(),
373+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
374+
return llvm::success();
375+
}
376+
377+
Value emulatedMemref = adaptor.getBase();
378+
// the index into the target memref we are storing to
379+
Value currentDestIndex =
380+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
381+
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
382+
auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
383+
// the index into the source vector we are currently processing
384+
auto currentSourceIndex = 0;
385+
386+
// 1. atomic store for the first byte
387+
auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
388+
if (frontAtomicStoreElem != 0) {
389+
auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
390+
if (*foldedIntraVectorOffset + origElements < scale) {
391+
std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
392+
origElements, true);
393+
frontAtomicStoreElem = origElements;
394+
} else {
395+
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
396+
*foldedIntraVectorOffset, true);
397+
}
398+
auto frontMask = rewriter.create<arith::ConstantOp>(
399+
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
400+
401+
currentSourceIndex = scale - (*foldedIntraVectorOffset);
402+
auto value = extractSliceIntoByte(
403+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
404+
frontAtomicStoreElem, *foldedIntraVectorOffset);
405+
406+
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
407+
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
408+
scale);
409+
410+
currentDestIndex = rewriter.create<arith::AddIOp>(
411+
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
412+
}
413+
414+
if (currentSourceIndex >= origElements) {
415+
rewriter.eraseOp(op);
416+
return success();
417+
}
418+
419+
// 2. non-atomic store
420+
int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
421+
int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
422+
if (nonAtomicStoreSize != 0) {
423+
auto nonAtomicStorePart = staticallyExtractSubvector(
424+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
425+
currentSourceIndex, numNonAtomicElements);
426+
427+
nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
428+
nonAtomicStorePart);
429+
430+
currentSourceIndex += numNonAtomicElements;
431+
currentDestIndex = rewriter.create<arith::AddIOp>(
432+
loc, rewriter.getIndexType(), currentDestIndex,
433+
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
434+
}
435+
436+
// 3. atomic store for the last byte
437+
auto remainingElements = origElements - currentSourceIndex;
438+
if (remainingElements != 0) {
439+
auto atomicStorePart = extractSliceIntoByte(
440+
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
441+
currentSourceIndex, remainingElements, 0);
442+
443+
// back mask
444+
auto maskValues = llvm::SmallVector<bool>(scale, 0);
445+
std::fill_n(maskValues.begin(), remainingElements, 1);
446+
auto backMask = rewriter.create<arith::ConstantOp>(
447+
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
448+
449+
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
450+
cast<TypedValue<VectorType>>(atomicStorePart),
451+
backMask.getResult(), scale);
452+
}
303453

304-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
305-
op, bitCast.getResult(), adaptor.getBase(),
306-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
454+
rewriter.eraseOp(op);
307455
return success();
308456
}
309457
};
@@ -511,9 +659,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
511659
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
512660
linearizedInfo.intraDataOffset, origElements);
513661
} else if (isUnalignedEmulation) {
514-
result =
515-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
516-
*foldedIntraVectorOffset, origElements);
662+
result = staticallyExtractSubvector(
663+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
517664
}
518665
rewriter.replaceOp(op, result);
519666
return success();
@@ -672,9 +819,8 @@ struct ConvertVectorMaskedLoad final
672819
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
673820
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
674821
} else if (isUnalignedEmulation) {
675-
result =
676-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
677-
*foldedIntraVectorOffset, origElements);
822+
result = staticallyExtractSubvector(
823+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
678824
}
679825
rewriter.replaceOp(op, result);
680826

@@ -757,9 +903,8 @@ struct ConvertVectorTransferRead final
757903
linearizedInfo.intraDataOffset,
758904
origElements);
759905
} else if (isUnalignedEmulation) {
760-
result =
761-
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
762-
*foldedIntraVectorOffset, origElements);
906+
result = staticallyExtractSubvector(
907+
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
763908
}
764909
rewriter.replaceOp(op, result);
765910

0 commit comments

Comments
 (0)