Skip to content

Commit 6b21948

Browse files
[mlir][vector] Fix invalid LoadOp indices being created (#76292)
Fixes #71326. This is the second PR. The first PR at #75519 was reverted because an integration test failed. The failed integration test was simplified and added to the core MLIR tests. Compared to the first PR, the current PR uses a more reliable approach. In summary, the current PR determines the mask indices by looking up the _mask_ buffer load indices from the previous iteration, whereas `main` looks up the indices for the _data_ buffer. The mask and data indices can differ when using a `permutation_map`. The cause of the issue was that a new `LoadOp` was created which looked something like: ```mlir func.func main(%arg1 : index, %arg2 : index) { %alloca_0 = memref.alloca() : memref<vector<1x32xi1>> %1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>> %2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>> return } ``` which crashed inside the `LoadOp::verify`. Note here that `%alloca_0` is the mask as can be seen from the `i1` element type and note it is 0 dimensional. Next, `%1` has one dimension, but `memref.load` tries to index it with two indices. This issue occured in the following code (a simplified version of the bug report): ```mlir #map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)> func.func @main(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1} : memref<1x1x1x1xi32>, vector<1x1x1x1xi32> return %3 : vector<1x1x1x1xi32> } ``` After this patch, it is lowered to the following by `-convert-vector-to-scf`: ```mlir func.func @main(%arg0: memref<1x1x1x1xi32>, %arg1: vector<1x1xi1>) -> vector<1x1x1x1xi32> { %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %alloca = memref.alloca() : memref<vector<1x1x1x1xi32>> %alloca_0 = memref.alloca() : memref<vector<1x1xi1>> memref.store %arg1, %alloca_0[] : memref<vector<1x1xi1>> %0 = vector.type_cast %alloca : memref<vector<1x1x1x1xi32>> to memref<1xvector<1x1x1xi32>> %1 = vector.type_cast %alloca_0 : memref<vector<1x1xi1>> to memref<1xvector<1xi1>> scf.for %arg2 = %c0 to %c1 step %c1 { %3 = vector.type_cast %0 : memref<1xvector<1x1x1xi32>> to memref<1x1xvector<1x1xi32>> scf.for %arg3 = %c0 to %c1 step %c1 { %4 = vector.type_cast %3 : memref<1x1xvector<1x1xi32>> to memref<1x1x1xvector<1xi32>> scf.for %arg4 = %c0 to %c1 step %c1 { %5 = memref.load %1[%arg2] : memref<1xvector<1xi1>> %6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref<1x1x1x1xi32>, vector<1xi32> memref.store %6, %4[%arg2, %arg3, %arg4] : memref<1x1x1xvector<1xi32>> } } } %2 = memref.load %alloca[] : memref<vector<1x1x1x1xi32>> return %2 : vector<1x1x1x1xi32> } ``` What was causing the problems is that one dimension of the data buffer `%alloca` (eltype `i32`) is unpacked (`vector.type_cast`) inside the outmost loop (loop with index variable `%arg2`) and the nested loop (loop with index variable `%arg3`), whereas the mask buffer `%alloca_0` (eltype `i1`) is not unpacked in these loops. Before this patch, the load indices would be determined by looking up the load indices for the *data* buffer load op. However, as shown in the specific example, when a permutation map is specified then the load indices from the data buffer load op start to differ from the indices for the mask op. To fix this, this patch ensures that the load indices for the *mask* buffer are used instead. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent d09315d commit 6b21948

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -866,16 +866,41 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
866866
this->setHasBoundedRewriteRecursion();
867867
}
868868

869+
static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
870+
SmallVectorImpl<Value> &loadIndices,
871+
Value iv) {
872+
assert(xferOp.getMask() && "Expected transfer op to have mask");
873+
874+
// Add load indices from the previous iteration.
875+
// The mask buffer depends on the permutation map, which makes determining
876+
// the indices quite complex, so this is why we need to "look back" to the
877+
// previous iteration to find the right indices.
878+
Value maskBuffer = getMaskBuffer(xferOp);
879+
for (Operation *user : maskBuffer.getUsers()) {
880+
// If there is no previous load op, then the indices are empty.
881+
if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
882+
Operation::operand_range prevIndices = loadOp.getIndices();
883+
loadIndices.append(prevIndices.begin(), prevIndices.end());
884+
break;
885+
}
886+
}
887+
888+
// In case of broadcast: Use same indices to load from memref
889+
// as before.
890+
if (!xferOp.isBroadcastDim(0))
891+
loadIndices.push_back(iv);
892+
}
893+
869894
LogicalResult matchAndRewrite(OpTy xferOp,
870895
PatternRewriter &rewriter) const override {
871896
if (!xferOp->hasAttr(kPassLabel))
872897
return failure();
873898

874899
// Find and cast data buffer. How the buffer can be found depends on OpTy.
875900
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
876-
auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
901+
Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
877902
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
878-
auto castedDataType = unpackOneDim(dataBufferType);
903+
FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
879904
if (failed(castedDataType))
880905
return failure();
881906

@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
885910
// If the xferOp has a mask: Find and cast mask buffer.
886911
Value castedMaskBuffer;
887912
if (xferOp.getMask()) {
888-
auto maskBuffer = getMaskBuffer(xferOp);
889-
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
913+
Value maskBuffer = getMaskBuffer(xferOp);
890914
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
891915
// Do not unpack a dimension of the mask, if:
892916
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
897921
} else {
898922
// It's safe to assume the mask buffer can be unpacked if the data
899923
// buffer was unpacked.
900-
auto castedMaskType = *unpackOneDim(maskBufferType);
924+
auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
925+
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
901926
castedMaskBuffer =
902927
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
903928
}
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
929954

930955
// If old transfer op has a mask: Set mask on new transfer op.
931956
// Special case: If the mask of the old transfer op is 1D and
932-
// the
933-
// unpacked dim is not a broadcast, no mask is
934-
// needed on the new transfer op.
957+
// the unpacked dim is not a broadcast, no mask is needed on
958+
// the new transfer op.
935959
if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
936960
xferOp.getMaskType().getRank() > 1)) {
937961
OpBuilder::InsertionGuard guard(b);
938962
b.setInsertionPoint(newXfer); // Insert load before newXfer.
939963

940964
SmallVector<Value, 8> loadIndices;
941-
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
942-
// In case of broadcast: Use same indices to load from memref
943-
// as before.
944-
if (!xferOp.isBroadcastDim(0))
945-
loadIndices.push_back(iv);
946-
965+
getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
966+
loadIndices, iv);
947967
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
948968
loadIndices);
949969
rewriter.updateRootInPlace(newXfer, [&]() {

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
740740

741741
// -----
742742

743+
// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
744+
745+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
746+
func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
747+
%c0 = arith.constant 0 : index
748+
%c0_i32 = arith.constant 0 : i32
749+
%3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
750+
: memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
751+
return %3 : vector<1x1x1x1xi32>
752+
}
753+
// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
754+
// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
755+
// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
756+
// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
757+
758+
// -----
759+
760+
// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
761+
// This test is pulled from an integration test for ArmSVE.
762+
763+
func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
764+
%c0 = arith.constant 0 : index
765+
%c2 = arith.constant 2 : index
766+
%c3 = arith.constant 2 : index
767+
%cst = arith.constant 0.000000e+00 : f32
768+
%dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
769+
%mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
770+
%vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
771+
return %vector_a : vector<1x2x[4]xf32>
772+
}
773+
// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
774+
// CHECK: scf.for
775+
// CHECK: scf.for
776+
// CHECK: memref.load
777+
778+
// -----
779+
743780
// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
744781
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
745782
// FULL-UNROLL-NOT: vector.extract

0 commit comments

Comments
 (0)