Skip to content

Commit 4610b01

Browse files
committed
Addressing review feedbacks
1 parent 7f6d6ef commit 4610b01

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "mlir/IR/TypeUtilities.h"
1616
#include "mlir/Pass/Pass.h"
1717
#include "mlir/Support/LogicalResult.h"
18-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1919

2020
namespace mlir {
2121
#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
@@ -36,17 +36,16 @@ using namespace mlir;
3636
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
3737
/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
3838
/// pass.
39-
static LogicalResult
40-
transferPreconditions(PatternRewriter &rewriter,
41-
VectorTransferOpInterface xferOp,
42-
SmallVector<unsigned> &broadcastedDims,
43-
VectorType &unbroadcastedVectorType) {
39+
static LogicalResult transferPreconditions(
40+
PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
41+
bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
4442
if (!xferOp.getMask())
4543
return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
4644

4745
// Permutations are handled by VectorToSCF or
4846
// populateVectorTransferPermutationMapLoweringPatterns.
4947
// We let the 0-d corner case pass-through as it is supported.
48+
SmallVector<unsigned> broadcastedDims;
5049
if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
5150
&broadcastedDims))
5251
return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
@@ -56,9 +55,8 @@ transferPreconditions(PatternRewriter &rewriter,
5655
return rewriter.notifyMatchFailure(xferOp, "not a memref source");
5756

5857
Attribute addrSpace = memRefType.getMemorySpace();
59-
if (!addrSpace ||
60-
llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
61-
amdgpu::AddressSpace::FatRawBuffer)
58+
if (!addrSpace || dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
59+
amdgpu::AddressSpace::FatRawBuffer)
6260
return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
6361

6462
// Non-unit strides are handled by VectorToSCF.
@@ -73,6 +71,7 @@ transferPreconditions(PatternRewriter &rewriter,
7371
unbroadcastedVectorShape[i] = 1;
7472
unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
7573
unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
74+
requiresBroadcasting = !broadcastedDims.empty();
7675

7776
// `vector.load` supports vector types as memref's elements only when the
7877
// resulting vector type is the same as the element type.
@@ -98,31 +97,31 @@ transferPreconditions(PatternRewriter &rewriter,
9897
return success();
9998
}
10099

101-
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
102-
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
100+
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
101+
using OpRewritePattern::OpRewritePattern;
103102

104103
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
105104
PatternRewriter &rewriter) const override {
106105

107-
SmallVector<unsigned> broadcastedDims;
106+
bool requiresBroadcasting = false;
108107
VectorType unbroadcastedVectorType;
109-
if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
108+
if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
110109
unbroadcastedVectorType))) {
111110
return failure();
112111
}
113112

114-
Value fill = rewriter.create<vector::SplatOp>(
115-
readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
113+
Location loc = readOp.getLoc();
114+
Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
115+
readOp.getPadding());
116116
Value load = rewriter.create<vector::LoadOp>(
117-
readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
118-
readOp.getIndices());
119-
Value res = rewriter.create<arith::SelectOp>(
120-
readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
117+
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
118+
Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
119+
readOp.getMask(), load, fill);
121120

122121
// Insert a broadcasting op if required.
123-
if (!broadcastedDims.empty()) {
124-
res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
125-
readOp.getVectorType(), res);
122+
if (requiresBroadcasting) {
123+
res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
124+
res);
126125
}
127126

128127
rewriter.replaceOp(readOp, res);
@@ -136,12 +135,11 @@ void mlir::populateVectorToAMDGPUConversionPatterns(
136135
patterns.add<TransferReadLowering>(patterns.getContext());
137136
}
138137

139-
struct ConvertVectorToAMDGPUPass
138+
struct ConvertVectorToAMDGPUPass final
140139
: public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
141140
void runOnOperation() override {
142141
RewritePatternSet patterns(&getContext());
143142
populateVectorToAMDGPUConversionPatterns(patterns);
144-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
145-
return signalPassFailure();
143+
walkAndApplyPatterns(getOperation(), std::move(patterns));
146144
}
147145
};

mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --convert-vector-to-amdgpu --split-input-file | FileCheck %s
22

33
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
44
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
@@ -9,9 +9,10 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
99
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
1010
return %res : vector<4xf32>
1111
}
12-
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
12+
// CHECK: %[[CST:.*]] = arith.constant 0.0
13+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
1314
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
14-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
15+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
1516
// CHECK: return %[[SELECT]] : vector<4xf32>
1617

1718
// -----
@@ -43,9 +44,10 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
4344
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
4445
return %res : vector<4xf32>
4546
}
46-
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
47+
// CHECK: %[[CST:.*]] = arith.constant 0.0
48+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
4749
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
48-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
50+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
4951
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
5052
// CHECK: return %[[BROADCAST]] : vector<4xf32>
5153

@@ -62,7 +64,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
6264
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
6365
return %res : vector<1xf32>
6466
}
65-
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
67+
// CHECK: %[[CST:.*]] = arith.constant 0.0
68+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
6669
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
67-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
70+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
6871
// CHECK: return %[[SELECT]] : vector<1xf32>

0 commit comments

Comments
 (0)