Skip to content

[mlir][vector] Fix invalid LoadOp indices being created #75519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ struct Strategy<TransferReadOp> {
/// Retrieve the indices of the current StoreOp that stores into the buffer.
static void getBufferIndices(TransferReadOp xferOp,
SmallVector<Value, 8> &indices) {
auto storeOp = getStoreOp(xferOp);
memref::StoreOp storeOp = getStoreOp(xferOp);
auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
Expand Down Expand Up @@ -591,8 +591,8 @@ struct PrepareTransferReadConversion
if (checkPrepareXferOp(xferOp, options).failed())
return failure();

auto buffers = allocBuffers(rewriter, xferOp);
auto *newXfer = rewriter.clone(*xferOp.getOperation());
BufferAllocs buffers = allocBuffers(rewriter, xferOp);
Operation *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
Expand Down Expand Up @@ -885,8 +885,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
auto maskBuffer = getMaskBuffer(xferOp);
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
Expand All @@ -897,7 +896,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
auto castedMaskType = *unpackOneDim(maskBufferType);
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
Expand Down Expand Up @@ -938,11 +938,18 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
b.setInsertionPoint(newXfer); // Insert load before newXfer.

SmallVector<Value, 8> loadIndices;
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
// In case of broadcast: Use same indices to load from memref
// as before.
if (!xferOp.isBroadcastDim(0))
if (auto memrefType =
castedMaskBuffer.getType().dyn_cast<MemRefType>()) {
// If castedMaskBuffer is a memref, then one dim was
// unpacked; see above.
loadIndices.push_back(iv);
} else {
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
// In case of broadcast: Use same indices to load from
// memref as before.
if (!xferOp.isBroadcastDim(0))
loadIndices.push_back(iv);
}

auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,8 +1615,10 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//

LogicalResult LoadOp::verify() {
if (getNumOperands() != 1 + getMemRefType().getRank())
return emitOpError("incorrect number of indices for load");
if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
return emitOpError("incorrect number of indices for load, expected ")
<< getMemRefType().getRank() << " but got " << getIndices().size();
}
return success();
}

Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,23 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3

// -----

// Check that the `unpackOneDim` case in the `TransferOpConversion` generates valid indices for the LoadOp.

#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
: memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
return %3 : vector<1x1x1x1xi32>
}
// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>

// -----

// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
// FULL-UNROLL-NOT: vector.extract
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,15 @@ func.func @bad_alloc_wrong_symbol_count() {

// -----

func.func @load_invalid_memref_indexes() {
%0 = memref.alloca() : memref<10xi32>
%c0 = arith.constant 0 : index
// expected-error@+1 {{incorrect number of indices for load, expected 1 but got 2}}
%1 = memref.load %0[%c0, %c0] : memref<10xi32>
}

// -----

func.func @test_store_zero_results() {
^bb0:
%0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
Expand Down