@@ -363,6 +363,29 @@ static void atomicStore(OpBuilder &builder, Location loc,
363
363
builder.create <memref::AtomicYieldOp>(loc, scalarMaskedValue);
364
364
}
365
365
366
+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
367
+ // / It has similar logic to `atomicStore`, but without the atomicity.
368
+ static void rmwStore (OpBuilder &builder, Location loc,
369
+ MemRefValue linearizedMemref, Value linearizedIndex,
370
+ VectorValue valueToStore, Value mask) {
371
+ assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
372
+
373
+ // Load the original value from memory, and cast it to the original element
374
+ // type.
375
+ auto oneElemVecType =
376
+ VectorType::get ({1 }, linearizedMemref.getType ().getElementType ());
377
+ Value origVecValue = builder.create <vector::LoadOp>(
378
+ loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
379
+ origVecValue = builder.create <vector::BitCastOp>(loc, valueToStore.getType (),
380
+ origVecValue);
381
+
382
+ // Construct the final masked value and yield it.
383
+ Value maskedValue = selectAndCast (builder, loc, oneElemVecType, mask,
384
+ origVecValue, valueToStore);
385
+ builder.create <vector::StoreOp>(loc, maskedValue, linearizedMemref,
386
+ linearizedIndex);
387
+ }
388
+
366
389
// / Extract `sliceNumElements` from source `vector` at `extractOffset`,
367
390
// / and insert it into an empty vector at `insertOffset`.
368
391
// / Inputs:
@@ -405,6 +428,10 @@ namespace {
405
428
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
406
429
using OpConversionPattern::OpConversionPattern;
407
430
431
+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
432
+ : OpConversionPattern<vector::StoreOp>(context),
433
+ useAtomicWrites_ (useAtomicWrites) {}
434
+
408
435
LogicalResult
409
436
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
410
437
ConversionPatternRewriter &rewriter) const override {
@@ -611,13 +638,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
611
638
auto backMask = rewriter.create <arith::ConstantOp>(
612
639
loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
613
640
614
- atomicStore (rewriter, loc, memrefBase, currentDestIndex,
615
- cast<VectorValue>(subWidthStorePart), backMask.getResult ());
641
+ subEmulatedWidthStore (rewriter, loc, memrefBase, currentDestIndex,
642
+ cast<VectorValue>(subWidthStorePart),
643
+ backMask.getResult ());
616
644
}
617
645
618
646
rewriter.eraseOp (op);
619
647
return success ();
620
648
}
649
+
650
+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
651
+ // / configuration, it could be an atomic store or a non-atomic RMW sequence.
652
+ template <typename ... Args>
653
+ void subEmulatedWidthStore (Args &&...args) const {
654
+ static_assert (
655
+ std::is_same_v<decltype (atomicStore), decltype (rmwStore)> &&
656
+ " `atomicStore` and `rmwStore` must have same signature, as per "
657
+ " the design to keep the code clean, which one to call is "
658
+ " determined by the `useAtomicWrites` flag." );
659
+ std::function<decltype (atomicStore)> storeFunc =
660
+ useAtomicWrites_ ? atomicStore : rmwStore;
661
+ storeFunc (std::forward<Args>(args)...);
662
+ }
663
+
664
+ private:
665
+ const bool useAtomicWrites_;
621
666
};
622
667
623
668
// ===----------------------------------------------------------------------===//
@@ -1930,12 +1975,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1930
1975
1931
1976
void vector::populateVectorNarrowTypeEmulationPatterns (
1932
1977
const arith::NarrowTypeEmulationConverter &typeConverter,
1933
- RewritePatternSet &patterns) {
1978
+ RewritePatternSet &patterns, bool useAtomicWrites ) {
1934
1979
1935
1980
// Populate `vector.*` conversion patterns.
1936
- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1981
+ // TODO: #119553 support atomicity
1982
+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
1937
1983
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1938
1984
typeConverter, patterns.getContext ());
1985
+
1986
+ // Populate `vector.*` store conversion patterns. The caller can choose
1987
+ // to avoid emitting atomic operations and reduce it to load-modify-write
1988
+ // sequence for stores if it is known there are no thread contentions.
1989
+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
1939
1990
}
1940
1991
1941
1992
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments