@@ -93,15 +93,11 @@ struct SelfConcatToTile : public OpRewritePattern<tosa::ConcatOp> {
93
93
}
94
94
SmallVector<int64_t > multiplies (concatType.getRank (), 1 );
95
95
multiplies[concatOp.getAxis ()] = concatOp->getNumOperands ();
96
- const int64_t rank = multiplies.size ();
97
- auto constantShapeOp = rewriter.create <ConstShapeOp>(
98
- concatOp->getLoc (), shapeType::get (concatOp->getContext (), rank),
99
- DenseIntElementsAttr::get (
100
- RankedTensorType::get ({rank}, rewriter.getIndexType ()),
101
- multiplies));
96
+ auto constantShapeValue =
97
+ getTosaConstShape (rewriter, concatOp->getLoc (), multiplies);
102
98
auto tileOp = rewriter.createOrFold <tosa::TileOp>(
103
99
concatOp->getLoc (), concatOp.getType (), concatOp->getOperand (0 ),
104
- constantShapeOp );
100
+ constantShapeValue );
105
101
rewriter.replaceOp (concatOp, {tileOp});
106
102
return success ();
107
103
}
@@ -140,19 +136,15 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
140
136
for (auto [idx, multiplier] : llvm::enumerate (inputTileMultiples)) {
141
137
multiplies[idx] *= multiplier;
142
138
}
143
- auto constantShapeOp = rewriter.create <ConstShapeOp>(
139
+ auto constantShapeValue = getTosaConstShape (
140
+ rewriter,
144
141
rewriter.getFusedLoc (
145
142
{op.getMultiples ().getLoc (), inputTile.getMultiples ().getLoc ()}),
146
- op.getMultiples ().getType (),
147
- DenseIntElementsAttr::get (
148
- RankedTensorType::get (
149
- {cast<shapeType>(op.getMultiples ().getType ()).getRank ()},
150
- rewriter.getIndexType ()),
151
- multiplies));
143
+ multiplies);
152
144
153
145
rewriter.modifyOpInPlace (op, [&]() {
154
146
op.setOperand (0 , inputTile->getOperand (0 ));
155
- op.setOperand (1 , constantShapeOp );
147
+ op.setOperand (1 , constantShapeValue );
156
148
op.getOperation ()->setLoc (
157
149
FusedLoc::get (getContext (), {inputTile->getLoc (), op.getLoc ()}));
158
150
});
@@ -828,16 +820,11 @@ struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
828
820
llvm::zip_equal (newTileShape, requiredMultipliers)) {
829
821
newShape *= multiplier;
830
822
}
831
- auto constantShapeOp = rewriter.create <ConstShapeOp>(
832
- tileOp.getMultiples ().getLoc (), tileOp.getMultiples ().getType (),
833
- DenseIntElementsAttr::get (
834
- RankedTensorType::get (
835
- {cast<shapeType>(tileOp.getMultiples ().getType ()).getRank ()},
836
- rewriter.getIndexType ()),
837
- requiredMultipliers));
823
+ auto constantShapeValue = getTosaConstShape (
824
+ rewriter, tileOp.getMultiples ().getLoc (), requiredMultipliers);
838
825
auto newTile = rewriter.create <tosa::TileOp>(
839
826
tileOp->getLoc (), tileOpInputType.clone (newTileShape),
840
- tileOp->getOperand (0 ), constantShapeOp );
827
+ tileOp->getOperand (0 ), constantShapeValue );
841
828
auto newSlice = rewriter.create <tosa::SliceOp>(
842
829
sliceOp->getLoc (), sliceOp.getType (), newTile,
843
830
rewriter.getDenseI64ArrayAttr (newTileStarts), sliceOp.getSizeAttr ());
0 commit comments