Skip to content

Commit 3e2e4b5

Browse files
committed
First commit
1 parent 654b763 commit 3e2e4b5

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. The `useAtomicWrites` indicates whether to use
368+
/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
369+
/// rmw sequence otherwise.
368370
void populateVectorNarrowTypeEmulationPatterns(
369371
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
372+
RewritePatternSet &patterns, bool useAtomicWrites = true);
371373

372374
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373375
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,29 @@ static void atomicStore(OpBuilder &builder, Location loc,
363363
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
364364
}
365365

366+
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
367+
/// It has similar logic to `atomicStore`, but without the atomicity.
368+
static void rmwStore(OpBuilder &builder, Location loc,
369+
MemRefValue linearizedMemref, Value linearizedIndex,
370+
VectorValue valueToStore, Value mask) {
371+
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
372+
373+
// Load the original value from memory, and cast it to the original element
374+
// type.
375+
auto oneElemVecType =
376+
VectorType::get({1}, linearizedMemref.getType().getElementType());
377+
Value origVecValue = builder.create<vector::LoadOp>(
378+
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
379+
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
380+
origVecValue);
381+
382+
// Construct the final masked value and yield it.
383+
Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
384+
origVecValue, valueToStore);
385+
builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
386+
linearizedIndex);
387+
}
388+
366389
/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
367390
/// and insert it into an empty vector at `insertOffset`.
368391
/// Inputs:
@@ -405,6 +428,10 @@ namespace {
405428
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
406429
using OpConversionPattern::OpConversionPattern;
407430

431+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
432+
: OpConversionPattern<vector::StoreOp>(context),
433+
useAtomicWrites_(useAtomicWrites) {}
434+
408435
LogicalResult
409436
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
410437
ConversionPatternRewriter &rewriter) const override {
@@ -611,13 +638,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
611638
auto backMask = rewriter.create<arith::ConstantOp>(
612639
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
613640

614-
atomicStore(rewriter, loc, memrefBase, currentDestIndex,
615-
cast<VectorValue>(subWidthStorePart), backMask.getResult());
641+
subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
642+
cast<VectorValue>(subWidthStorePart),
643+
backMask.getResult());
616644
}
617645

618646
rewriter.eraseOp(op);
619647
return success();
620648
}
649+
650+
/// Store a subbyte-sized value to memory, with a mask. Depending on the
651+
/// configuration, it could be an atomic store or a non-atomic RMW sequence.
652+
template <typename... Args>
653+
void subEmulatedWidthStore(Args &&...args) const {
654+
static_assert(
655+
std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
656+
"`atomicStore` and `rmwStore` must have same signature, as per "
657+
"the design to keep the code clean, which one to call is "
658+
"determined by the `useAtomicWrites` flag.");
659+
std::function<decltype(atomicStore)> storeFunc =
660+
useAtomicWrites_ ? atomicStore : rmwStore;
661+
storeFunc(std::forward<Args>(args)...);
662+
}
663+
664+
private:
665+
const bool useAtomicWrites_;
621666
};
622667

623668
//===----------------------------------------------------------------------===//
@@ -1930,12 +1975,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19301975

19311976
void vector::populateVectorNarrowTypeEmulationPatterns(
19321977
const arith::NarrowTypeEmulationConverter &typeConverter,
1933-
RewritePatternSet &patterns) {
1978+
RewritePatternSet &patterns, bool useAtomicWrites) {
19341979

19351980
// Populate `vector.*` conversion patterns.
1936-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1981+
// TODO: #119553 support atomicity
1982+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
19371983
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
19381984
typeConverter, patterns.getContext());
1985+
1986+
// Populate `vector.*` store conversion patterns. The caller can choose
1987+
// to avoid emitting atomic operations and reduce it to load-modify-write
1988+
// sequence for stores if it is known there are no thread contentions.
1989+
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
19391990
}
19401991

19411992
void vector::populateVectorNarrowTypeRewritePatterns(

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
9999

100100
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
101101
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
102-
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
102+
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
103+
atomicStore);
103104

104105
if (failed(applyPartialConversion(op, target, std::move(patterns))))
105106
signalPassFailure();
@@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
118119
*this, "skip-memref-type-conversion",
119120
llvm::cl::desc("disable memref type conversion (to test failures)"),
120121
llvm::cl::init(false)};
122+
123+
Option<bool> atomicStore{
124+
*this, "atomic-store",
125+
llvm::cl::desc("use atomic store instead of load-modify-write"),
126+
llvm::cl::init(true)};
121127
};
122128
} // namespace
123129

0 commit comments

Comments
 (0)