Skip to content

Commit 2eb9e33

Browse files
authored
[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (#73523)
Updates patterns for flattening `vector.transfer_read` by relaxing the requirement that the "collapsed" indices are all zero. This enables collapsing cases like this one: ```mlir %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... : memref<1x43x4x6xi32>, vector<1x2x6xi32> ``` Previously only the following case would be consider for collapsing (all indices are 0): ```mlir %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... : memref<1x43x4x6xi32>, vector<1x2x6xi32> ``` Also adds some new comments and renames the `firstContiguousInnerDim` parameter as `firstDimToCollapse` (the latter better matches the actual meaning). Similar updates for `vector.transfer_write` will be implemented in a follow-up patch.
1 parent e8dbe94 commit 2eb9e33

File tree

4 files changed

+129
-11
lines changed

4 files changed

+129
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
511511
/// Checks that the indices corresponding to dimensions starting at
512512
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
513513
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
514+
/// TODO: Extract the logic that writes to outIndices so that this method
515+
/// simply checks one pre-condition.
514516
static LogicalResult
515517
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
516518
SmallVector<Value> &outIndices) {
@@ -542,45 +544,100 @@ class FlattenContiguousRowMajorTransferReadPattern
542544
auto loc = transferReadOp.getLoc();
543545
Value vector = transferReadOp.getVector();
544546
VectorType vectorType = cast<VectorType>(vector.getType());
545-
Value source = transferReadOp.getSource();
547+
auto source = transferReadOp.getSource();
546548
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
549+
550+
// 0. Check pre-conditions
547551
// Contiguity check is valid on tensors only.
548552
if (!sourceType)
549553
return failure();
554+
// If this is already 0D/1D, there's nothing to do.
550555
if (vectorType.getRank() <= 1)
551-
// Already 0D/1D, nothing to do.
552556
return failure();
553557
if (!vector::isContiguousSlice(sourceType, vectorType))
554558
return failure();
555-
int64_t firstContiguousInnerDim =
556-
sourceType.getRank() - vectorType.getRank();
557559
// TODO: generalize this pattern, relax the requirements here.
558560
if (transferReadOp.hasOutOfBoundsDim())
559561
return failure();
560562
if (!transferReadOp.getPermutationMap().isMinorIdentity())
561563
return failure();
562564
if (transferReadOp.getMask())
563565
return failure();
566+
564567
SmallVector<Value> collapsedIndices;
565-
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
566-
firstContiguousInnerDim,
567-
collapsedIndices)))
568-
return failure();
568+
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
569+
570+
// 1. Collapse the source memref
569571
Value collapsedSource =
570-
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
572+
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
571573
MemRefType collapsedSourceType =
572574
dyn_cast<MemRefType>(collapsedSource.getType());
573575
int64_t collapsedRank = collapsedSourceType.getRank();
574-
assert(collapsedRank == firstContiguousInnerDim + 1);
576+
assert(collapsedRank == firstDimToCollapse + 1);
577+
578+
// 2. Generate input args for a new vector.transfer_read that will read
579+
// from the collapsed memref.
580+
// 2.1. New dim exprs + affine map
575581
SmallVector<AffineExpr, 1> dimExprs{
576-
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
582+
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
577583
auto collapsedMap =
578584
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
585+
586+
// 2.2 New indices
587+
// If all the collapsed indices are zero then no extra logic is needed.
588+
// Otherwise, a new offset/index has to be computed.
589+
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
590+
firstDimToCollapse,
591+
collapsedIndices))) {
592+
// Copy all the leading indices
593+
collapsedIndices = transferReadOp.getIndices();
594+
collapsedIndices.resize(firstDimToCollapse);
595+
596+
// Compute the remaining trailing index/offset required for reading from
597+
// the collapsed memref:
598+
//
599+
// offset = 0
600+
// for (i = firstDimToCollapse; i < outputRank; ++i)
601+
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
602+
//
603+
// For this example:
604+
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
605+
// memref<1x43x2xi32>, vector<1x2xi32>
606+
// which would be collapsed to:
607+
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
608+
// memref<1x86xi32>, vector<2xi32>
609+
// one would get the following offset:
610+
// %offset = %arg0 * 43
611+
AffineExpr offsetExpr, idxExpr;
612+
bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);
613+
614+
int64_t outputRank = transferReadOp.getIndices().size();
615+
OpFoldResult offset =
616+
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
617+
618+
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
619+
int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
620+
offset = affine::makeComposedFoldedAffineApply(
621+
rewriter, loc, offsetExpr + dim * idxExpr,
622+
{offset, transferReadOp.getIndices()[i]});
623+
}
624+
if (offset.is<Value>()) {
625+
collapsedIndices.push_back(offset.get<Value>());
626+
} else {
627+
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
628+
loc, *getConstantIntValue(offset)));
629+
}
630+
}
631+
632+
// 3. Create new vector.transfer_read that reads from the collapsed memref
579633
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
580634
vectorType.getElementType());
581635
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
582636
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
583637
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
638+
639+
// 4. Replace the old transfer_read with the new one reading from the
640+
// collapsed shape
584641
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
585642
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
586643
return success();

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
265265
return false;
266266
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
267267

268+
// TODO: Add support for memref with trailing dynamic shapes. Memrefs
269+
// with leading dynamic dimensions are already supported.
270+
if (ShapedType::isDynamicShape(memrefShape))
271+
return false;
272+
268273
// Cond 1: A contiguous memref will always have a unit trailing stride.
269274
if (strides.back() != 1)
270275
return false;

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,61 @@ func.func @transfer_read_dims_mismatch_contiguous(
4141

4242
// -----
4343

44+
func.func @transfer_read_dims_mismatch_non_zero_indices(
45+
%idx_1: index,
46+
%idx_2: index,
47+
%m_in: memref<1x43x4x6xi32>,
48+
%m_out: memref<1x2x6xi32>) {
49+
%c0 = arith.constant 0 : index
50+
%c0_i32 = arith.constant 0 : i32
51+
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
52+
memref<1x43x4x6xi32>, vector<1x2x6xi32>
53+
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
54+
vector<1x2x6xi32>, memref<1x2x6xi32>
55+
return
56+
}
57+
58+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
59+
60+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
61+
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
62+
// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>,
63+
// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
64+
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
65+
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
66+
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
67+
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
68+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
69+
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
70+
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
71+
72+
// -----
73+
74+
func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
75+
%idx_1: index,
76+
%idx_2: index,
77+
%m_in: memref<1x?x4x6xi32>,
78+
%m_out: memref<1x2x6xi32>) {
79+
%c0 = arith.constant 0 : index
80+
%c0_i32 = arith.constant 0 : i32
81+
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
82+
memref<1x?x4x6xi32>, vector<1x2x6xi32>
83+
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
84+
vector<1x2x6xi32>, memref<1x2x6xi32>
85+
return
86+
}
87+
88+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
89+
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
90+
// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>,
91+
// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
92+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
93+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
94+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
95+
// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
96+
97+
// -----
98+
4499
func.func @transfer_read_dims_mismatch_non_contiguous(
45100
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
46101
%c0 = arith.constant 0 : index

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns
454454
}
455455
void getDependentDialects(DialectRegistry &registry) const override {
456456
registry.insert<memref::MemRefDialect>();
457+
registry.insert<affine::AffineDialect>();
457458
}
458459
void runOnOperation() override {
459460
RewritePatternSet patterns(&getContext());

0 commit comments

Comments
 (0)