Skip to content

Commit 6ceb8f8

Browse files
[mlir][Tensor] NFC: Move concat operation decomposition as a method of the concat operation.
Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent bf51a9e commit 6ceb8f8

File tree

3 files changed

+54
-47
lines changed

3 files changed

+54
-47
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
178178
int64_t getRank() {
179179
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
180180
}
181+
182+
// Method to decompose the operation into a sequence of insert_slices.
183+
FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
181184
}];
182185

183186
let hasCanonicalizer = 1;

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,51 @@ LogicalResult ConcatOp::verify() {
615615
return success();
616616
}
617617

618+
FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
619+
size_t numInputs = getInputs().size();
620+
uint64_t concatDim = getDim();
621+
622+
SmallVector<SmallVector<OpFoldResult>> inputShapes;
623+
inputShapes.reserve(numInputs);
624+
SmallVector<OpFoldResult> concatOffsets;
625+
concatOffsets.reserve(numInputs);
626+
SmallVector<OpFoldResult> outputShape;
627+
628+
AffineExpr addExpr =
629+
builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
630+
OpFoldResult zero = builder.getIndexAttr(0);
631+
Location loc = getLoc();
632+
for (auto [index, input] : llvm::enumerate(getInputs())) {
633+
SmallVector<OpFoldResult> inputShape =
634+
tensor::getMixedSizes(builder, input.getLoc(), input);
635+
if (index == 0) {
636+
outputShape = inputShape;
637+
concatOffsets.push_back(zero);
638+
} else {
639+
concatOffsets.push_back(outputShape[concatDim]);
640+
outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
641+
builder, loc, addExpr,
642+
{outputShape[concatDim], inputShape[concatDim]});
643+
}
644+
inputShapes.emplace_back(std::move(inputShape));
645+
}
646+
647+
Value replacement = builder.create<tensor::EmptyOp>(
648+
loc, outputShape, getType().getElementType());
649+
650+
int64_t rank = getType().getRank();
651+
OpFoldResult one = builder.getIndexAttr(1);
652+
SmallVector<OpFoldResult> strides(rank, one);
653+
SmallVector<OpFoldResult> offsets(rank, zero);
654+
for (auto [index, input] : llvm::enumerate(getInputs())) {
655+
offsets[concatDim] = concatOffsets[index];
656+
auto insertSlice = builder.create<tensor::InsertSliceOp>(
657+
loc, input, replacement, offsets, inputShapes[index], strides);
658+
replacement = insertSlice.getResult();
659+
}
660+
return SmallVector<Value>{replacement};
661+
}
662+
618663
LogicalResult
619664
ConcatOp::reifyResultShapes(OpBuilder &builder,
620665
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {

mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
3333

3434
LogicalResult matchAndRewrite(ConcatOp concatOp,
3535
PatternRewriter &rewriter) const override {
36-
Location loc = concatOp.getLoc();
37-
FailureOr<Value> dest =
38-
tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
39-
if (failed(dest))
40-
return failure();
41-
42-
auto empty = dest->getDefiningOp<tensor::EmptyOp>();
43-
if (!empty)
44-
return failure();
45-
46-
int64_t dim = concatOp.getDim();
47-
Value dimValue =
48-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
49-
50-
int64_t rank = concatOp.getResultType().getRank();
51-
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
52-
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
53-
54-
// Compute the partial sums for the slice offsets.
55-
AffineExpr sum = rewriter.getAffineDimExpr(0);
56-
SmallVector<AffineExpr> partialSums = {sum};
57-
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
58-
for (auto [idx, input] :
59-
llvm::enumerate(concatOp.getInputs().drop_back())) {
60-
sum = sum + rewriter.getAffineDimExpr(idx + 1);
61-
partialSums.push_back(sum);
62-
offsetStrides.push_back(
63-
rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
36+
FailureOr<SmallVector<Value>> decomposed =
37+
concatOp.decomposeOperation(rewriter);
38+
if (failed(decomposed)) {
39+
return rewriter.notifyMatchFailure(
40+
concatOp, "failed to get the decomposed insert slices");
6441
}
65-
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
66-
partialSums, rewriter.getContext());
67-
SmallVector<OpFoldResult> dimOffsets =
68-
affine::makeComposedFoldedMultiResultAffineApply(
69-
rewriter, loc, partialSumMap, offsetStrides);
70-
71-
// Construct the chain of insert_slice ops into the destination.
72-
Value result = *dest;
73-
for (auto [input, offset] :
74-
llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
75-
SmallVector<OpFoldResult> sizes =
76-
tensor::getMixedSizes(rewriter, loc, input);
77-
offsets[dim] = offset;
78-
result = rewriter.createOrFold<tensor::InsertSliceOp>(
79-
loc, input, result, offsets, sizes, strides);
80-
}
81-
82-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
83-
concatOp, concatOp.getResultType(), result);
42+
rewriter.replaceOp(concatOp, decomposed.value()[0]);
8443
return success();
8544
}
8645
};

0 commit comments

Comments
 (0)