Skip to content

[mlir][Tensor] Move concat operation decomposition as a method of the concat operation. #116004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
int64_t getRank() {
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
}

// Method to decompose the operation into a sequence of insert_slices.
FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
}];

let hasCanonicalizer = 1;
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() {
return success();
}

FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
size_t numInputs = getInputs().size();
uint64_t concatDim = getDim();

SmallVector<SmallVector<OpFoldResult>> inputShapes;
inputShapes.reserve(numInputs);
SmallVector<OpFoldResult> concatOffsets;
concatOffsets.reserve(numInputs);
SmallVector<OpFoldResult> outputShape;

AffineExpr addExpr =
builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
OpFoldResult zero = builder.getIndexAttr(0);
Location loc = getLoc();
for (auto [index, input] : llvm::enumerate(getInputs())) {
SmallVector<OpFoldResult> inputShape =
tensor::getMixedSizes(builder, input.getLoc(), input);
if (index == 0) {
outputShape = inputShape;
concatOffsets.push_back(zero);
} else {
concatOffsets.push_back(outputShape[concatDim]);
outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
builder, loc, addExpr,
{outputShape[concatDim], inputShape[concatDim]});
}
inputShapes.emplace_back(std::move(inputShape));
}

Value replacement = builder.create<tensor::EmptyOp>(
loc, outputShape, getType().getElementType());

int64_t rank = getType().getRank();
OpFoldResult one = builder.getIndexAttr(1);
SmallVector<OpFoldResult> strides(rank, one);
SmallVector<OpFoldResult> offsets(rank, zero);
for (auto [index, input] : llvm::enumerate(getInputs())) {
offsets[concatDim] = concatOffsets[index];
auto insertSlice = builder.create<tensor::InsertSliceOp>(
loc, input, replacement, offsets, inputShapes[index], strides);
replacement = insertSlice.getResult();
}
if (replacement.getType() != getType()) {
replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
}
Comment on lines +660 to +662
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main difference is that the getOrCreateDestination "infers" the static shape when possible. We can get rid of the tensor.cast ops if we use the method. Why do you like the current way better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cause the earlier implementation was relying on that creating a tensor.empty and that is a weird dependence on an implementation detail of that method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, good point. Have you considered using ConcatOp::reifyResultShapes to create the outputShape for tensor.empty op? Though this way we might create more operations and pay the cost. (I don't have preference, just wanna make sure that the idea is evaluated.)

The current implementation looks okay to me because the root issue is that the op does no infer static shapes when possible. We'll end up with these tensor.cast ops even if the shape inference is implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorry my phone was misbehaving and I was afk). That's exactly right. reifyResultShapes would create a lot of operations. We need to have proper cast propagation to resolve the static information that is outside of this decomposition

return SmallVector<Value>{replacement};
}

LogicalResult
ConcatOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
Expand Down
53 changes: 6 additions & 47 deletions mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {

LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
Location loc = concatOp.getLoc();
FailureOr<Value> dest =
tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
if (failed(dest))
return failure();

auto empty = dest->getDefiningOp<tensor::EmptyOp>();
if (!empty)
return failure();

int64_t dim = concatOp.getDim();
Value dimValue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));

int64_t rank = concatOp.getResultType().getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));

// Compute the partial sums for the slice offsets.
AffineExpr sum = rewriter.getAffineDimExpr(0);
SmallVector<AffineExpr> partialSums = {sum};
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
for (auto [idx, input] :
llvm::enumerate(concatOp.getInputs().drop_back())) {
sum = sum + rewriter.getAffineDimExpr(idx + 1);
partialSums.push_back(sum);
offsetStrides.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
FailureOr<SmallVector<Value>> decomposed =
concatOp.decomposeOperation(rewriter);
if (failed(decomposed)) {
return rewriter.notifyMatchFailure(
concatOp, "failed to get the decomposed insert slices");
}
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
partialSums, rewriter.getContext());
SmallVector<OpFoldResult> dimOffsets =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, loc, partialSumMap, offsetStrides);

// Construct the chain of insert_slice ops into the destination.
Value result = *dest;
for (auto [input, offset] :
llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, input);
offsets[dim] = offset;
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, input, result, offsets, sizes, strides);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
concatOp, concatOp.getResultType(), result);
rewriter.replaceOp(concatOp, decomposed.value()[0]);
return success();
}
};
Expand Down
49 changes: 27 additions & 22 deletions mlir/test/Dialect/Tensor/decompose-concat.mlir
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s

func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-LABEL: func @decompose_dynamic_concat(
// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>

// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 4)>()[%[[DIM0]]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[DIM]], %[[DIM0]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
// CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
// CHECK: return %[[CAST]] : tensor<?x?xf32>

func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
%arg1 : tensor<2xf32>,
Expand All @@ -42,12 +41,14 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
: (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32>
return %0 : tensor<1x?x128xf32>
}
// CHECK-LABEL: func @decompose_static_concat_dim
// CHECK-LABEL: func @decompose_static_concat_dim(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>)
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x64xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x64xf32>
// CHECK: tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>

Expand All @@ -58,19 +59,23 @@ func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
: (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
return %0 : tensor<1x?x128xf32>
}
// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>)
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
// CHECK: tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
// CHECK: %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
// CHECK: %[[T0_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x?xf32>
// CHECK: %[[T0_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<1x?x?xf32>
// CHECK: %[[T1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x?xf32>
// CHECK: %[[T1_DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<1x?x?xf32>
// CHECK: %[[CONCAT_DIM:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[T0_DIM2]], %[[T1_DIM2]]]
// CHECK: tensor.empty(%[[T0_DIM1]], %[[CONCAT_DIM]]) : tensor<1x?x?xf32>
// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
// CHECK: %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
// CHECK: %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32>
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32>
// CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<1x?x?xf32> to tensor<1x?x128xf32>
// CHECK: return %[[CAST]] : tensor<1x?x128xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
Expand Down
Loading