@@ -573,6 +573,38 @@ static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
573
573
return recip;
574
574
}
575
575
576
+ // / Fold reshapes. This is similar to ReshapeOp::fold, but also allows
577
+ // / to fold with multiple users.
578
+ struct TosaFoldConstantReshape
579
+ : public TosaFoldConstantUnaryElementwise<TosaFoldConstantReshape,
580
+ ReshapeOp> {
581
+ using TosaFoldConstantUnaryElementwise::TosaFoldConstantUnaryElementwise;
582
+
583
+ LogicalResult matchAndRewrite (ReshapeOp op,
584
+ PatternRewriter &rewriter) const override {
585
+ auto inputTensor = op.getOperand ();
586
+ // Check that we can apply folding
587
+ auto preCondCheck =
588
+ notifyIfNoTosaDenseConstantTensor (inputTensor, op, rewriter);
589
+ if (failed (preCondCheck)) {
590
+ return preCondCheck;
591
+ }
592
+
593
+ // Extract the tensor values
594
+ DenseElementsAttr inputValues;
595
+ matchPattern (inputTensor, m_Constant (&inputValues));
596
+
597
+ // Check whether this should be folded.
598
+ if (!constantUnaryOpShouldBeFolded (op, inputValues)) {
599
+ return rewriter.notifyMatchFailure (
600
+ op, " expected reshape op to have a single user" );
601
+ }
602
+ DenseElementsAttr newTensor = inputValues.reshape (op.getType ());
603
+ rewriter.replaceOpWithNewOp <ConstOp>(op, newTensor.getType (), newTensor);
604
+ return success ();
605
+ }
606
+ };
607
+
576
608
struct TosaFoldConstantReciprocal
577
609
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal, ReciprocalOp> {
578
610
using TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal,
@@ -1723,6 +1755,8 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
1723
1755
MLIRContext *ctx, RewritePatternSet &patterns,
1724
1756
bool foldSplatOrSingleUseOnly,
1725
1757
bool enableIntCastFolding) {
1758
+
1759
+ patterns.add <TosaFoldConstantReshape>(ctx, foldSplatOrSingleUseOnly);
1726
1760
patterns.add <TosaFoldConstantTranspose>(ctx, foldSplatOrSingleUseOnly);
1727
1761
patterns.add <TosaFoldConstantReciprocal>(ctx, foldSplatOrSingleUseOnly);
1728
1762
patterns.add <TosaFoldConstantRSQRT>(ctx, foldSplatOrSingleUseOnly);
0 commit comments