Skip to content

Commit 53ddc87

Browse files
authored
[mlir][vector] Improve flattening vector.transfer_write ops. (#94051)
We can flatten the transfer ops even when the collapsed indices are not zeros. We can compute it. It is already supported in vector.transfer_read cases. The revision refactors the logic and reuse it in transfer_write cases.
1 parent 0e743ec commit 53ddc87

File tree

2 files changed

+102
-84
lines changed

2 files changed

+102
-84
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 67 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -505,25 +505,61 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
505505
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
506506
}
507507

508-
/// Checks that the indices corresponding to dimensions starting at
509-
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
510-
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
511-
/// TODO: Extract the logic that writes to outIndices so that this method
512-
/// simply checks one pre-condition.
513-
static LogicalResult
514-
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
515-
SmallVector<Value> &outIndices) {
516-
int64_t rank = indices.size();
517-
if (firstDimToCollapse >= rank)
518-
return failure();
519-
for (int64_t i = firstDimToCollapse; i < rank; ++i) {
520-
std::optional<int64_t> cst = getConstantIntValue(indices[i]);
521-
if (!cst || cst.value() != 0)
522-
return failure();
508+
/// Returns the new indices that collapses the inner dimensions starting from
509+
/// the `firstDimToCollapse` dimension.
510+
static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
511+
Location loc,
512+
ArrayRef<int64_t> shape,
513+
ValueRange indices,
514+
int64_t firstDimToCollapse) {
515+
assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
516+
517+
// If all the collapsed indices are zero then no extra logic is needed.
518+
// Otherwise, a new offset/index has to be computed.
519+
SmallVector<Value> indicesAfterCollapsing(
520+
indices.begin(), indices.begin() + firstDimToCollapse);
521+
SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
522+
indices.end());
523+
if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
524+
indicesAfterCollapsing.push_back(indicesToCollapse[0]);
525+
return indicesAfterCollapsing;
526+
}
527+
528+
// Compute the remaining trailing index/offset required for reading from
529+
// the collapsed memref:
530+
//
531+
// offset = 0
532+
// for (i = firstDimToCollapse; i < outputRank; ++i)
533+
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
534+
//
535+
// For this example:
536+
// %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
537+
// memref<1x43x2xi32>, vector<1x2xi32>
538+
// which would be collapsed to:
539+
// %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
540+
// memref<1x86xi32>, vector<2xi32>
541+
// one would get the following offset:
542+
// %offset = %arg0 * 43
543+
OpFoldResult collapsedOffset =
544+
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
545+
546+
auto collapsedStrides = computeSuffixProduct(
547+
ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
548+
549+
// Compute the collapsed offset.
550+
auto &&[collapsedExpr, collapsedVals] =
551+
computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
552+
collapsedOffset = affine::makeComposedFoldedAffineApply(
553+
rewriter, loc, collapsedExpr, collapsedVals);
554+
555+
if (collapsedOffset.is<Value>()) {
556+
indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
557+
} else {
558+
indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
559+
loc, *getConstantIntValue(collapsedOffset)));
523560
}
524-
outIndices = indices;
525-
outIndices.resize(firstDimToCollapse + 1);
526-
return success();
561+
562+
return indicesAfterCollapsing;
527563
}
528564

529565
namespace {
@@ -594,54 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
594630
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
595631

596632
// 2.2 New indices
597-
// If all the collapsed indices are zero then no extra logic is needed.
598-
// Otherwise, a new offset/index has to be computed.
599-
SmallVector<Value> collapsedIndices;
600-
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
601-
firstDimToCollapse,
602-
collapsedIndices))) {
603-
// Copy all the leading indices.
604-
SmallVector<Value> indices = transferReadOp.getIndices();
605-
collapsedIndices.append(indices.begin(),
606-
indices.begin() + firstDimToCollapse);
607-
608-
// Compute the remaining trailing index/offset required for reading from
609-
// the collapsed memref:
610-
//
611-
// offset = 0
612-
// for (i = firstDimToCollapse; i < outputRank; ++i)
613-
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
614-
//
615-
// For this example:
616-
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
617-
// memref<1x43x2xi32>, vector<1x2xi32>
618-
// which would be collapsed to:
619-
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
620-
// memref<1x86xi32>, vector<2xi32>
621-
// one would get the following offset:
622-
// %offset = %arg0 * 43
623-
OpFoldResult collapsedOffset =
624-
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
625-
626-
auto sourceShape = sourceType.getShape();
627-
auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
628-
sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
629-
630-
// Compute the collapsed offset.
631-
ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
632-
indices.end());
633-
auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
634-
collapsedOffset, collapsedStrides, indicesToCollapse);
635-
collapsedOffset = affine::makeComposedFoldedAffineApply(
636-
rewriter, loc, collapsedExpr, collapsedVals);
637-
638-
if (collapsedOffset.is<Value>()) {
639-
collapsedIndices.push_back(collapsedOffset.get<Value>());
640-
} else {
641-
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
642-
loc, *getConstantIntValue(collapsedOffset)));
643-
}
644-
}
633+
SmallVector<Value> collapsedIndices =
634+
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
635+
transferReadOp.getIndices(), firstDimToCollapse);
645636

646637
// 3. Create new vector.transfer_read that reads from the collapsed memref
647638
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
@@ -697,31 +688,31 @@ class FlattenContiguousRowMajorTransferWritePattern
697688
return failure();
698689
if (!vector::isContiguousSlice(sourceType, vectorType))
699690
return failure();
700-
int64_t firstContiguousInnerDim =
701-
sourceType.getRank() - vectorType.getRank();
691+
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
702692
// TODO: generalize this pattern, relax the requirements here.
703693
if (transferWriteOp.hasOutOfBoundsDim())
704694
return failure();
705695
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
706696
return failure();
707697
if (transferWriteOp.getMask())
708698
return failure();
709-
SmallVector<Value> collapsedIndices;
710-
if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
711-
firstContiguousInnerDim,
712-
collapsedIndices)))
713-
return failure();
699+
700+
SmallVector<Value> collapsedIndices =
701+
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
702+
transferWriteOp.getIndices(), firstDimToCollapse);
714703

715704
Value collapsedSource =
716-
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
705+
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
717706
MemRefType collapsedSourceType =
718707
cast<MemRefType>(collapsedSource.getType());
719708
int64_t collapsedRank = collapsedSourceType.getRank();
720-
assert(collapsedRank == firstContiguousInnerDim + 1);
709+
assert(collapsedRank == firstDimToCollapse + 1);
710+
721711
SmallVector<AffineExpr, 1> dimExprs{
722-
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
712+
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
723713
auto collapsedMap =
724714
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
715+
725716
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
726717
vectorType.getElementType());
727718
Value flatVector =

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,25 +471,52 @@ func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, str
471471
}
472472

473473
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
474-
// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
475-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
476-
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
474+
// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
475+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
476+
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
477477

478478
// CHECK-128B-LABEL: func @regression_non_contiguous_dim_read(
479479
// CHECK-128B: memref.collapse_shape
480480

481481
// -----
482482

483-
func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
483+
func.func @regression_non_contiguous_dim_write(%value : vector<2x2xf32>,
484484
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
485485
%idx0 : index, %idx1 : index) {
486486
%c0 = arith.constant 0 : index
487487
vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
488488
return
489489
}
490490

491-
// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
492-
// CHECK-NOT: memref.collapse_shape
491+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
492+
// CHECK-LABEL: func.func @regression_non_contiguous_dim_write(
493+
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
494+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
493495

494-
// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
495-
// CHECK-128B-NOT: memref.collapse_shape
496+
// CHECK-128B-LABEL: func @regression_non_contiguous_dim_write(
497+
// CHECK-128B: memref.collapse_shape
498+
499+
// -----
500+
501+
func.func @negative_out_of_bound_transfer_read(
502+
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
503+
%c0 = arith.constant 0 : index
504+
%cst = arith.constant 0 : i8
505+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} :
506+
memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
507+
return %v : vector<5x4x3x2xi8>
508+
}
509+
// CHECK: func.func @negative_out_of_bound_transfer_read
510+
// CHECK-NOT: memref.collapse_shape
511+
512+
// -----
513+
514+
func.func @negative_out_of_bound_transfer_write(
515+
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
516+
%c0 = arith.constant 0 : index
517+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} :
518+
vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
519+
return
520+
}
521+
// CHECK: func.func @negative_out_of_bound_transfer_write
522+
// CHECK-NOT: memref.collapse_shape

0 commit comments

Comments
 (0)