15
15
#include " mlir/IR/TypeUtilities.h"
16
16
#include " mlir/Pass/Pass.h"
17
17
#include " mlir/Support/LogicalResult.h"
18
- #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
18
+ #include " mlir/Transforms/WalkPatternRewriteDriver .h"
19
19
20
20
namespace mlir {
21
21
#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
@@ -36,17 +36,16 @@ using namespace mlir;
36
36
// / - The permutation map doesn't perform permutation (broadcasting is allowed).
37
37
// / Note: those conditions mostly come from TransferReadToVectorLoadLowering
38
38
// / pass.
39
- static LogicalResult
40
- transferPreconditions (PatternRewriter &rewriter,
41
- VectorTransferOpInterface xferOp,
42
- SmallVector<unsigned > &broadcastedDims,
43
- VectorType &unbroadcastedVectorType) {
39
+ static LogicalResult transferPreconditions (
40
+ PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
41
+ bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
44
42
if (!xferOp.getMask ())
45
43
return rewriter.notifyMatchFailure (xferOp, " Only support masked transfer" );
46
44
47
45
// Permutations are handled by VectorToSCF or
48
46
// populateVectorTransferPermutationMapLoweringPatterns.
49
47
// We let the 0-d corner case pass-through as it is supported.
48
+ SmallVector<unsigned > broadcastedDims;
50
49
if (!xferOp.getPermutationMap ().isMinorIdentityWithBroadcasting (
51
50
&broadcastedDims))
52
51
return rewriter.notifyMatchFailure (xferOp, " not minor identity + bcast" );
@@ -56,9 +55,8 @@ transferPreconditions(PatternRewriter &rewriter,
56
55
return rewriter.notifyMatchFailure (xferOp, " not a memref source" );
57
56
58
57
Attribute addrSpace = memRefType.getMemorySpace ();
59
- if (!addrSpace ||
60
- llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue () !=
61
- amdgpu::AddressSpace::FatRawBuffer)
58
+ if (!addrSpace || dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue () !=
59
+ amdgpu::AddressSpace::FatRawBuffer)
62
60
return rewriter.notifyMatchFailure (xferOp, " not in buffer address space" );
63
61
64
62
// Non-unit strides are handled by VectorToSCF.
@@ -73,6 +71,7 @@ transferPreconditions(PatternRewriter &rewriter,
73
71
unbroadcastedVectorShape[i] = 1 ;
74
72
unbroadcastedVectorType = xferOp.getVectorType ().cloneWith (
75
73
unbroadcastedVectorShape, xferOp.getVectorType ().getElementType ());
74
+ requiresBroadcasting = !broadcastedDims.empty ();
76
75
77
76
// `vector.load` supports vector types as memref's elements only when the
78
77
// resulting vector type is the same as the element type.
@@ -98,31 +97,31 @@ transferPreconditions(PatternRewriter &rewriter,
98
97
return success ();
99
98
}
100
99
101
- struct TransferReadLowering : public OpRewritePattern <vector::TransferReadOp> {
102
- using OpRewritePattern<vector::TransferReadOp> ::OpRewritePattern;
100
+ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
101
+ using OpRewritePattern::OpRewritePattern;
103
102
104
103
LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
105
104
PatternRewriter &rewriter) const override {
106
105
107
- SmallVector< unsigned > broadcastedDims ;
106
+ bool requiresBroadcasting = false ;
108
107
VectorType unbroadcastedVectorType;
109
- if (failed (transferPreconditions (rewriter, readOp, broadcastedDims ,
108
+ if (failed (transferPreconditions (rewriter, readOp, requiresBroadcasting ,
110
109
unbroadcastedVectorType))) {
111
110
return failure ();
112
111
}
113
112
114
- Value fill = rewriter.create <vector::SplatOp>(
115
- readOp.getLoc (), unbroadcastedVectorType, readOp.getPadding ());
113
+ Location loc = readOp.getLoc ();
114
+ Value fill = rewriter.create <vector::SplatOp>(loc, unbroadcastedVectorType,
115
+ readOp.getPadding ());
116
116
Value load = rewriter.create <vector::LoadOp>(
117
- readOp.getLoc (), unbroadcastedVectorType, readOp.getSource (),
118
- readOp.getIndices ());
119
- Value res = rewriter.create <arith::SelectOp>(
120
- readOp.getLoc (), unbroadcastedVectorType, readOp.getMask (), load, fill);
117
+ loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
118
+ Value res = rewriter.create <arith::SelectOp>(loc, unbroadcastedVectorType,
119
+ readOp.getMask (), load, fill);
121
120
122
121
// Insert a broadcasting op if required.
123
- if (!broadcastedDims. empty () ) {
124
- res = rewriter.create <vector::BroadcastOp>(readOp.getLoc (),
125
- readOp. getVectorType (), res);
122
+ if (requiresBroadcasting ) {
123
+ res = rewriter.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
124
+ res);
126
125
}
127
126
128
127
rewriter.replaceOp (readOp, res);
@@ -136,12 +135,11 @@ void mlir::populateVectorToAMDGPUConversionPatterns(
136
135
patterns.add <TransferReadLowering>(patterns.getContext ());
137
136
}
138
137
139
- struct ConvertVectorToAMDGPUPass
138
+ struct ConvertVectorToAMDGPUPass final
140
139
: public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
141
140
void runOnOperation () override {
142
141
RewritePatternSet patterns (&getContext ());
143
142
populateVectorToAMDGPUConversionPatterns (patterns);
144
- if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
145
- return signalPassFailure ();
143
+ walkAndApplyPatterns (getOperation (), std::move (patterns));
146
144
}
147
145
};
0 commit comments