Skip to content

Commit 3a1ae2f

Browse files
rikhuijzerMacDue
andauthored
[mlir][vector] Fix invalid LoadOp indices being created (#75519)
Fixes #71326. The cause of the issue was that a new `LoadOp` was created which looked something like: ```mlir %arg4 = func.func main(%arg1 : index, %arg2 : index) { %alloca_0 = memref.alloca() : memref<vector<1x32xi1>> %1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>> %2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>> return } ``` which crashed inside the `LoadOp::verify`. Note here that `%alloca_0` is 0 dimensional, `%1` has one dimension, but `memref.load` tries to index `%1` with two indices. This is now fixed by using the fact that `unpackOneDim` always unpacks one dim https://github.com/llvm/llvm-project/blob/1bce61e6b01b38e04260be4f422bbae59c34c766/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp#L897-L903 and so the `loadOp` should just index only one dimension. --------- Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent a3952b4 commit 3a1ae2f

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ struct Strategy<TransferReadOp> {
369369
/// Retrieve the indices of the current StoreOp that stores into the buffer.
370370
static void getBufferIndices(TransferReadOp xferOp,
371371
SmallVector<Value, 8> &indices) {
372-
auto storeOp = getStoreOp(xferOp);
372+
memref::StoreOp storeOp = getStoreOp(xferOp);
373373
auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
374374
indices.append(prevIndices.begin(), prevIndices.end());
375375
}
@@ -591,8 +591,8 @@ struct PrepareTransferReadConversion
591591
if (checkPrepareXferOp(xferOp, options).failed())
592592
return failure();
593593

594-
auto buffers = allocBuffers(rewriter, xferOp);
595-
auto *newXfer = rewriter.clone(*xferOp.getOperation());
594+
BufferAllocs buffers = allocBuffers(rewriter, xferOp);
595+
Operation *newXfer = rewriter.clone(*xferOp.getOperation());
596596
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
597597
if (xferOp.getMask()) {
598598
dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
@@ -885,8 +885,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
885885
// If the xferOp has a mask: Find and cast mask buffer.
886886
Value castedMaskBuffer;
887887
if (xferOp.getMask()) {
888-
auto maskBuffer = getMaskBuffer(xferOp);
889-
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
888+
Value maskBuffer = getMaskBuffer(xferOp);
890889
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
891890
// Do not unpack a dimension of the mask, if:
892891
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +896,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
897896
} else {
898897
// It's safe to assume the mask buffer can be unpacked if the data
899898
// buffer was unpacked.
900-
auto castedMaskType = *unpackOneDim(maskBufferType);
899+
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
900+
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
901901
castedMaskBuffer =
902902
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
903903
}
@@ -938,11 +938,18 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
938938
b.setInsertionPoint(newXfer); // Insert load before newXfer.
939939

940940
SmallVector<Value, 8> loadIndices;
941-
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
942-
// In case of broadcast: Use same indices to load from memref
943-
// as before.
944-
if (!xferOp.isBroadcastDim(0))
941+
if (auto memrefType =
942+
castedMaskBuffer.getType().dyn_cast<MemRefType>()) {
943+
// If castedMaskBuffer is a memref, then one dim was
944+
// unpacked; see above.
945945
loadIndices.push_back(iv);
946+
} else {
947+
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
948+
// In case of broadcast: Use same indices to load from
949+
// memref as before.
950+
if (!xferOp.isBroadcastDim(0))
951+
loadIndices.push_back(iv);
952+
}
946953

947954
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
948955
loadIndices);

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,8 +1615,10 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
16151615
//===----------------------------------------------------------------------===//
16161616

16171617
LogicalResult LoadOp::verify() {
1618-
if (getNumOperands() != 1 + getMemRefType().getRank())
1619-
return emitOpError("incorrect number of indices for load");
1618+
if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1619+
return emitOpError("incorrect number of indices for load, expected ")
1620+
<< getMemRefType().getRank() << " but got " << getIndices().size();
1621+
}
16201622
return success();
16211623
}
16221624

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,23 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
740740

741741
// -----
742742

743+
// Check that the `unpackOneDim` case in the `TransferOpConversion` generates valid indices for the LoadOp.
744+
745+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
746+
func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
747+
%c0 = arith.constant 0 : index
748+
%c0_i32 = arith.constant 0 : i32
749+
%3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
750+
: memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
751+
return %3 : vector<1x1x1x1xi32>
752+
}
753+
// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
754+
// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
755+
// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
756+
// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
757+
758+
// -----
759+
743760
// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
744761
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
745762
// FULL-UNROLL-NOT: vector.extract

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,15 @@ func.func @bad_alloc_wrong_symbol_count() {
896896

897897
// -----
898898

899+
func.func @load_invalid_memref_indexes() {
900+
%0 = memref.alloca() : memref<10xi32>
901+
%c0 = arith.constant 0 : index
902+
// expected-error@+1 {{incorrect number of indices for load, expected 1 but got 2}}
903+
%1 = memref.load %0[%c0, %c0] : memref<10xi32>
904+
}
905+
906+
// -----
907+
899908
func.func @test_store_zero_results() {
900909
^bb0:
901910
%0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>

0 commit comments

Comments
 (0)