@@ -573,6 +573,37 @@ static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
573
573
return recip;
574
574
}
575
575
576
+ // / Fold reshapes. This is similar to ReshapeOp::fold, but also allows
577
+ // / to fold with multiple users.
578
+ struct TosaFoldConstantReshape
579
+ : public TosaFoldConstantUnaryElementwise<TosaFoldConstantReshape,
580
+ ReshapeOp> {
581
+ using TosaFoldConstantUnaryElementwise::TosaFoldConstantUnaryElementwise;
582
+
583
+ LogicalResult matchAndRewrite (ReshapeOp op,
584
+ PatternRewriter &rewriter) const override {
585
+ auto inputTensor = op.getOperand ();
586
+ // Check that we can apply folding
587
+ auto preCondCheck =
588
+ notifyIfNoTosaDenseConstantTensor (inputTensor, op, rewriter);
589
+ if (failed (preCondCheck))
590
+ return preCondCheck;
591
+
592
+ // Extract the tensor values
593
+ DenseElementsAttr inputValues;
594
+ matchPattern (inputTensor, m_Constant (&inputValues));
595
+
596
+ // Check whether this should be folded.
597
+ if (!constantUnaryOpShouldBeFolded (op, inputValues)) {
598
+ return rewriter.notifyMatchFailure (
599
+ op, " expected reshape op to have a single user" );
600
+ }
601
+ DenseElementsAttr newTensor = inputValues.reshape (op.getType ());
602
+ rewriter.replaceOpWithNewOp <ConstOp>(op, newTensor.getType (), newTensor);
603
+ return success ();
604
+ }
605
+ };
606
+
576
607
struct TosaFoldConstantReciprocal
577
608
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal, ReciprocalOp> {
578
609
using TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal,
@@ -1723,8 +1754,10 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
1723
1754
MLIRContext *ctx, RewritePatternSet &patterns,
1724
1755
bool foldSplatOrSingleUseOnly,
1725
1756
bool enableIntCastFolding) {
1757
+
1726
1758
patterns.add <TosaFoldConstantTranspose>(ctx, foldSplatOrSingleUseOnly);
1727
1759
patterns.add <TosaFoldConstantReciprocal>(ctx, foldSplatOrSingleUseOnly);
1760
+ patterns.add <TosaFoldConstantReshape>(ctx, foldSplatOrSingleUseOnly);
1728
1761
patterns.add <TosaFoldConstantRSQRT>(ctx, foldSplatOrSingleUseOnly);
1729
1762
patterns.add <TosaFoldConstantLogicalNot>(ctx, foldSplatOrSingleUseOnly);
1730
1763
patterns.add <TosaFoldConstantPow>(ctx, foldSplatOrSingleUseOnly);
0 commit comments