Skip to content

Commit 9b81a3f

Browse files
committed
Add support to avoid atomic operations.
1 parent 1430e84 commit 9b81a3f

File tree

2 files changed

+97
-56
lines changed

2 files changed

+97
-56
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. `useAtomicWrites` indicates whether to use atomic
368+
/// operations in the places where thread contention is possible.
368369
void populateVectorNarrowTypeEmulationPatterns(
369370
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
371+
RewritePatternSet &patterns, bool useAtomicWrites = true);
371372

372373
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373374
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 94 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7676
int numSrcElemsPerDest,
7777
int numFrontPadElems = 0) {
7878

79-
assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
79+
assert(numFrontPadElems < numSrcElemsPerDest &&
80+
"intraDataOffset must be less than scale");
8081

8182
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
8283
numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256257
newLoad);
257258
}
258259

259-
static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
260-
Value memref, Value index, Value value) {
261-
auto originType = dyn_cast<VectorType>(value.getType());
262-
auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
263-
auto scale = memrefElemType.getIntOrFloatBitWidth() /
264-
originType.getElementType().getIntOrFloatBitWidth();
265-
auto storeType =
266-
VectorType::get({originType.getNumElements() / scale}, memrefElemType);
267-
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
268-
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
269-
}
270-
271-
/// atomically store a subbyte-sized value to memory, with a mask.
260+
/// Atomically store a subbyte-sized value to memory, with a mask.
272261
static Value atomicStore(OpBuilder &rewriter, Location loc,
273-
Value emulatedMemref, Value emulatedIndex,
274-
TypedValue<VectorType> value, Value mask,
275-
int64_t scale) {
262+
TypedValue<MemRefType> emulatedMemref,
263+
Value emulatedIndex, TypedValue<VectorType> value,
264+
Value mask, int64_t scale) {
276265
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
277266
loc, emulatedMemref, ValueRange{emulatedIndex});
278267
OpBuilder builder =
@@ -294,6 +283,27 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
294283
return atomicOp;
295284
}
296285

286+
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
287+
static Value rmwStore(OpBuilder &rewriter, Location loc,
288+
TypedValue<MemRefType> emulatedMemref,
289+
Value emulatedIndex, TypedValue<VectorType> value,
290+
Value mask, int64_t numSrcElemsPerDest) {
291+
auto emulatedIOType =
292+
VectorType::get({1}, emulatedMemref.getType().getElementType());
293+
auto elemLoad = rewriter.create<vector::LoadOp>(
294+
loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
295+
auto fromBitcast = rewriter.create<vector::BitCastOp>(
296+
loc,
297+
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
298+
elemLoad);
299+
auto select = rewriter.create<arith::SelectOp>(loc, mask, value, fromBitcast);
300+
auto toBitcast =
301+
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
302+
return rewriter
303+
.create<vector::StoreOp>(loc, toBitcast, emulatedMemref, emulatedIndex)
304+
->getResult(0);
305+
}
306+
297307
// Extract a slice of a vector, and insert it into a byte vector.
298308
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
299309
Location loc, TypedValue<VectorType> vector,
@@ -322,6 +332,10 @@ namespace {
322332
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323333
using OpConversionPattern::OpConversionPattern;
324334

335+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
336+
: OpConversionPattern<vector::StoreOp>(context),
337+
useAtomicWrites_(useAtomicWrites) {}
338+
325339
LogicalResult
326340
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
327341
ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +357,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343357
return rewriter.notifyMatchFailure(
344358
op, "only dstBits % srcBits == 0 supported");
345359
}
346-
int scale = dstBits / srcBits;
360+
int numSrcElemsPerDest = dstBits / srcBits;
347361

348362
// Adjust the number of elements to store when emulating narrow types.
349363
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +373,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359373
// vector<4xi8>
360374

361375
auto origElements = valueToStore.getType().getNumElements();
362-
bool isUnalignedEmulation = origElements % scale != 0;
376+
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
363377

364378
auto stridedMetadata =
365379
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -374,21 +388,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374388
stridedMetadata.getConstifiedMixedStrides(),
375389
getAsOpFoldResult(adaptor.getIndices()));
376390

377-
auto foldedIntraVectorOffset =
391+
auto foldedNumFrontPadElems =
378392
isUnalignedEmulation
379393
? getConstantIntValue(linearizedInfo.intraDataOffset)
380394
: 0;
381395

382-
if (!foldedIntraVectorOffset) {
383-
// unimplemented case for dynamic front padding size
396+
if (!foldedNumFrontPadElems) {
397+
// Unimplemented case for dynamic front padding size != 0
384398
return failure();
385399
}
386400

387-
// conditions when atomic stores and all that are not needed:
401+
// Conditions when atomic stores and all that are not needed:
388402
// 1. The source vector size is multiple of byte size
389403
// 2. The address of the store is byte aligned
390-
if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
391-
auto numElements = origElements / scale;
404+
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
405+
auto numElements = origElements / numSrcElemsPerDest;
392406
auto bitCast = rewriter.create<vector::BitCastOp>(
393407
loc, VectorType::get(numElements, newElementType),
394408
op.getValueToStore());
@@ -398,38 +412,41 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
398412
return llvm::success();
399413
}
400414

401-
Value emulatedMemref = adaptor.getBase();
402-
// the index into the target memref we are storing to
415+
TypedValue<MemRefType> emulatedMemref =
416+
cast<TypedValue<MemRefType>>(adaptor.getBase());
417+
// The index into the target memref we are storing to
403418
Value currentDestIndex =
404419
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
405420
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
406-
auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
407-
// the index into the source vector we are currently processing
421+
auto atomicMaskType =
422+
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
423+
// The index into the source vector we are currently processing
408424
auto currentSourceIndex = 0;
409425

410-
// 1. atomic store for the first byte
411-
auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
426+
// 1. Atomic store for the first byte
427+
auto frontAtomicStoreElem =
428+
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
412429
if (frontAtomicStoreElem != 0) {
413-
auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
414-
if (*foldedIntraVectorOffset + origElements < scale) {
415-
std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
430+
auto frontMaskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, false);
431+
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
432+
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
416433
origElements, true);
417434
frontAtomicStoreElem = origElements;
418435
} else {
419436
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
420-
*foldedIntraVectorOffset, true);
437+
*foldedNumFrontPadElems, true);
421438
}
422439
auto frontMask = rewriter.create<arith::ConstantOp>(
423440
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
424441

425-
currentSourceIndex = scale - (*foldedIntraVectorOffset);
442+
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
426443
auto value = extractSliceIntoByte(
427444
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
428-
frontAtomicStoreElem, *foldedIntraVectorOffset);
445+
frontAtomicStoreElem, *foldedNumFrontPadElems);
429446

430-
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
431-
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
432-
scale);
447+
subByteStore(rewriter, loc, emulatedMemref, currentDestIndex,
448+
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
449+
numSrcElemsPerDest);
433450

434451
currentDestIndex = rewriter.create<arith::AddIOp>(
435452
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
@@ -440,44 +457,62 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440457
return success();
441458
}
442459

443-
// 2. non-atomic store
444-
int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445-
int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
460+
// 2. Non-atomic store
461+
int64_t nonAtomicStoreSize =
462+
(origElements - currentSourceIndex) / numSrcElemsPerDest;
463+
int64_t numNonAtomicElements = nonAtomicStoreSize * numSrcElemsPerDest;
446464
if (nonAtomicStoreSize != 0) {
447465
auto nonAtomicStorePart = staticallyExtractSubvector(
448466
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449467
currentSourceIndex, numNonAtomicElements);
450468

451-
nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
452-
nonAtomicStorePart);
469+
auto originType = dyn_cast<VectorType>(nonAtomicStorePart.getType());
470+
auto memrefElemType =
471+
dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
472+
auto storeType = VectorType::get(
473+
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
474+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
475+
nonAtomicStorePart);
476+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
477+
currentDestIndex);
453478

454479
currentSourceIndex += numNonAtomicElements;
455480
currentDestIndex = rewriter.create<arith::AddIOp>(
456481
loc, rewriter.getIndexType(), currentDestIndex,
457482
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
458483
}
459484

460-
// 3. atomic store for the last byte
485+
// 3. Atomic store for the last byte
461486
auto remainingElements = origElements - currentSourceIndex;
462487
if (remainingElements != 0) {
463488
auto atomicStorePart = extractSliceIntoByte(
464489
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465490
currentSourceIndex, remainingElements, 0);
466491

467-
// back mask
468-
auto maskValues = llvm::SmallVector<bool>(scale, 0);
492+
// Generate back mask
493+
auto maskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, 0);
469494
std::fill_n(maskValues.begin(), remainingElements, 1);
470495
auto backMask = rewriter.create<arith::ConstantOp>(
471496
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
472497

473-
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
474-
cast<TypedValue<VectorType>>(atomicStorePart),
475-
backMask.getResult(), scale);
498+
subByteStore(rewriter, loc, emulatedMemref, currentDestIndex,
499+
cast<TypedValue<VectorType>>(atomicStorePart),
500+
backMask.getResult(), numSrcElemsPerDest);
476501
}
477502

478503
rewriter.eraseOp(op);
479504
return success();
480505
}
506+
507+
template <typename... Args>
508+
Value subByteStore(Args &&...args) const {
509+
std::function<decltype(atomicStore)> storeFunc =
510+
useAtomicWrites_ ? atomicStore : rmwStore;
511+
return storeFunc(std::forward<Args>(args)...);
512+
}
513+
514+
private:
515+
const bool useAtomicWrites_;
481516
};
482517

483518
//===----------------------------------------------------------------------===//
@@ -1673,12 +1708,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16731708

16741709
void vector::populateVectorNarrowTypeEmulationPatterns(
16751710
const arith::NarrowTypeEmulationConverter &typeConverter,
1676-
RewritePatternSet &patterns) {
1711+
RewritePatternSet &patterns, bool useAtomicWrites) {
16771712

1678-
// Populate `vector.*` conversion patterns.
1679-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1713+
// Populate `vector.*` load conversion patterns.
1714+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
16801715
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
16811716
typeConverter, patterns.getContext());
1717+
1718+
// Populate `vector.*` store conversion patterns. The caller can choose
1719+
// to avoid emitting atomic operations and reduce it to load-modify-write
1720+
// sequence for stores if it is known there are no thread contentions.
1721+
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
16821722
}
16831723

16841724
void vector::populateVectorNarrowTypeRewritePatterns(

0 commit comments

Comments
 (0)