Skip to content

Commit fd15e2b

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Use rank-reduced versions of subtensor and subtensor insert when possible.
Convert subtensor and subtensor_insert operations to use their rank-reduced versions to drop unit dimensions. Differential Revision: https://reviews.llvm.org/D101495
1 parent 63f8226 commit fd15e2b

File tree

6 files changed

+270
-200
lines changed

6 files changed

+270
-200
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
1818
from/to the original memref.
1919
}];
2020
let constructor = "mlir::memref::createFoldSubViewOpsPass()";
21-
let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
21+
let dependentDialects = [
22+
"AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
23+
];
2224
}
2325

2426

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 74 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -544,77 +544,87 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
544544
return success();
545545
}
546546
};
547+
} // namespace
547548

548-
/// Pattern to fold subtensors that are just taking a slice of unit-dimension
549-
/// tensor. For example
550-
///
551-
/// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1]
552-
/// : tensor<1x?x1xf32> to tensor<1x?x1xf32>
553-
///
554-
/// can be replaced with
555-
///
556-
/// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
557-
/// : tensor<1x?x1xf32> into tensor<?xf32>
558-
/// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32>
559-
/// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
560-
/// : tensor<?xf32> into tensor<1x?x1xf32>
561-
///
562-
/// The additional tensor_reshapes will hopefully get canonicalized away with
563-
/// other reshapes that drop unit dimensions. Three condiitions to fold a
564-
/// dimension
565-
/// - The offset must be 0
566-
/// - The size must be 1
567-
/// - The dimension of the source type must be 1.
568-
struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
549+
/// Get the reassociation maps to fold the result of a subtensor (or source of a
550+
/// subtensor_insert) operation with given offsets, and sizes to its
551+
/// rank-reduced version. This is only done for the cases where the size is 1
552+
/// and offset is 0. Strictly speaking the offset 0 is not required in general,
553+
/// but non-zero offsets are not handled by SPIR-V backend at this point (and
554+
/// potentially cannot be handled).
555+
static Optional<SmallVector<ReassociationIndices>>
556+
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
557+
SmallVector<ReassociationIndices> reassociation;
558+
ReassociationIndices curr;
559+
for (auto it : llvm::enumerate(mixedSizes)) {
560+
auto dim = it.index();
561+
auto size = it.value();
562+
curr.push_back(dim);
563+
auto attr = size.dyn_cast<Attribute>();
564+
if (attr && attr.cast<IntegerAttr>().getInt() == 1)
565+
continue;
566+
reassociation.emplace_back(ReassociationIndices{});
567+
std::swap(reassociation.back(), curr);
568+
}
569+
if (!curr.empty())
570+
reassociation.back().append(curr.begin(), curr.end());
571+
return reassociation;
572+
}
573+
574+
namespace {
575+
/// Convert `subtensor` operations to rank-reduced versions.
576+
struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
569577
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
570578

571579
LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
572580
PatternRewriter &rewriter) const override {
573-
SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets();
574-
SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes();
575-
SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides();
576-
auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) {
577-
auto attr = valueOrAttr.dyn_cast<Attribute>();
578-
return attr && attr.cast<IntegerAttr>().getInt() == val;
579-
};
580-
581-
if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) {
582-
return !hasValue(valueOrAttr, 1);
583-
}))
581+
RankedTensorType resultType = subTensorOp.getType();
582+
SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
583+
SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
584+
SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
585+
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
586+
if (!reassociation ||
587+
reassociation->size() == static_cast<size_t>(resultType.getRank()))
584588
return failure();
589+
auto rankReducedType =
590+
SubTensorOp::inferRankReducedResultType(reassociation->size(),
591+
subTensorOp.getSourceType(),
592+
offsets, sizes, strides)
593+
.cast<RankedTensorType>();
594+
595+
Location loc = subTensorOp.getLoc();
596+
Value newSubTensor = rewriter.create<SubTensorOp>(
597+
loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
598+
rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
599+
newSubTensor, *reassociation);
600+
return success();
601+
}
602+
};
585603

586-
// Find the expanded unit dimensions.
587-
SmallVector<ReassociationIndices> reassociation;
588-
SmallVector<OpFoldResult> newOffsets, newSizes;
589-
ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape();
590-
ReassociationIndices curr;
591-
for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
592-
curr.push_back(dim);
593-
if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) &&
594-
hasValue(mixedSizes[dim], 1)) {
595-
continue;
596-
}
597-
newOffsets.push_back(mixedOffsets[dim]);
598-
newSizes.push_back(mixedSizes[dim]);
599-
reassociation.emplace_back(ReassociationIndices{});
600-
std::swap(reassociation.back(), curr);
601-
}
602-
if (newOffsets.size() == mixedOffsets.size())
604+
/// Convert `subtensor_insert` operations to rank-reduced versions.
605+
struct UseRankReducedSubTensorInsertOp
606+
: public OpRewritePattern<SubTensorInsertOp> {
607+
using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
608+
609+
LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
610+
PatternRewriter &rewriter) const override {
611+
RankedTensorType sourceType = insertOp.getSourceType();
612+
SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
613+
SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
614+
SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
615+
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
616+
if (!reassociation ||
617+
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
603618
return failure();
604-
reassociation.back().append(curr.begin(), curr.end());
605-
SmallVector<OpFoldResult> newStrides(newOffsets.size(),
606-
rewriter.getI64IntegerAttr(1));
607-
Location loc = subTensorOp->getLoc();
608-
auto srcReshape = rewriter.create<TensorReshapeOp>(
609-
loc, subTensorOp.source(), reassociation);
610-
auto newSubTensorOp = rewriter.create<SubTensorOp>(
611-
loc, srcReshape, newOffsets, newSizes, newStrides);
612-
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
613-
subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation);
619+
Location loc = insertOp.getLoc();
620+
auto reshapedSource = rewriter.create<TensorReshapeOp>(
621+
loc, insertOp.source(), *reassociation);
622+
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
623+
insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
624+
insertOp.getMixedSizes(), insertOp.getMixedStrides());
614625
return success();
615626
}
616627
};
617-
618628
} // namespace
619629

620630
/// Patterns that are used to canonicalize the use of unit-extent dims for
@@ -623,8 +633,10 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
623633
RewritePatternSet &patterns) {
624634
auto *context = patterns.getContext();
625635
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
626-
FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>,
627-
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
636+
ReplaceUnitExtentTensors<GenericOp>,
637+
ReplaceUnitExtentTensors<IndexedGenericOp>,
638+
UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
639+
context);
628640
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
629641
patterns.add<FoldReshapeOpWithUnitExtent>(context);
630642
}

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
88
MLIRMemRefPassIncGen
99

1010
LINK_LIBS PUBLIC
11+
MLIRAffine
1112
MLIRMemRef
1213
MLIRPass
1314
MLIRStandard

mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1415
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1516
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1617
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -41,27 +42,53 @@ static LogicalResult
4142
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
4243
memref::SubViewOp subViewOp, ValueRange indices,
4344
SmallVectorImpl<Value> &sourceIndices) {
44-
// TODO: Aborting when the offsets are static. There might be a way to fold
45-
// the subview op with load even if the offsets have been canonicalized
46-
// away.
47-
SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
48-
if (opRanges.size() != indices.size()) {
49-
// For the rank-reduced cases, we can only handle the folding when the
50-
// offset is zero, size is 1 and stride is 1.
51-
return failure();
45+
SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
46+
SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
47+
SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
48+
49+
SmallVector<Value> useIndices;
50+
// Check if this is rank-reducing case. Then for every unit-dim size add a
51+
// zero to the indices.
52+
ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
53+
unsigned resultDim = 0;
54+
for (auto size : llvm::enumerate(mixedSizes)) {
55+
auto attr = size.value().dyn_cast<Attribute>();
56+
// Check if this dimension has been dropped, i.e. the size is 1, but the
57+
// associated dimension is not 1.
58+
if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
59+
(resultDim >= resultShape.size() || resultShape[resultDim] != 1))
60+
useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
61+
else if (resultDim < resultShape.size()) {
62+
useIndices.push_back(indices[resultDim++]);
63+
}
5264
}
53-
auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
54-
auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
55-
56-
// New indices for the load are the current indices * subview_stride +
57-
// subview_offset.
58-
sourceIndices.resize(indices.size());
59-
for (auto index : llvm::enumerate(indices)) {
60-
auto offset = *(opOffsets.begin() + index.index());
61-
auto stride = *(opStrides.begin() + index.index());
62-
auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
63-
sourceIndices[index.index()] =
64-
rewriter.create<AddIOp>(loc, offset, mul).getResult();
65+
if (useIndices.size() != mixedOffsets.size())
66+
return failure();
67+
sourceIndices.resize(useIndices.size());
68+
for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
69+
SmallVector<Value> dynamicOperands;
70+
AffineExpr expr = rewriter.getAffineDimExpr(0);
71+
unsigned numSymbols = 0;
72+
dynamicOperands.push_back(useIndices[index]);
73+
74+
// Multiply the stride;
75+
if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
76+
expr = expr * attr.cast<IntegerAttr>().getInt();
77+
} else {
78+
dynamicOperands.push_back(mixedStrides[index].get<Value>());
79+
expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
80+
}
81+
82+
// Add the offset.
83+
if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
84+
expr = expr + attr.cast<IntegerAttr>().getInt();
85+
} else {
86+
dynamicOperands.push_back(mixedOffsets[index].get<Value>());
87+
expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
88+
}
89+
Location loc = subViewOp.getLoc();
90+
sourceIndices[index] = rewriter.create<AffineApplyOp>(
91+
loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
6592
}
6693
return success();
6794
}

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -476,67 +476,32 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
476476
// -----
477477

478478
func @fold_subtensor(
479-
%arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : index, %arg2 : index,
480-
%arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
481-
-> tensor<1x?x?x1x?x1x1xf32> {
482-
%0 = subtensor %arg0[0, %arg1, %arg2, 0, %arg3, 0, 0]
483-
[1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
479+
%arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : tensor<1x?x?x?x?x1x1xf32>,
480+
%arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index,
481+
%arg6 : index, %arg7 : index) -> (tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>) {
482+
%0 = subtensor %arg0[0, %arg2, %arg3, 0, %arg4, 0, 0]
483+
[1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
484484
tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
485-
return %0 : tensor<1x?x?x1x?x1x1xf32>
485+
%1 = subtensor %arg1[%arg2, 0, %arg3, 0, 0, %arg4, 0]
486+
[1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
487+
tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
488+
return %0, %1 : tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>
486489
}
487490
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1)>
488491
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
489492
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
490493
// CHECK: func @fold_subtensor
491494
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32>
492-
// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index
493-
// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index
494-
// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index
495-
// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index
496-
// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index
497-
// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index
498-
// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
495+
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32>
496+
// CHECK: %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]]
497+
// CHECK-SAME: to tensor<?x?x?xf32>
498+
// CHECK: %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]]
499499
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
500-
// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
501-
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[ARG3]]]
502-
// CHECK-SAME: [%[[ARG4]], %[[ARG5]], %[[ARG6]]]
503-
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
500+
// CHECK: %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]]
501+
// CHECK-SAME: to tensor<?x?x?xf32>
502+
// CHECK: %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]]
504503
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
505-
// CHECK: return %[[RESULT_RESHAPE]]
506-
507-
// -----
508-
509-
func @no_fold_subtensor(
510-
%arg0 : tensor<1x?x?x?x?x1x1xf32>, %arg1 : index, %arg2 : index,
511-
%arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
512-
-> tensor<1x?x?x1x?x1x1xf32> {
513-
%0 = subtensor %arg0[%arg1, 0, %arg2, 0, 0, %arg3, 0]
514-
[1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
515-
tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
516-
return %0 : tensor<1x?x?x1x?x1x1xf32>
517-
}
518-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)>
519-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1)>
520-
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
521-
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)>
522-
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4)>
523-
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6)>
524-
// CHECK: func @no_fold_subtensor
525-
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x?x?x1x1xf32>
526-
// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index
527-
// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index
528-
// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index
529-
// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index
530-
// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index
531-
// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index
532-
// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
533-
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
534-
// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
535-
// CHECK-SAME: [%[[ARG1]], 0, %[[ARG2]], 0, 0, %[[ARG3]]]
536-
// CHECK-SAME: [1, %[[ARG4]], %[[ARG5]], 1, %[[ARG6]], 1]
537-
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
538-
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
539-
// CHECK: return %[[RESULT_RESHAPE]]
504+
// CHECK: return %[[RESULT1]], %[[RESULT2]]
540505

541506
// -----
542507

0 commit comments

Comments
 (0)