Skip to content

Commit 847048f

Browse files
authored
[mlir][Vector] Fix bug in vector xfer op flattening transformation (#81964)
It looks like the affine map generated to compute the indices of the collapsed dimensions used the wrong dim size. For indices `[idx0][idx1]` we computed the collapsed index as `idx0*size0 + idx1` instead of `idx0*size1 + idx1`. This led to correctness issues in convolution tests when enabling this transformation internally.
1 parent 66f6929 commit 847048f

File tree

4 files changed

+65
-22
lines changed

4 files changed

+65
-22
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
257257
std::pair<AffineExpr, SmallVector<OpFoldResult>>
258258
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
259259
ArrayRef<OpFoldResult> indices);
260+
std::pair<AffineExpr, SmallVector<OpFoldResult>>
261+
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
262+
ArrayRef<Value> indices);
260263

261264
//===----------------------------------------------------------------------===//
262265
// Utilities for decomposing larger shapes

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Utils/IndexingUtils.h"
10-
10+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1111
#include "mlir/IR/AffineExpr.h"
1212
#include "mlir/IR/Builders.h"
1313
#include "mlir/IR/BuiltinAttributes.h"
1414
#include "mlir/IR/MLIRContext.h"
1515
#include "llvm/ADT/STLExtras.h"
16-
1716
#include <numeric>
1817
#include <optional>
1918

@@ -306,6 +305,14 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
306305
return {expr, values};
307306
}
308307

308+
std::pair<AffineExpr, SmallVector<OpFoldResult>>
309+
mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
310+
ArrayRef<Value> indices) {
311+
return computeLinearIndex(
312+
sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
313+
getAsOpFoldResult(ValueRange(indices)));
314+
}
315+
309316
//===----------------------------------------------------------------------===//
310317
// TileOffsetRange
311318
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1819
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1920
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2021
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -577,7 +578,6 @@ class FlattenContiguousRowMajorTransferReadPattern
577578
if (transferReadOp.getMask())
578579
return failure();
579580

580-
SmallVector<Value> collapsedIndices;
581581
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
582582

583583
// 1. Collapse the source memref
@@ -599,12 +599,14 @@ class FlattenContiguousRowMajorTransferReadPattern
599599
// 2.2 New indices
600600
// If all the collapsed indices are zero then no extra logic is needed.
601601
// Otherwise, a new offset/index has to be computed.
602+
SmallVector<Value> collapsedIndices;
602603
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
603604
firstDimToCollapse,
604605
collapsedIndices))) {
605-
// Copy all the leading indices
606-
collapsedIndices = transferReadOp.getIndices();
607-
collapsedIndices.resize(firstDimToCollapse);
606+
// Copy all the leading indices.
607+
SmallVector<Value> indices = transferReadOp.getIndices();
608+
collapsedIndices.append(indices.begin(),
609+
indices.begin() + firstDimToCollapse);
608610

609611
// Compute the remaining trailing index/offset required for reading from
610612
// the collapsed memref:
@@ -621,24 +623,26 @@ class FlattenContiguousRowMajorTransferReadPattern
621623
// memref<1x86xi32>, vector<2xi32>
622624
// one would get the following offset:
623625
// %offset = %arg0 * 43
624-
AffineExpr offsetExpr, idxExpr;
625-
bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);
626-
627-
int64_t outputRank = transferReadOp.getIndices().size();
628-
OpFoldResult offset =
626+
OpFoldResult collapsedOffset =
629627
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
630628

631-
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
632-
int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
633-
offset = affine::makeComposedFoldedAffineApply(
634-
rewriter, loc, offsetExpr + dim * idxExpr,
635-
{offset, transferReadOp.getIndices()[i]});
636-
}
637-
if (offset.is<Value>()) {
638-
collapsedIndices.push_back(offset.get<Value>());
629+
auto sourceShape = sourceType.getShape();
630+
auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
631+
sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
632+
633+
// Compute the collapsed offset.
634+
ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
635+
indices.end());
636+
auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
637+
collapsedOffset, collapsedStrides, indicesToCollapse);
638+
collapsedOffset = affine::makeComposedFoldedAffineApply(
639+
rewriter, loc, collapsedExpr, collapsedVals);
640+
641+
if (collapsedOffset.is<Value>()) {
642+
collapsedIndices.push_back(collapsedOffset.get<Value>());
639643
} else {
640644
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
641-
loc, *getConstantIntValue(offset)));
645+
loc, *getConstantIntValue(collapsedOffset)));
642646
}
643647
}
644648

@@ -710,6 +714,7 @@ class FlattenContiguousRowMajorTransferWritePattern
710714
firstContiguousInnerDim,
711715
collapsedIndices)))
712716
return failure();
717+
713718
Value collapsedSource =
714719
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
715720
MemRefType collapsedSourceType =

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
8383
return
8484
}
8585

86-
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
86+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
8787

8888
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
8989
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -92,7 +92,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
9292
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
9393
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
9494
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
95-
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
95+
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
9696
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
9797
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
9898
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
@@ -459,3 +459,31 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
459459
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
460460
// CHECK-128B-NOT: memref.collapse_shape
461461

462+
463+
// -----
464+
465+
func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
466+
%idx0 : index, %idx1 : index) -> vector<2x2xf32> {
467+
%c0 = arith.constant 0 : index
468+
%cst_1 = arith.constant 0.000000e+00 : f32
469+
%8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
470+
return %8 : vector<2x2xf32>
471+
}
472+
473+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
474+
// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
475+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
476+
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
477+
478+
// -----
479+
480+
func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
481+
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
482+
%idx0 : index, %idx1 : index) {
483+
%c0 = arith.constant 0 : index
484+
vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
485+
return
486+
}
487+
488+
// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
489+
// CHECK-NOT: memref.collapse_shape

0 commit comments

Comments
 (0)