Skip to content

Commit b46f01e

Browse files
committed
Add support to avoid atomic operations.
1 parent 1430e84 commit b46f01e

File tree

5 files changed

+235
-76
lines changed

5 files changed

+235
-76
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. `useAtomicWrites` indicates whether to use atomic
368+
/// operations in the places where thread contention is possible.
368369
void populateVectorNarrowTypeEmulationPatterns(
369370
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
371+
RewritePatternSet &patterns, bool useAtomicWrites = true);
371372

372373
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373374
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 119 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7676
int numSrcElemsPerDest,
7777
int numFrontPadElems = 0) {
7878

79-
assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
79+
assert(numFrontPadElems < numSrcElemsPerDest &&
80+
"intraDataOffset must be less than scale");
8081

8182
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
8283
numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256257
newLoad);
257258
}
258259

259-
static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
260-
Value memref, Value index, Value value) {
261-
auto originType = dyn_cast<VectorType>(value.getType());
262-
auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
263-
auto scale = memrefElemType.getIntOrFloatBitWidth() /
264-
originType.getElementType().getIntOrFloatBitWidth();
265-
auto storeType =
266-
VectorType::get({originType.getNumElements() / scale}, memrefElemType);
267-
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
268-
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
269-
}
270-
271-
/// atomically store a subbyte-sized value to memory, with a mask.
272-
static Value atomicStore(OpBuilder &rewriter, Location loc,
273-
Value emulatedMemref, Value emulatedIndex,
274-
TypedValue<VectorType> value, Value mask,
275-
int64_t scale) {
260+
/// Atomically store a subbyte-sized value to memory, with a mask.
261+
static void atomicStore(OpBuilder &rewriter, Location loc,
262+
TypedValue<MemRefType> emulatedMemref,
263+
Value emulatedIndex, TypedValue<VectorType> value,
264+
Value mask, int64_t scale) {
276265
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
277266
loc, emulatedMemref, ValueRange{emulatedIndex});
278267
OpBuilder builder =
@@ -291,9 +280,31 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
291280
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
292281
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
293282
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
294-
return atomicOp;
295283
}
296284

285+
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
286+
static void rmwStore(OpBuilder &rewriter, Location loc,
287+
TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
288+
TypedValue<VectorType> value, Value mask,
289+
int64_t numSrcElemsPerDest) {
290+
auto emulatedIOType =
291+
VectorType::get({1}, emulatedMemref.getType().getElementType());
292+
auto elemLoad = rewriter.create<vector::LoadOp>(
293+
loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
294+
auto fromBitcast = rewriter.create<vector::BitCastOp>(
295+
loc,
296+
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
297+
elemLoad);
298+
auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
299+
auto toBitcast =
300+
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
301+
rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
302+
emulatedIndex);
303+
}
304+
305+
static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
306+
"`atomicStore` and `rmwStore` must have same function type.");
307+
297308
// Extract a slice of a vector, and insert it into a byte vector.
298309
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
299310
Location loc, TypedValue<VectorType> vector,
@@ -322,6 +333,10 @@ namespace {
322333
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323334
using OpConversionPattern::OpConversionPattern;
324335

336+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
337+
: OpConversionPattern<vector::StoreOp>(context),
338+
useAtomicWrites_(useAtomicWrites) {}
339+
325340
LogicalResult
326341
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
327342
ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +358,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343358
return rewriter.notifyMatchFailure(
344359
op, "only dstBits % srcBits == 0 supported");
345360
}
346-
int scale = dstBits / srcBits;
361+
int numSrcElemsPerDest = dstBits / srcBits;
347362

348363
// Adjust the number of elements to store when emulating narrow types.
349364
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +374,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359374
// vector<4xi8>
360375

361376
auto origElements = valueToStore.getType().getNumElements();
362-
bool isUnalignedEmulation = origElements % scale != 0;
377+
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
363378

364379
auto stridedMetadata =
365380
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -374,62 +389,68 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374389
stridedMetadata.getConstifiedMixedStrides(),
375390
getAsOpFoldResult(adaptor.getIndices()));
376391

377-
auto foldedIntraVectorOffset =
392+
auto foldedNumFrontPadElems =
378393
isUnalignedEmulation
379394
? getConstantIntValue(linearizedInfo.intraDataOffset)
380395
: 0;
381396

382-
if (!foldedIntraVectorOffset) {
383-
// unimplemented case for dynamic front padding size
397+
if (!foldedNumFrontPadElems) {
398+
// Unimplemented case for dynamic front padding size != 0
384399
return failure();
385400
}
386401

387-
// conditions when atomic stores and all that are not needed:
402+
TypedValue<MemRefType> emulatedMemref =
403+
cast<TypedValue<MemRefType>>(adaptor.getBase());
404+
405+
// Shortcut: conditions when subbyte store at the front is not needed:
388406
// 1. The source vector size is multiple of byte size
389-
// 2. The address of the store is byte aligned
390-
if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
391-
auto numElements = origElements / scale;
407+
// 2. The address of the store is aligned to the emulated width boundary
408+
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
409+
auto numElements = origElements / numSrcElemsPerDest;
392410
auto bitCast = rewriter.create<vector::BitCastOp>(
393411
loc, VectorType::get(numElements, newElementType),
394412
op.getValueToStore());
395413
rewriter.replaceOpWithNewOp<vector::StoreOp>(
396-
op, bitCast.getResult(), adaptor.getBase(),
414+
op, bitCast.getResult(), emulatedMemref,
397415
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
398416
return llvm::success();
399417
}
400418

401-
Value emulatedMemref = adaptor.getBase();
402-
// the index into the target memref we are storing to
419+
// The index into the target memref we are storing to
403420
Value currentDestIndex =
404421
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
405422
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
406-
auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
407-
// the index into the source vector we are currently processing
423+
auto subWidthStoreMaskType =
424+
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
425+
// The index into the source vector we are currently processing
408426
auto currentSourceIndex = 0;
409427

410-
// 1. atomic store for the first byte
411-
auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
412-
if (frontAtomicStoreElem != 0) {
413-
auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
414-
if (*foldedIntraVectorOffset + origElements < scale) {
415-
std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
428+
// 1. Partial width store for the first byte, when the store address is not
429+
// aligned to emulated width boundary, deal with the unaligned part so that
430+
// the rest elements are aligned to width boundary.
431+
auto frontSubWidthStoreElem =
432+
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
433+
if (frontSubWidthStoreElem != 0) {
434+
auto frontMaskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, false);
435+
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
436+
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
416437
origElements, true);
417-
frontAtomicStoreElem = origElements;
438+
frontSubWidthStoreElem = origElements;
418439
} else {
419-
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
420-
*foldedIntraVectorOffset, true);
440+
std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
441+
*foldedNumFrontPadElems, true);
421442
}
422443
auto frontMask = rewriter.create<arith::ConstantOp>(
423-
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
444+
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
424445

425-
currentSourceIndex = scale - (*foldedIntraVectorOffset);
446+
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
426447
auto value = extractSliceIntoByte(
427448
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
428-
frontAtomicStoreElem, *foldedIntraVectorOffset);
449+
frontSubWidthStoreElem, *foldedNumFrontPadElems);
429450

430-
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
431-
cast<TypedValue<VectorType>>(value), frontMask.getResult(),
432-
scale);
451+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
452+
cast<TypedValue<VectorType>>(value),
453+
frontMask.getResult(), numSrcElemsPerDest);
433454

434455
currentDestIndex = rewriter.create<arith::AddIOp>(
435456
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
@@ -440,44 +461,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440461
return success();
441462
}
442463

443-
// 2. non-atomic store
444-
int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445-
int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
446-
if (nonAtomicStoreSize != 0) {
447-
auto nonAtomicStorePart = staticallyExtractSubvector(
464+
// 2. Full width store. After the previous step, the store address is
465+
// aligned to the emulated width boundary.
466+
int64_t fullWidthStoreSize =
467+
(origElements - currentSourceIndex) / numSrcElemsPerDest;
468+
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
469+
if (fullWidthStoreSize != 0) {
470+
auto fullWidthStorePart = staticallyExtractSubvector(
448471
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449-
currentSourceIndex, numNonAtomicElements);
450-
451-
nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
452-
nonAtomicStorePart);
453-
454-
currentSourceIndex += numNonAtomicElements;
472+
currentSourceIndex, numNonFullWidthElements);
473+
474+
auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
475+
auto memrefElemType =
476+
dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
477+
auto storeType = VectorType::get(
478+
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
479+
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
480+
fullWidthStorePart);
481+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
482+
currentDestIndex);
483+
484+
currentSourceIndex += numNonFullWidthElements;
455485
currentDestIndex = rewriter.create<arith::AddIOp>(
456486
loc, rewriter.getIndexType(), currentDestIndex,
457-
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
487+
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
458488
}
459489

460-
// 3. atomic store for the last byte
490+
// 3. Deal with trailing elements that are aligned to the emulated width,
491+
// but their length is smaller than the emulated width.
461492
auto remainingElements = origElements - currentSourceIndex;
462493
if (remainingElements != 0) {
463-
auto atomicStorePart = extractSliceIntoByte(
494+
auto subWidthStorePart = extractSliceIntoByte(
464495
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465496
currentSourceIndex, remainingElements, 0);
466497

467-
// back mask
468-
auto maskValues = llvm::SmallVector<bool>(scale, 0);
498+
// Generate back mask
499+
auto maskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, 0);
469500
std::fill_n(maskValues.begin(), remainingElements, 1);
470501
auto backMask = rewriter.create<arith::ConstantOp>(
471-
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
502+
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
472503

473-
atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
474-
cast<TypedValue<VectorType>>(atomicStorePart),
475-
backMask.getResult(), scale);
504+
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
505+
cast<TypedValue<VectorType>>(subWidthStorePart),
506+
backMask.getResult(), numSrcElemsPerDest);
476507
}
477508

478509
rewriter.eraseOp(op);
479510
return success();
480511
}
512+
513+
/// Store a subbyte-sized value to memory, with a mask. Depending on the
514+
/// configuration, it could be an atomic store or an RMW sequence.
515+
template <typename... Args>
516+
void subEmulatedWidthStore(Args &&...args) const {
517+
std::function<decltype(atomicStore)> storeFunc =
518+
useAtomicWrites_ ? atomicStore : rmwStore;
519+
storeFunc(std::forward<Args>(args)...);
520+
}
521+
522+
private:
523+
const bool useAtomicWrites_;
481524
};
482525

483526
//===----------------------------------------------------------------------===//
@@ -1673,12 +1716,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16731716

16741717
void vector::populateVectorNarrowTypeEmulationPatterns(
16751718
const arith::NarrowTypeEmulationConverter &typeConverter,
1676-
RewritePatternSet &patterns) {
1719+
RewritePatternSet &patterns, bool useAtomicWrites) {
16771720

1678-
// Populate `vector.*` conversion patterns.
1679-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1721+
// Populate `vector.*` load conversion patterns.
1722+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
16801723
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
16811724
typeConverter, patterns.getContext());
1725+
1726+
// Populate `vector.*` store conversion patterns. The caller can choose
1727+
// to avoid emitting atomic operations and reduce it to load-modify-write
1728+
// sequence for stores if it is known there are no thread contentions.
1729+
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
16821730
}
16831731

16841732
void vector::populateVectorNarrowTypeRewritePatterns(

0 commit comments

Comments
 (0)