Skip to content

Commit e33c018

Browse files
[mlir][Tensor] Move concat operation decomposition as a method of the concat operation.
Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent bf51a9e commit e33c018

File tree

4 files changed

+84
-69
lines changed

4 files changed

+84
-69
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
178178
int64_t getRank() {
179179
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
180180
}
181+
182+
// Method to decompose the operation into a sequence of insert_slices.
183+
FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
181184
}];
182185

183186
let hasCanonicalizer = 1;

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() {
615615
return success();
616616
}
617617

618+
FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
619+
size_t numInputs = getInputs().size();
620+
uint64_t concatDim = getDim();
621+
622+
SmallVector<SmallVector<OpFoldResult>> inputShapes;
623+
inputShapes.reserve(numInputs);
624+
SmallVector<OpFoldResult> concatOffsets;
625+
concatOffsets.reserve(numInputs);
626+
SmallVector<OpFoldResult> outputShape;
627+
628+
AffineExpr addExpr =
629+
builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
630+
OpFoldResult zero = builder.getIndexAttr(0);
631+
Location loc = getLoc();
632+
for (auto [index, input] : llvm::enumerate(getInputs())) {
633+
SmallVector<OpFoldResult> inputShape =
634+
tensor::getMixedSizes(builder, input.getLoc(), input);
635+
if (index == 0) {
636+
outputShape = inputShape;
637+
concatOffsets.push_back(zero);
638+
} else {
639+
concatOffsets.push_back(outputShape[concatDim]);
640+
outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
641+
builder, loc, addExpr,
642+
{outputShape[concatDim], inputShape[concatDim]});
643+
}
644+
inputShapes.emplace_back(std::move(inputShape));
645+
}
646+
647+
Value replacement = builder.create<tensor::EmptyOp>(
648+
loc, outputShape, getType().getElementType());
649+
650+
int64_t rank = getType().getRank();
651+
OpFoldResult one = builder.getIndexAttr(1);
652+
SmallVector<OpFoldResult> strides(rank, one);
653+
SmallVector<OpFoldResult> offsets(rank, zero);
654+
for (auto [index, input] : llvm::enumerate(getInputs())) {
655+
offsets[concatDim] = concatOffsets[index];
656+
auto insertSlice = builder.create<tensor::InsertSliceOp>(
657+
loc, input, replacement, offsets, inputShapes[index], strides);
658+
replacement = insertSlice.getResult();
659+
}
660+
if (replacement.getType() != getType()) {
661+
replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
662+
}
663+
return SmallVector<Value>{replacement};
664+
}
665+
618666
LogicalResult
619667
ConcatOp::reifyResultShapes(OpBuilder &builder,
620668
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {

mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
3333

3434
LogicalResult matchAndRewrite(ConcatOp concatOp,
3535
PatternRewriter &rewriter) const override {
36-
Location loc = concatOp.getLoc();
37-
FailureOr<Value> dest =
38-
tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
39-
if (failed(dest))
40-
return failure();
41-
42-
auto empty = dest->getDefiningOp<tensor::EmptyOp>();
43-
if (!empty)
44-
return failure();
45-
46-
int64_t dim = concatOp.getDim();
47-
Value dimValue =
48-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
49-
50-
int64_t rank = concatOp.getResultType().getRank();
51-
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
52-
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
53-
54-
// Compute the partial sums for the slice offsets.
55-
AffineExpr sum = rewriter.getAffineDimExpr(0);
56-
SmallVector<AffineExpr> partialSums = {sum};
57-
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
58-
for (auto [idx, input] :
59-
llvm::enumerate(concatOp.getInputs().drop_back())) {
60-
sum = sum + rewriter.getAffineDimExpr(idx + 1);
61-
partialSums.push_back(sum);
62-
offsetStrides.push_back(
63-
rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
36+
FailureOr<SmallVector<Value>> decomposed =
37+
concatOp.decomposeOperation(rewriter);
38+
if (failed(decomposed)) {
39+
return rewriter.notifyMatchFailure(
40+
concatOp, "failed to get the decomposed insert slices");
6441
}
65-
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
66-
partialSums, rewriter.getContext());
67-
SmallVector<OpFoldResult> dimOffsets =
68-
affine::makeComposedFoldedMultiResultAffineApply(
69-
rewriter, loc, partialSumMap, offsetStrides);
70-
71-
// Construct the chain of insert_slice ops into the destination.
72-
Value result = *dest;
73-
for (auto [input, offset] :
74-
llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
75-
SmallVector<OpFoldResult> sizes =
76-
tensor::getMixedSizes(rewriter, loc, input);
77-
offsets[dim] = offset;
78-
result = rewriter.createOrFold<tensor::InsertSliceOp>(
79-
loc, input, result, offsets, sizes, strides);
80-
}
81-
82-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
83-
concatOp, concatOp.getResultType(), result);
42+
rewriter.replaceOp(concatOp, decomposed.value()[0]);
8443
return success();
8544
}
8645
};

mlir/test/Dialect/Tensor/decompose-concat.mlir

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
22

33
func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
44
%0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
55
return %0 : tensor<?x?xf32>
66
}
7-
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)>
87
// CHECK-LABEL: func @decompose_dynamic_concat(
98
// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
109
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
1110

12-
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
1311
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1412
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
15-
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
16-
// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
17-
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
18-
// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
19-
// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
20-
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
21-
// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
13+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
14+
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
15+
// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 4)>()[%[[DIM0]]]
16+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
17+
// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
18+
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[DIM]], %[[DIM0]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
19+
// CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
20+
// CHECK: return %[[CAST]] : tensor<?x?xf32>
2221

2322
func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
2423
%arg1 : tensor<2xf32>,
@@ -42,12 +41,14 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
4241
: (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32>
4342
return %0 : tensor<1x?x128xf32>
4443
}
45-
// CHECK-LABEL: func @decompose_static_concat_dim
44+
// CHECK-LABEL: func @decompose_static_concat_dim(
45+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>,
46+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>)
4647
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
47-
// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
48+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x64xf32>
49+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x64xf32>
4850
// CHECK: tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
4951
// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
50-
// CHECK: %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
5152
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
5253
// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
5354

@@ -58,19 +59,23 @@ func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
5859
: (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
5960
return %0 : tensor<1x?x128xf32>
6061
}
61-
// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
62+
// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim(
63+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>,
64+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>)
6265
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
6366
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
64-
// CHECK: %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
65-
// CHECK: tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
66-
// CHECK: %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
67+
// CHECK: %[[T0_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x?xf32>
68+
// CHECK: %[[T0_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<1x?x?xf32>
69+
// CHECK: %[[T1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x?xf32>
70+
// CHECK: %[[T1_DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<1x?x?xf32>
71+
// CHECK: %[[CONCAT_DIM:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[T0_DIM2]], %[[T1_DIM2]]]
72+
// CHECK: tensor.empty(%[[T0_DIM1]], %[[CONCAT_DIM]]) : tensor<1x?x?xf32>
6773
// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
68-
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
69-
// CHECK: %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
70-
// CHECK: %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
74+
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32>
7175
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
72-
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
73-
// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
76+
// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32>
77+
// CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<1x?x?xf32> to tensor<1x?x128xf32>
78+
// CHECK: return %[[CAST]] : tensor<1x?x128xf32>
7479

7580
module attributes {transform.with_named_sequence} {
7681
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)