Skip to content

Commit 2b983a2

Browse files
authored
[MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (#135014)
Motivation: amdgpu buffer load instruction will return all zeros when loading sub-word values. For example, assuming the buffer size is exactly one word and we attempt to invoke `llvm.amdgcn.raw.ptr.buffer.load.v2i32` starting from byte 2 of the word, we will not receive the actual value of the buffer but all zeros for the first word. This is because the boundary has been crossed for the first word. This PR come up with a fix to this problem, such that, it creates a bounds check against the buffer load instruction. It will compare the offset + vector size to see if the upper bound of the address will exceed the buffer size. If it does, masked transfer read will be optimized to `vector.load` + `arith.select`, else, it will continue to fall back to default lowering of the masked vector load.
1 parent 85eb44e commit 2b983a2

File tree

5 files changed

+233
-35
lines changed

5 files changed

+233
-35
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
5454
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
5555
let summary = "Lower the operations from the vector transfer_read to vector load";
5656
let description = [{
57-
This pass creates a transfer read op lowering. A vector trasfer read op
58-
will be lowered to a combination of vector.load, arith.select and
59-
vector.broadcast.
57+
This pass creates a transfer read op lowering optimization. The lowering
58+
will produce a conditional check at runtime. If within bounds, a vector
59+
trasfer read op will be lowered to a combination of vector.load, arith.select
60+
and vector.broadcast. If not, it will fallback to the default lowering
61+
of the transfer_read op.
6062

6163
This pattern will make it possible for masked transfer_read to be lowered
6264
towards buffer load with bounds check, allowing a more optimized global
6365
load accessing pattern compared with existing implementation of
6466
llvm.intr.masked.load on vectors.
6567
}];
66-
let dependentDialects = [];
68+
let dependentDialects = [
69+
"scf::SCFDialect",
70+
"memref::MemRefDialect"
71+
];
6772
}
6873
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
1414
MLIRAMDGPUUtils
1515
MLIRArithDialect
1616
MLIRMemRefDialect
17+
MLIRSCFDialect
1718
MLIRVectorDialect
1819
MLIRControlFlowDialect
1920
MLIRFuncDialect

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 145 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,22 @@
99
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/Arith/Utils/Utils.h"
15+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17+
#include "mlir/Dialect/SCF/IR/SCF.h"
1218
#include "mlir/Dialect/Vector/IR/VectorOps.h"
19+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1320
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/OpDefinition.h"
1422
#include "mlir/IR/PatternMatch.h"
1523
#include "mlir/IR/TypeUtilities.h"
1624
#include "mlir/Pass/Pass.h"
1725
#include "mlir/Support/LogicalResult.h"
18-
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
26+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27+
#include "llvm/Support/MathExtras.h"
1928

2029
namespace mlir::amdgpu {
2130
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -67,6 +76,9 @@ static LogicalResult transferPreconditions(
6776
if (!memRefType.isLastDimUnitStride())
6877
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
6978

79+
if (memRefType.getElementTypeBitWidth() < 8)
80+
return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
81+
7082
// If there is broadcasting involved then we first load the unbroadcasted
7183
// vector, and then broadcast it with `vector.broadcast`.
7284
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
@@ -101,13 +113,35 @@ static LogicalResult transferPreconditions(
101113
return success();
102114
}
103115

116+
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
117+
vector::TransferReadOp readOp,
118+
bool requiresBroadcasting,
119+
VectorType unbroadcastedVectorType) {
120+
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
121+
readOp.getPadding());
122+
Value load = builder.create<vector::LoadOp>(
123+
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
124+
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
125+
readOp.getMask(), load, fill);
126+
// Insert a broadcasting op if required.
127+
if (requiresBroadcasting) {
128+
res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
129+
}
130+
return res;
131+
}
132+
133+
static constexpr char kTransferReadNeedsMask[] =
134+
"amdgpu.buffer_transfer_read_needs_mask";
135+
104136
namespace {
105137

106138
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
107139
using OpRewritePattern::OpRewritePattern;
108140

109141
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
110142
PatternRewriter &rewriter) const override {
143+
if (readOp->hasAttr(kTransferReadNeedsMask))
144+
return failure();
111145

112146
bool requiresBroadcasting = false;
113147
VectorType unbroadcastedVectorType;
@@ -117,20 +151,115 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
117151
}
118152

119153
Location loc = readOp.getLoc();
120-
Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
121-
readOp.getPadding());
122-
Value load = rewriter.create<vector::LoadOp>(
123-
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
124-
Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
125-
readOp.getMask(), load, fill);
126-
127-
// Insert a broadcasting op if required.
128-
if (requiresBroadcasting) {
129-
res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
130-
res);
154+
Value src = readOp.getSource();
155+
156+
VectorType vectorType = readOp.getVectorType();
157+
int64_t vectorSize = vectorType.getNumElements();
158+
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
159+
SmallVector<OpFoldResult> indices = readOp.getIndices();
160+
161+
auto stridedMetadata =
162+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
163+
SmallVector<OpFoldResult> strides =
164+
stridedMetadata.getConstifiedMixedStrides();
165+
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
166+
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
167+
OpFoldResult linearizedIndices;
168+
std::tie(std::ignore, linearizedIndices) =
169+
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
170+
elementBitWidth, offset, sizes,
171+
strides, indices);
172+
173+
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
174+
// Note below doesn't give the correct result for the linearized size.
175+
// Value totalSize = getValueOrCreateConstantIndexOp(
176+
// rewriter, loc, linearizedInfo.linearizedSize);
177+
// It computes the multiplied sizes of all dimensions instead of taking
178+
// the maximum of each dimension size * stride.
179+
SmallVector<AffineExpr> productExpressions;
180+
SmallVector<Value> productResults;
181+
unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
182+
183+
SmallVector<AffineExpr> symbols(2 * sourceRank);
184+
SmallVector<Value> offsetValues;
185+
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
186+
187+
size_t symbolIndex = 0;
188+
for (size_t i = 0; i < sourceRank; ++i) {
189+
AffineExpr strideExpr, sizeExpr;
190+
OpFoldResult stride = strides[i];
191+
OpFoldResult size = sizes[i];
192+
if (auto constantStride = getConstantIntValue(stride)) {
193+
strideExpr = rewriter.getAffineConstantExpr(*constantStride);
194+
} else {
195+
strideExpr = symbols[symbolIndex++];
196+
offsetValues.push_back(
197+
getValueOrCreateConstantIndexOp(rewriter, loc, stride));
198+
}
199+
200+
if (auto constantSize = getConstantIntValue(size)) {
201+
sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
202+
} else {
203+
sizeExpr = symbols[symbolIndex++];
204+
offsetValues.push_back(
205+
getValueOrCreateConstantIndexOp(rewriter, loc, size));
206+
}
207+
208+
productExpressions.push_back(strideExpr * sizeExpr);
131209
}
132210

133-
rewriter.replaceOp(readOp, res);
211+
AffineMap maxMap = AffineMap::get(
212+
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
213+
rewriter.getContext());
214+
Value totalSize =
215+
rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
216+
217+
// delta = bufferSize - linearizedOffset
218+
Value vectorSizeOffset =
219+
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
220+
Value linearIndex =
221+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
222+
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
223+
224+
// 1) check if delta < vectorSize
225+
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
226+
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
227+
228+
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
229+
Value deltaBytes = rewriter.create<arith::MulIOp>(
230+
loc, delta,
231+
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
232+
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
233+
loc, llvm::divideCeil(32, elementBitWidth));
234+
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
235+
loc, arith::CmpIPredicate::ne,
236+
rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
237+
rewriter.create<arith::ConstantIndexOp>(loc, 0));
238+
239+
// We take the fallback of transfer_read default lowering only it is both
240+
// out-of-bounds and not word aligned. The fallback ensures correct results
241+
// when loading at the boundary of the buffer since buffer load returns
242+
// inconsistent zeros for the whole word when boundary is crossed.
243+
Value ifCondition =
244+
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
245+
246+
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
247+
Operation *read = builder.clone(*readOp.getOperation());
248+
read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
249+
Value readResult = read->getResult(0);
250+
builder.create<scf::YieldOp>(loc, readResult);
251+
};
252+
253+
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
254+
Value res = createVectorLoadForMaskedLoad(
255+
builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
256+
rewriter.create<scf::YieldOp>(loc, res);
257+
};
258+
259+
auto ifOp =
260+
rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
261+
262+
rewriter.replaceOp(readOp, ifOp);
134263

135264
return success();
136265
}
@@ -149,6 +278,8 @@ struct AmdgpuTransferReadToLoadPass final
149278
void runOnOperation() override {
150279
RewritePatternSet patterns(&getContext());
151280
populateAmdgpuTransferReadToLoadPatterns(patterns);
152-
walkAndApplyPatterns(getOperation(), std::move(patterns));
281+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
282+
return signalPassFailure();
283+
}
153284
}
154285
};

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,71 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
99
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
1010
return %res : vector<4xf32>
1111
}
12-
// CHECK: %[[CST:.*]] = arith.constant 0.0
13-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
12+
13+
// CHECK: %[[FALSE:.*]] = arith.constant false
14+
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
15+
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
16+
17+
// CHECK: } else {
1418
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
15-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
16-
// CHECK: return %[[SELECT]] : vector<4xf32>
19+
// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]]
20+
21+
// CHECK: return %[[IF]] : vector<4xf32>
22+
23+
// -----
24+
25+
// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
26+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_f16(
27+
// CHECK-SAME: %[[ARG0:.+]]: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>,
28+
// CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
29+
// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>)
30+
func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xf16> {
31+
%cf0 = arith.constant 0.0 : f16
32+
%res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf16>
33+
return %res : vector<4xf16>
34+
}
35+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
36+
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 64
37+
// CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
38+
// CHECK-DAG: %[[VECTORSIZE:.*]] = arith.constant 4
39+
40+
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
41+
// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
42+
// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[VECTORSIZE]]
43+
44+
// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
45+
// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
46+
// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
47+
48+
// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
49+
// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
50+
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG2]]]
51+
// CHECK: } else {
52+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
53+
// CHECK: return %[[IF]] : vector<4xf16>
54+
55+
// -----
56+
57+
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
58+
// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
59+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
60+
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
61+
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
62+
// CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
63+
func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xi8> {
64+
%cf0 = arith.constant 0 : i8
65+
%res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
66+
return %res : vector<4xi8>
67+
}
68+
69+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
70+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
71+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
72+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
73+
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
74+
// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
75+
// CHECK: %[[IF:.*]] = scf.if
76+
// CHECK: return
1777

1878
// -----
1979

@@ -26,8 +86,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
2686
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
2787
return %res : vector<4xf32>
2888
}
29-
// CHECK: %[[CST:.*]] = arith.constant 0.0
30-
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
89+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
90+
// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
3191
// CHECK: return %[[RES]] : vector<4xf32>
3292

3393
// -----
@@ -41,8 +101,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
41101
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
42102
return %res : vector<4xf32>
43103
}
44-
// CHECK: %[[CST:.*]] = arith.constant 0.0
45-
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
104+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
105+
// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
46106
// CHECK: return %[[RES]] : vector<4xf32>
47107

48108
// -----
@@ -59,12 +119,12 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
59119
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
60120
return %res : vector<4xf32>
61121
}
62-
// CHECK: %[[CST:.*]] = arith.constant 0.0
63-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
122+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
123+
// CHECK: %[[FALSE:.*]] = arith.constant false
124+
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
64125
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
65-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
126+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
66127
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
67-
// CHECK: return %[[BROADCAST]] : vector<4xf32>
68128

69129
// -----
70130

@@ -79,8 +139,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
79139
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
80140
return %res : vector<1xf32>
81141
}
82-
// CHECK: %[[CST:.*]] = arith.constant 0.0
83-
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
84-
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
85-
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
86-
// CHECK: return %[[SELECT]] : vector<1xf32>
142+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
143+
// CHECK: %[[FALSE:.*]] = arith.constant false
144+
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<1xf32>) {
145+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
146+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,7 @@ cc_library(
15691569
":IR",
15701570
":MemRefDialect",
15711571
":Pass",
1572+
":SCFDialect",
15721573
":SideEffectInterfaces",
15731574
":Support",
15741575
":TransformUtils",

0 commit comments

Comments
 (0)