Skip to content

Commit ee5e355

Browse files
committed
[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N)
Updates patterns for flattening vector.transfer_read by relaxing the requirement that the "collapsed" indices are all zero. This enables collapsing cases like this one: ```mlir %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... : memref<1x43x4x6xi32>, vector<1x2x6xi32> ``` Previously only the following case would be consider for collapsing: ```mlir %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... : memref<1x43x4x6xi32>, vector<1x2x6xi32> ``` The pattern itself, `FlattenContiguousRowMajorTransferReadPattern`, was a bit refactored too: * added comments, * renamed `firstContiguousInnerDim` as `firstDimToCollapse` (the latter better matches the meaning and is already consistently used in various helper methods that use it), Similar update for `vector.transfer_write` will be implemented in a follow-up patch.
1 parent 3b6d63c commit ee5e355

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
511511
/// Checks that the indices corresponding to dimensions starting at
512512
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
513513
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
514+
/// TODO: Extract the logic that writes to outIndices so that this method
515+
/// simply checks one pre-condition.
514516
static LogicalResult
515517
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
516518
SmallVector<Value> &outIndices) {
@@ -544,43 +546,93 @@ class FlattenContiguousRowMajorTransferReadPattern
544546
VectorType vectorType = cast<VectorType>(vector.getType());
545547
Value source = transferReadOp.getSource();
546548
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
549+
550+
// 0. Check pre-conditions
547551
// Contiguity check is valid on tensors only.
548552
if (!sourceType)
549553
return failure();
554+
// If this is already 0D/1D, there's nothing to do.
550555
if (vectorType.getRank() <= 1)
551-
// Already 0D/1D, nothing to do.
552556
return failure();
553557
if (!vector::isContiguousSlice(sourceType, vectorType))
554558
return failure();
555-
int64_t firstContiguousInnerDim =
556-
sourceType.getRank() - vectorType.getRank();
557559
// TODO: generalize this pattern, relax the requirements here.
558560
if (transferReadOp.hasOutOfBoundsDim())
559561
return failure();
560562
if (!transferReadOp.getPermutationMap().isMinorIdentity())
561563
return failure();
562564
if (transferReadOp.getMask())
563565
return failure();
566+
564567
SmallVector<Value> collapsedIndices;
565-
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
566-
firstContiguousInnerDim,
567-
collapsedIndices)))
568-
return failure();
568+
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
569+
570+
// 1. Collapse the source memref
569571
Value collapsedSource =
570-
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
572+
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
571573
MemRefType collapsedSourceType =
572574
dyn_cast<MemRefType>(collapsedSource.getType());
573575
int64_t collapsedRank = collapsedSourceType.getRank();
574-
assert(collapsedRank == firstContiguousInnerDim + 1);
576+
assert(collapsedRank == firstDimToCollapse + 1);
577+
578+
// 2. Generate input args for a new vector.transfer_read that will read
579+
// from the collapsed memref.
580+
// 2.1. New dim exprs + affine map
575581
SmallVector<AffineExpr, 1> dimExprs{
576-
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
582+
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
577583
auto collapsedMap =
578584
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
585+
586+
// 2.2 New indices
587+
// If all the collapsed indices are zero then no extra logic is needed.
588+
// Otherwise, a new offset/index has to be computed.
589+
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
590+
firstDimToCollapse,
591+
collapsedIndices))) {
592+
// Copy all the leading indices
593+
collapsedIndices = transferReadOp.getIndices();
594+
collapsedIndices.resize(firstDimToCollapse);
595+
596+
// Compute the remaining trailing index/offset required for reading from
597+
// the collapsed memref:
598+
//
599+
// offset = 0
600+
// for (i = firstDimToCollapse; i < outputRank; ++i)
601+
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
602+
//
603+
// For this example:
604+
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
605+
// memref<1x43x2xi32>, vector<1x2xi32>
606+
// which would be collapsed to:
607+
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
608+
// memref<1x86xi32>, vector<2xi32>
609+
// one would get the following offset:
610+
// %offset = %arg0 * 43
611+
int64_t outputRank = transferReadOp.getIndices().size();
612+
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
613+
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
614+
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
615+
auto sourceDimSize =
616+
rewriter.create<memref::DimOp>(loc, source, dimIdx);
617+
618+
offset = rewriter.create<arith::AddIOp>(
619+
loc,
620+
rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
621+
sourceDimSize),
622+
offset);
623+
}
624+
collapsedIndices.push_back(offset);
625+
}
626+
627+
// 3. Create new vector.transfer_read that reads from the collapsed memref
579628
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
580629
vectorType.getElementType());
581630
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
582631
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
583632
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
633+
634+
// 4. Replace the old transfer_read with the new one reading from the
635+
// collapsed shape
584636
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
585637
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
586638
return success();

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,38 @@ func.func @transfer_read_dims_mismatch_contiguous(
4141

4242
// -----
4343

44+
func.func @transfer_read_dims_mismatch_non_zero_indices(
45+
%idx_1: index,
46+
%idx_2: index,
47+
%m_in: memref<1x43x4x6xi32>,
48+
%m_out: memref<1x2x6xi32>) {
49+
%c0 = arith.constant 0 : index
50+
%c0_i32 = arith.constant 0 : i32
51+
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
52+
memref<1x43x4x6xi32>, vector<1x2x6xi32>
53+
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
54+
vector<1x2x6xi32>, memref<1x2x6xi32>
55+
return
56+
}
57+
58+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
59+
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
60+
// CHECK-SAME: %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
61+
// CHECK-SAME: %[[VAL_3:.*]]: memref<1x2x6xi32>) {
62+
// CHECK: %[[VAL_4:.*]] = arith.constant 43 : index
63+
// CHECK: %[[VAL_5:.*]] = arith.constant 4 : index
64+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
65+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
66+
// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
67+
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
68+
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
69+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
70+
// CHECK: %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
71+
// CHECK: %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
72+
// CHECK: vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
73+
74+
// -----
75+
4476
func.func @transfer_read_dims_mismatch_non_contiguous(
4577
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
4678
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)