Skip to content

[mlir][tensor] add tensor insert/extract op folders #142458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});

/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
/// source tensor.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);

} // namespace tensor
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [

let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
165 changes: 165 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
Expand All @@ -33,10 +34,12 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
#include <vector>

using namespace mlir;
using namespace mlir::tensor;
Expand Down Expand Up @@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
}
};

/// Canonicalizes the pattern of the form
///
/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
/// tensor<12xf64>
/// %extracted_element = tensor.extract %val[%c10] :
/// tensor<12xf64>
///
/// to
///
/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const final {
auto collapseOp =
extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseOp)
return failure();
if (!collapseOp.getSrcType().hasStaticShape())
return failure();

auto sourceSizes = collapseOp.getSrcType().getShape();

SmallVector<Value> indices(extractOp.getIndices().begin(),
extractOp.getIndices().end());
SmallVector<Value> sourceIndices;
for (auto [index, group] :
llvm::zip(indices, collapseOp.getReassociationIndices())) {
assert(!group.empty() && "association indices groups cannot be empty");
auto groupSize = group.size();

if (groupSize == 1) {
sourceIndices.push_back(index);
continue;
}

SmallVector<int64_t> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(), zeroAffineMap,
ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
}
}

rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, collapseOp.getSrc(), sourceIndices);
return success();
}
};

} // namespace

void ExtractOp::getAsmResultNames(
Expand All @@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
return success();
}

/// If we have an ExtractOp consuming an InsertOp with the same
/// indices, we can return the InsertOp's scalar directly.
// TODO: This only checks the immediate producer; extend to go up the
// insert/extract chain if the slices are disjoint.
static Value foldExtractAfterInsert(ExtractOp extractOp) {
auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();

auto isSame = [](Value a, Value b) {
return getAsOpFoldResult(a) == getAsOpFoldResult(b);
};
if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
return insertOp.getScalar();

return {};
}

OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (Attribute tensor = adaptor.getTensor()) {
// If this is a splat elements attribute, simply return the value.
Expand Down Expand Up @@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return elementsAttr.getValues<Attribute>()[indices];
}

if (Value result = foldExtractAfterInsert(*this))
return result;

return {};
}

Expand All @@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractFromTensorCast>(context);
}

void mlir::tensor::populateFoldCollapseExtractPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractFromCollapseShape>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//

namespace {

/// Pattern to fold an insert op of a constant destination and scalar to a new
/// constant.
///
/// Example:
/// ```
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
/// %c0 = arith.constant 0 : index
/// %c4_f32 = arith.constant 4.0 : f32
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
/// ```
/// is rewritten into:
/// ```
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
/// ```
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern<InsertOp>::OpRewritePattern;

LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
// Requires a ranked tensor type.
auto destType =
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
if (!destType)
return failure();

// Pattern requires constant indices
SmallVector<uint64_t, 8> indices;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    SmallVector<uint64_t> indices;

Nit: don't use a specific size for SmallVector without a good reasons to.

for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
auto indiceAttr = dyn_cast<Attribute>(indice);
if (!indiceAttr)
return failure();
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
}

// Requires a constant scalar to insert
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
if (!scalarAttr)
return failure();

if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
insertOp.getDest().getDefiningOp())) {
if (auto sourceAttr =
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
// Update the attribute at the inserted index.
auto sourceValues = sourceAttr.getValues<Attribute>();
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
std::vector<Attribute> updatedValues;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is an unfortunate slow path for what will be converted ultimately to a vector of int or float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting to copy the source values instead of looping? Or to switch on the int / float type and then convert back to attributes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm saying we should be able to avoid using individual attribues per-element in the common case indeed.
We may be lacking (possibly templated?) helpers to do this conveniently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in retrospect it makes sense to type switch on the attribute and the options are DenseIntElementsAttr and DenseFPElementsAttr.

updatedValues.reserve(sourceAttr.getNumElements());
for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
updatedValues.push_back(i == flattenedIndex ? scalarAttr
: sourceValues[i]);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
insertOp, sourceAttr.getType(),
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
updatedValues));
return success();
}
}

return failure();
}
};

} // namespace

void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
Expand All @@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}

void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertOpConstantFold>(context);
}

//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 26 additions & 3 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
// -----

// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
Expand Down Expand Up @@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>

// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
// Fold an extract after an insert.
// CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
%c4_i32 = arith.constant 4 : i32
%5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
%inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
%ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>

// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
}

// -----
Expand Down Expand Up @@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}


// -----

func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
// Fold an insert into a splat.
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
// CHECK-LITERAL:
// CHECK-NEXT: return %[[C4]]
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4_i32 = arith.constant 4 : i32
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
return %inserted : tensor<2x2xi32>
}

// -----

// CHECK-LABEL: func @extract_from_tensor.cast
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s

// CHECK-LABEL: @extract_from_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
%extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
%extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
func.return %extracted, %extracted_0 : i8, i8
}

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8

// -----

// CHECK-LABEL: @extract_from_static_shape
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
%1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
return %1 : f32
}
// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
13 changes: 13 additions & 0 deletions mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};

Option<bool> testFoldExtractFromCollapseShape{
*this, "test-fold-extract-from-collapse-shape",
llvm::cl::desc("Test folding of extract from collapse_shape"),
llvm::cl::init(false)};

Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
Expand Down Expand Up @@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}

static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateFoldCollapseExtractPatterns(patterns);
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}

namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
Expand Down Expand Up @@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
if (testFoldExtractFromCollapseShape)
applyFoldExtractFromCollapseShapePatterns(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
Expand Down
Loading