Skip to content

Commit fd2b089

Browse files
[mlir][Vector] Lowering of transfer_read/write to vector.load/store
This patch introduces progressive lowering patterns for rewriting vector.transfer_read/write to vector.load/store and vector.broadcast in certain supported cases. Reviewed By: dcaballe, nicolasvasilache Differential Revision: https://reviews.llvm.org/D97822
1 parent 5eaeb0f commit fd2b089

File tree

6 files changed

+383
-0
lines changed

6 files changed

+383
-0
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
8585
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
8686
MLIRContext *context);
8787

88+
/// Collect a set of transfer read/write lowering patterns.
89+
///
90+
/// These patterns lower transfer ops to simpler ops like `vector.load`,
91+
/// `vector.store` and `vector.broadcast`.
92+
void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
93+
MLIRContext *context);
94+
8895
/// An attribute that specifies the combining function for `vector.contract`,
8996
/// and `vector.reduction`.
9097
class CombiningKindAttr

mlir/include/mlir/IR/AffineMap.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ class AffineMap {
104104
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
105105
bool isMinorIdentity() const;
106106

107+
/// Returns true if this affine map is a minor identity up to broadcasted
108+
/// dimensions which are indicated by value 0 in the result. If
109+
/// `broadcastedDims` is not null, it will be populated with the indices of
110+
/// the broadcasted dimensions in the result array.
111+
/// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)>
112+
/// (`broadcastedDims` will contain [0, 2])
113+
bool isMinorIdentityWithBroadcasting(
114+
SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
115+
107116
/// Returns true if this affine map is an empty map, i.e., () -> ().
108117
bool isEmpty() const;
109118

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "mlir/IR/Types.h"
3838
#include "mlir/Interfaces/VectorInterfaces.h"
3939

40+
#include "llvm/ADT/STLExtras.h"
4041
#include "llvm/Support/CommandLine.h"
4142
#include "llvm/Support/Debug.h"
4243
#include "llvm/Support/raw_ostream.h"
@@ -2729,6 +2730,116 @@ struct TransferWriteInsertPattern
27292730
}
27302731
};
27312732

2733+
/// Progressive lowering of transfer_read. This pattern supports lowering of
2734+
/// `vector.transfer_read` to a combination of `vector.load` and
2735+
/// `vector.broadcast` if all of the following hold:
2736+
/// - The op reads from a memref with the default layout.
2737+
/// - Masking is not required.
2738+
/// - If the memref's element type is a vector type then it coincides with the
2739+
/// result type.
2740+
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
2741+
struct TransferReadToVectorLoadLowering
2742+
: public OpRewritePattern<vector::TransferReadOp> {
2743+
TransferReadToVectorLoadLowering(MLIRContext *context)
2744+
: OpRewritePattern<vector::TransferReadOp>(context) {}
2745+
LogicalResult matchAndRewrite(vector::TransferReadOp read,
2746+
PatternRewriter &rewriter) const override {
2747+
SmallVector<unsigned, 4> broadcastedDims;
2748+
// TODO: Support permutations.
2749+
if (!read.permutation_map().isMinorIdentityWithBroadcasting(
2750+
&broadcastedDims))
2751+
return failure();
2752+
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
2753+
if (!memRefType)
2754+
return failure();
2755+
2756+
// If there is broadcasting involved then we first load the unbroadcasted
2757+
// vector, and then broadcast it with `vector.broadcast`.
2758+
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
2759+
SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
2760+
vectorShape.end());
2761+
for (unsigned i : broadcastedDims)
2762+
unbroadcastedVectorShape[i] = 1;
2763+
VectorType unbroadcastedVectorType = VectorType::get(
2764+
unbroadcastedVectorShape, read.getVectorType().getElementType());
2765+
2766+
// `vector.load` supports vector types as memref's elements only when the
2767+
// resulting vector type is the same as the element type.
2768+
if (memRefType.getElementType().isa<VectorType>() &&
2769+
memRefType.getElementType() != unbroadcastedVectorType)
2770+
return failure();
2771+
// Only the default layout is supported by `vector.load`.
2772+
// TODO: Support non-default layouts.
2773+
if (!memRefType.getAffineMaps().empty())
2774+
return failure();
2775+
// TODO: When masking is required, we can create a MaskedLoadOp
2776+
if (read.hasMaskedDim())
2777+
return failure();
2778+
2779+
Operation *loadOp;
2780+
if (!broadcastedDims.empty() &&
2781+
unbroadcastedVectorType.getNumElements() == 1) {
2782+
// If broadcasting is required and the number of loaded elements is 1 then
2783+
// we can create `std.load` instead of `vector.load`.
2784+
loadOp = rewriter.create<mlir::LoadOp>(read.getLoc(), read.source(),
2785+
read.indices());
2786+
} else {
2787+
// Otherwise create `vector.load`.
2788+
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
2789+
unbroadcastedVectorType,
2790+
read.source(), read.indices());
2791+
}
2792+
2793+
// Insert a broadcasting op if required.
2794+
if (!broadcastedDims.empty()) {
2795+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2796+
read, read.getVectorType(), loadOp->getResult(0));
2797+
} else {
2798+
rewriter.replaceOp(read, loadOp->getResult(0));
2799+
}
2800+
2801+
return success();
2802+
}
2803+
};
2804+
2805+
/// Progressive lowering of transfer_write. This pattern supports lowering of
2806+
/// `vector.transfer_write` to `vector.store` if all of the following hold:
2807+
/// - The op writes to a memref with the default layout.
2808+
/// - Masking is not required.
2809+
/// - If the memref's element type is a vector type then it coincides with the
2810+
/// type of the written value.
2811+
/// - The permutation map is the minor identity map (neither permutation nor
2812+
/// broadcasting is allowed).
2813+
struct TransferWriteToVectorStoreLowering
2814+
: public OpRewritePattern<vector::TransferWriteOp> {
2815+
TransferWriteToVectorStoreLowering(MLIRContext *context)
2816+
: OpRewritePattern<vector::TransferWriteOp>(context) {}
2817+
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2818+
PatternRewriter &rewriter) const override {
2819+
// TODO: Support non-minor-identity maps
2820+
if (!write.permutation_map().isMinorIdentity())
2821+
return failure();
2822+
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
2823+
if (!memRefType)
2824+
return failure();
2825+
// `vector.store` supports vector types as memref's elements only when the
2826+
// type of the vector value being written is the same as the element type.
2827+
if (memRefType.getElementType().isa<VectorType>() &&
2828+
memRefType.getElementType() != write.getVectorType())
2829+
return failure();
2830+
// Only the default layout is supported by `vector.store`.
2831+
// TODO: Support non-default layouts.
2832+
if (!memRefType.getAffineMaps().empty())
2833+
return failure();
2834+
// TODO: When masking is required, we can create a MaskedStoreOp
2835+
if (write.hasMaskedDim())
2836+
return failure();
2837+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
2838+
write, write.vector(), write.source(), write.indices());
2839+
return success();
2840+
}
2841+
};
2842+
27322843
// Trims leading one dimensions from `oldType` and returns the result type.
27332844
// Returns `vector<1xT>` if `oldType` only has one element.
27342845
static VectorType trimLeadingOneDims(VectorType oldType) {
@@ -3201,3 +3312,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
32013312
ContractionOpToOuterProductOpLowering>(parameters, context);
32023313
// clang-format on
32033314
}
3315+
3316+
void mlir::vector::populateVectorTransferLoweringPatterns(
3317+
OwningRewritePatternList &patterns, MLIRContext *context) {
3318+
patterns.insert<TransferReadToVectorLoadLowering,
3319+
TransferWriteToVectorStoreLowering>(context);
3320+
}

mlir/lib/IR/AffineMap.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,35 @@ bool AffineMap::isMinorIdentity() const {
110110
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
111111
}
112112

113+
/// Returns true if this affine map is a minor identity up to broadcasted
114+
/// dimensions which are indicated by value 0 in the result.
115+
bool AffineMap::isMinorIdentityWithBroadcasting(
116+
SmallVectorImpl<unsigned> *broadcastedDims) const {
117+
if (broadcastedDims)
118+
broadcastedDims->clear();
119+
if (getNumDims() < getNumResults())
120+
return false;
121+
unsigned suffixStart = getNumDims() - getNumResults();
122+
for (auto idxAndExpr : llvm::enumerate(getResults())) {
123+
unsigned resIdx = idxAndExpr.index();
124+
AffineExpr expr = idxAndExpr.value();
125+
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
126+
// Each result may be either a constant 0 (broadcasted dimension).
127+
if (constExpr.getValue() != 0)
128+
return false;
129+
if (broadcastedDims)
130+
broadcastedDims->push_back(resIdx);
131+
} else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
132+
// Or it may be the input dimension corresponding to this result position.
133+
if (dimExpr.getPosition() != suffixStart + resIdx)
134+
return false;
135+
} else {
136+
return false;
137+
}
138+
}
139+
return true;
140+
}
141+
113142
/// Returns an AffineMap representing a permutation.
114143
AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
115144
MLIRContext *context) {

0 commit comments

Comments
 (0)