@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
33
33
34
34
LogicalResult matchAndRewrite (ConcatOp concatOp,
35
35
PatternRewriter &rewriter) const override {
36
- Location loc = concatOp.getLoc ();
37
- FailureOr<Value> dest =
38
- tensor::getOrCreateDestination (rewriter, loc, concatOp->getResult (0 ));
39
- if (failed (dest))
40
- return failure ();
41
-
42
- auto empty = dest->getDefiningOp <tensor::EmptyOp>();
43
- if (!empty)
44
- return failure ();
45
-
46
- int64_t dim = concatOp.getDim ();
47
- Value dimValue =
48
- rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (dim));
49
-
50
- int64_t rank = concatOp.getResultType ().getRank ();
51
- SmallVector<OpFoldResult> strides (rank, rewriter.getIndexAttr (1 ));
52
- SmallVector<OpFoldResult> offsets (rank, rewriter.getIndexAttr (0 ));
53
-
54
- // Compute the partial sums for the slice offsets.
55
- AffineExpr sum = rewriter.getAffineDimExpr (0 );
56
- SmallVector<AffineExpr> partialSums = {sum};
57
- SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr (0 )};
58
- for (auto [idx, input] :
59
- llvm::enumerate (concatOp.getInputs ().drop_back ())) {
60
- sum = sum + rewriter.getAffineDimExpr (idx + 1 );
61
- partialSums.push_back (sum);
62
- offsetStrides.push_back (
63
- rewriter.createOrFold <tensor::DimOp>(loc, input, dimValue));
36
+ FailureOr<SmallVector<Value>> decomposed =
37
+ concatOp.decomposeOperation (rewriter);
38
+ if (failed (decomposed)) {
39
+ return rewriter.notifyMatchFailure (
40
+ concatOp, " failed to get the decomposed insert slices" );
64
41
}
65
- auto partialSumMap = AffineMap::get (concatOp.getInputs ().size (), 0 ,
66
- partialSums, rewriter.getContext ());
67
- SmallVector<OpFoldResult> dimOffsets =
68
- affine::makeComposedFoldedMultiResultAffineApply (
69
- rewriter, loc, partialSumMap, offsetStrides);
70
-
71
- // Construct the chain of insert_slice ops into the destination.
72
- Value result = *dest;
73
- for (auto [input, offset] :
74
- llvm::zip_equal (concatOp.getInputs (), dimOffsets)) {
75
- SmallVector<OpFoldResult> sizes =
76
- tensor::getMixedSizes (rewriter, loc, input);
77
- offsets[dim] = offset;
78
- result = rewriter.createOrFold <tensor::InsertSliceOp>(
79
- loc, input, result, offsets, sizes, strides);
80
- }
81
-
82
- rewriter.replaceOpWithNewOp <tensor::CastOp>(
83
- concatOp, concatOp.getResultType (), result);
42
+ rewriter.replaceOp (concatOp, decomposed.value ()[0 ]);
84
43
return success ();
85
44
}
86
45
};
0 commit comments