Skip to content

Commit aec91d0

Browse files
committed
Add support to avoid atomic operations.
1 parent 1430e84 commit aec91d0

File tree

2 files changed

+122
-71
lines changed

2 files changed

+122
-71
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. `useAtomicWrites` indicates whether to use atomic
368+
/// operations in the places where thread contention is possible.
368369
void populateVectorNarrowTypeEmulationPatterns(
369370
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
371+
RewritePatternSet &patterns, bool useAtomicWrites = true);
371372

372373
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373374
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 119 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7676
int numSrcElemsPerDest,
7777
int numFrontPadElems = 0) {
7878

79-
assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
79+
assert(numFrontPadElems < numSrcElemsPerDest &&
80+
"intraDataOffset must be less than scale");
8081

8182
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
8283
numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256257
newLoad);
257258
}
258259

259-
static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
260-
Value memref, Value index, Value value) {
261-
auto originType = dyn_cast<VectorType>(value.getType());
262-
auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
263-
auto scale = memrefElemType.getIntOrFloatBitWidth() /
264-
originType.getElementType().getIntOrFloatBitWidth();
265-
auto storeType =
266-
VectorType::get({originType.getNumElements() / scale}, memrefElemType);
267-
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
268-
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
269-
}
270-
271-
/// atomically store a subbyte-sized value to memory, with a mask.
260+
/// Atomically store a subbyte-sized value to memory, with a mask.
272261
static Value atomicStore(OpBuilder &rewriter, Location loc,
273-
Value emulatedMemref, Value emulatedIndex,
274-
TypedValue<VectorType> value, Value mask,
275-
int64_t scale) {
262+
TypedValue<MemRefType> emulatedMemref,
263+
Value emulatedIndex, TypedValue<VectorType> value,
264+
Value mask, int64_t scale) {
276265
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
277266
loc, emulatedMemref, ValueRange{emulatedIndex});
278267
OpBuilder builder =
@@ -294,6 +283,30 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
294283
return atomicOp;
295284
}
296285

286+
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
287+
static Value rmwStore(OpBuilder &rewriter, Location loc,
288+
TypedValue<MemRefType> emulatedMemref,
289+
Value emulatedIndex, TypedValue<VectorType> value,
290+
Value mask, int64_t numSrcElemsPerDest) {
291+
auto emulatedIOType =
292+
VectorType::get({1}, emulatedMemref.getType().getElementType());
293+
auto elemLoad = rewriter.create<vector::LoadOp>(
294+
loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
295+
auto fromBitcast = rewriter.create<vector::BitCastOp>(
296+
loc,
297+
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
298+
elemLoad);
299+
auto select = rewriter.create<arith::SelectOp>(loc, mask, value, fromBitcast);
300+
auto toBitcast =
301+
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
302+
return rewriter
303+
.create<vector::StoreOp>(loc, toBitcast, emulatedMemref, emulatedIndex)
304+
->getResult(0);
305+
}
306+
307+
static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
308+
"`atomicStore` and `rmwStore` must have same function type.");
309+
297310
// Extract a slice of a vector, and insert it into a byte vector.
298311
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
299312
Location loc, TypedValue<VectorType> vector,
@@ -322,6 +335,10 @@ namespace {
322335
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323336
using OpConversionPattern::OpConversionPattern;
324337

338+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
339+
: OpConversionPattern<vector::StoreOp>(context),
340+
useAtomicWrites_(useAtomicWrites) {}
341+
325342
LogicalResult
326343
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
327344
ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +360,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343360
return rewriter.notifyMatchFailure(
344361
op, "only dstBits % srcBits == 0 supported");
345362
}
346-
int scale = dstBits / srcBits;
363+
int numSrcElemsPerDest = dstBits / srcBits;
347364

348365
// Adjust the number of elements to store when emulating narrow types.
349366
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +376,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359376
// vector<4xi8>
360377

361378
auto origElements = valueToStore.getType().getNumElements();
362-
bool isUnalignedEmulation = origElements % scale != 0;
379+
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
363380

364381
auto stridedMetadata =
365382
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -374,62 +391,68 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374391
stridedMetadata.getConstifiedMixedStrides(),
375392
getAsOpFoldResult(adaptor.getIndices()));
376393

377-
auto foldedIntraVectorOffset =
394+
auto foldedNumFrontPadElems =
378395
isUnalignedEmulation
379396
? getConstantIntValue(linearizedInfo.intraDataOffset)
380397
: 0;
381398

382-
if (!foldedIntraVectorOffset) {
383-
// unimplemented case for dynamic front padding size
399+
if (!foldedNumFrontPadElems) {
400+
// Unimplemented case for dynamic front padding size != 0
384401
return failure();
385402
}
386403

387-
// conditions when atomic stores and all that are not needed:
404+
TypedValue<MemRefType> emulatedMemref =
405+
cast<TypedValue<MemRefType>>(adaptor.getBase());
406+
407+
// Shortcut: conditions when subbyte store at the front is not needed:
388408
// 1. The source vector size is multiple of byte size
389-
// 2. The address of the store is byte aligned
390-
if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
391-
auto numElements = origElements / scale;
409+
// 2. The address of the store is aligned to the emulated width boundary
410+
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
411+
auto numElements = origElements / numSrcElemsPerDest;
392412
auto bitCast = rewriter.create<vector::BitCastOp>(
393413
loc, VectorType::get(numElements, newElementType),
394414
op.getValueToStore());
395415
rewriter.replaceOpWithNewOp<vector::StoreOp>(
396-
op, bitCast.getResult(), adaptor.getBase(),
416+
op, bitCast.getResult(), emulatedMemref,
397417
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
398418
return llvm::success();
399419
}
400420

401-
Value emulatedMemref = adaptor.getBase();
402-
// the index into the target memref we are storing to
421+
// The index into the target memref we are storing to
403422
Value currentDestIndex =
404423
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
405424
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
406-
auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
407-
// the index into the source vector we are currently processing
425+
auto subWidthStoreMaskType =
426+
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
427+
// The index into the source vector we are currently processing
408428
auto currentSourceIndex = 0;
409429

410-
// 1. atomic store for the first byte
411-
auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
412-
if (frontAtomicStoreElem != 0) {
413-
auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
414-
if (*foldedIntraVectorOffset + origElements < scale) {
415-
std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
430+
// 1. Partial width store for the first byte, when the store address is not
431+
// aligned to emulated width boundary, deal with the unaligned part so that
432+
// the rest elements are aligned to width boundary.
433+
auto frontSubWidthStoreElem =
434+
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
435+
if (frontSubWidthStoreElem != 0) {
436+
auto frontMaskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, false);
437+
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
438+
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
416439
origElements, true);
417-
frontAtomicStoreElem = origElements;
440+
frontSubWidthStoreElem = origElements;
418441
} else {
419-
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
420-
*foldedIntraVectorOffset, true);
442+
std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
443+
*foldedNumFrontPadElems, true);
421444
}
422445
auto frontMask = rewriter.create<arith::ConstantOp>(
423-
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
446+
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
424447

425-
currentSourceIndex = scale - (*foldedIntraVectorOffset);
448+
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
426449
auto value = extractSliceIntoByte(
427450
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
428-
frontAtomicStoreElem, *foldedIntraVectorOffset);
451+
frontSubWidthStoreElem, *foldedNumFrontPadElems);
429452

430-
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
431-
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
432-
scale);
453+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
454+
cast<TypedValue<VectorType>>(value),
455+
frontMask.getResult(), numSrcElemsPerDest);
433456

434457
currentDestIndex = rewriter.create<arith::AddIOp>(
435458
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
@@ -440,44 +463,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440463
return success();
441464
}
442465

443-
// 2. non-atomic store
444-
int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445-
int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
446-
if (nonAtomicStoreSize != 0) {
447-
auto nonAtomicStorePart = staticallyExtractSubvector(
466+
// 2. Full width store. After the previous step, the store address is
467+
// aligned to the emulated width boundary.
468+
int64_t fullWidthStoreSize =
469+
(origElements - currentSourceIndex) / numSrcElemsPerDest;
470+
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
471+
if (fullWidthStoreSize != 0) {
472+
auto fullWidthStorePart = staticallyExtractSubvector(
448473
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449-
currentSourceIndex, numNonAtomicElements);
450-
451-
nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
452-
nonAtomicStorePart);
453-
454-
currentSourceIndex += numNonAtomicElements;
474+
currentSourceIndex, numNonFullWidthElements);
475+
476+
auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
477+
auto memrefElemType =
478+
dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
479+
auto storeType = VectorType::get(
480+
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
481+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
482+
fullWidthStorePart);
483+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
484+
currentDestIndex);
485+
486+
currentSourceIndex += numNonFullWidthElements;
455487
currentDestIndex = rewriter.create<arith::AddIOp>(
456488
loc, rewriter.getIndexType(), currentDestIndex,
457-
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
489+
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
458490
}
459491

460-
// 3. atomic store for the last byte
492+
// 3. Deal with trailing elements that are aligned to the emulated width,
493+
// but their length is smaller than the emulated width.
461494
auto remainingElements = origElements - currentSourceIndex;
462495
if (remainingElements != 0) {
463-
auto atomicStorePart = extractSliceIntoByte(
496+
auto subWidthStorePart = extractSliceIntoByte(
464497
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465498
currentSourceIndex, remainingElements, 0);
466499

467-
// back mask
468-
auto maskValues = llvm::SmallVector<bool>(scale, 0);
500+
// Generate back mask
501+
auto maskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, 0);
469502
std::fill_n(maskValues.begin(), remainingElements, 1);
470503
auto backMask = rewriter.create<arith::ConstantOp>(
471-
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
504+
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
472505

473-
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
474-
cast<TypedValue<VectorType>>(atomicStorePart),
475-
backMask.getResult(), scale);
506+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
507+
cast<TypedValue<VectorType>>(subWidthStorePart),
508+
backMask.getResult(), numSrcElemsPerDest);
476509
}
477510

478511
rewriter.eraseOp(op);
479512
return success();
480513
}
514+
515+
/// Store a subbyte-sized value to memory, with a mask. Depending on the
516+
/// configuration, it could be an atomic store or an RMW sequence.
517+
template <typename... Args>
518+
Value subEmulatedWidthStore(Args &&...args) const {
519+
std::function<decltype(atomicStore)> storeFunc =
520+
useAtomicWrites_ ? atomicStore : rmwStore;
521+
return storeFunc(std::forward<Args>(args)...);
522+
}
523+
524+
private:
525+
const bool useAtomicWrites_;
481526
};
482527

483528
//===----------------------------------------------------------------------===//
@@ -1673,12 +1718,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16731718

16741719
void vector::populateVectorNarrowTypeEmulationPatterns(
16751720
const arith::NarrowTypeEmulationConverter &typeConverter,
1676-
RewritePatternSet &patterns) {
1721+
RewritePatternSet &patterns, bool useAtomicWrites) {
16771722

1678-
// Populate `vector.*` conversion patterns.
1679-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1723+
// Populate `vector.*` load conversion patterns.
1724+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
16801725
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
16811726
typeConverter, patterns.getContext());
1727+
1728+
// Populate `vector.*` store conversion patterns. The caller can choose
1729+
// to avoid emitting atomic operations and reduce it to load-modify-write
1730+
// sequence for stores if it is known there are no thread contentions.
1731+
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
16821732
}
16831733

16841734
void vector::populateVectorNarrowTypeRewritePatterns(

0 commit comments

Comments
 (0)