@@ -755,7 +755,8 @@ LogicalResult DeallocOp::inferReturnTypes(
755
755
ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
756
756
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
757
757
DeallocOpAdaptor adaptor (operands, attributes, properties, regions);
758
- inferredReturnTypes = SmallVector<Type>(adaptor.getConditions ().getTypes ());
758
+ inferredReturnTypes = SmallVector<Type>(adaptor.getRetained ().size (),
759
+ IntegerType::get (context, 1 ));
759
760
return success ();
760
761
}
761
762
@@ -766,44 +767,46 @@ LogicalResult DeallocOp::verify() {
766
767
return success ();
767
768
}
768
769
770
+ static LogicalResult updateDeallocIfChanged (DeallocOp deallocOp,
771
+ ArrayRef<Value> memrefs,
772
+ ArrayRef<Value> conditions,
773
+ PatternRewriter &rewriter) {
774
+ if (deallocOp.getMemrefs () == memrefs)
775
+ return failure ();
776
+
777
+ rewriter.updateRootInPlace (deallocOp, [&]() {
778
+ deallocOp.getMemrefsMutable ().assign (memrefs);
779
+ deallocOp.getConditionsMutable ().assign (conditions);
780
+ });
781
+ return success ();
782
+ }
783
+
769
784
namespace {
770
785
771
- // / Remove duplicate values in the list of retained memrefs as well as the list
772
- // / of memrefs to be deallocated. For the latter, we need to make sure the
773
- // / corresponding condition values match as well, or otherwise have to combine
774
- // / them (by computing the disjunction of them).
786
+ // / Remove duplicate values in the list of memrefs to be deallocated. We need to
787
+ // / make sure the corresponding condition value is updated accordingly since
788
+ // / their two conditions might not cover the same set of cases. In that case, we
789
+ // / have to combine them (by computing the disjunction of them).
775
790
// / Example:
776
791
// / ```mlir
777
- // / %0:2 = bufferization.dealloc (%arg0, %arg0 : ...)
778
- // / if (%arg1, %arg2)
779
- // / retain (%arg3, %arg3 : ...)
792
+ // / bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
780
793
// / ```
781
794
// / is canonicalized to
782
795
// / ```mlir
783
796
// / %0 = arith.ori %arg1, %arg2 : i1
784
- // / %1 = bufferization.dealloc (%arg0 : memref<2xi32>)
785
- // / if (%0)
786
- // / retain (%arg3 : memref<2xi32>)
797
+ // / bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
787
798
// / ```
788
- struct DeallocRemoveDuplicates : public OpRewritePattern <DeallocOp> {
799
+ struct DeallocRemoveDuplicateDeallocMemrefs
800
+ : public OpRewritePattern<DeallocOp> {
789
801
using OpRewritePattern<DeallocOp>::OpRewritePattern;
790
802
791
803
LogicalResult matchAndRewrite (DeallocOp deallocOp,
792
804
PatternRewriter &rewriter) const override {
793
805
// Unique memrefs to be deallocated.
794
- DenseSet<Value> retained (deallocOp.getRetained ().begin (),
795
- deallocOp.getRetained ().end ());
796
806
DenseMap<Value, unsigned > memrefToCondition;
797
- SmallVector<Value> newMemrefs, newConditions, newRetained;
798
- SmallVector<int32_t > resultIndices (deallocOp.getMemrefs ().size (), -1 );
807
+ SmallVector<Value> newMemrefs, newConditions;
799
808
for (auto [i, memref, cond] :
800
809
llvm::enumerate (deallocOp.getMemrefs (), deallocOp.getConditions ())) {
801
- if (retained.contains (memref)) {
802
- rewriter.replaceAllUsesWith (deallocOp.getResult (i),
803
- deallocOp.getConditions ()[i]);
804
- continue ;
805
- }
806
-
807
810
if (memrefToCondition.count (memref)) {
808
811
// If the dealloc conditions don't match, we need to make sure that the
809
812
// dealloc happens on the union of cases.
@@ -816,50 +819,133 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
816
819
newMemrefs.push_back (memref);
817
820
newConditions.push_back (cond);
818
821
}
819
- resultIndices[i] = memrefToCondition[memref];
820
822
}
821
823
824
+ // Return failure if we don't change anything such that we don't run into an
825
+ // infinite loop of pattern applications.
826
+ return updateDeallocIfChanged (deallocOp, newMemrefs, newConditions,
827
+ rewriter);
828
+ }
829
+ };
830
+
831
+ // / Remove duplicate values in the list of retained memrefs. We need to make
832
+ // / sure the corresponding result condition value is replaced properly.
833
+ // / Example:
834
+ // / ```mlir
835
+ // / %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
836
+ // / ```
837
+ // / is canonicalized to
838
+ // / ```mlir
839
+ // / %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
840
+ // / ```
841
+ struct DeallocRemoveDuplicateRetainedMemrefs
842
+ : public OpRewritePattern<DeallocOp> {
843
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
844
+
845
+ LogicalResult matchAndRewrite (DeallocOp deallocOp,
846
+ PatternRewriter &rewriter) const override {
822
847
// Unique retained values
823
- DenseSet<Value> seen;
848
+ DenseMap<Value, unsigned > seen;
849
+ SmallVector<Value> newRetained;
850
+ SmallVector<unsigned > resultReplacementIdx;
851
+ unsigned i = 0 ;
824
852
for (auto retained : deallocOp.getRetained ()) {
825
- if (! seen.contains (retained)) {
826
- seen. insert ( retained);
827
- newRetained. push_back (retained) ;
853
+ if (seen.count (retained)) {
854
+ resultReplacementIdx. push_back (seen[ retained] );
855
+ continue ;
828
856
}
857
+
858
+ seen[retained] = i;
859
+ newRetained.push_back (retained);
860
+ resultReplacementIdx.push_back (i++);
829
861
}
830
862
831
863
// Return failure if we don't change anything such that we don't run into an
832
864
// infinite loop of pattern applications.
833
- if (newConditions.size () == deallocOp.getConditions ().size () &&
834
- newRetained.size () == deallocOp.getRetained ().size ())
865
+ if (newRetained.size () == deallocOp.getRetained ().size ())
835
866
return failure ();
836
867
837
868
// We need to create a new op because the number of results is always the
838
869
// same as the number of condition operands.
839
- auto newDealloc = rewriter. create <DeallocOp>(deallocOp. getLoc (), newMemrefs,
840
- newConditions, newRetained);
841
- for ( auto [i, newIdx] : llvm::enumerate (resultIndices))
842
- if (newIdx != - 1 )
843
- rewriter. replaceAllUsesWith (deallocOp. getResult (i),
844
- newDealloc. getResult (newIdx)) ;
845
-
846
- rewriter.eraseOp (deallocOp);
870
+ auto newDeallocOp =
871
+ rewriter. create <DeallocOp>(deallocOp. getLoc (), deallocOp. getMemrefs (),
872
+ deallocOp. getConditions (), newRetained);
873
+ SmallVector<Value> replacements (
874
+ llvm::map_range (resultReplacementIdx, [&]( unsigned idx) {
875
+ return newDeallocOp. getUpdatedConditions ()[idx] ;
876
+ }));
877
+ rewriter.replaceOp (deallocOp, replacements );
847
878
return success ();
848
879
}
849
880
};
850
881
882
+ // / Remove memrefs to be deallocated that are also present in the retained list
883
+ // / since they will always alias and thus never actually be deallocated.
884
+ // / Example:
885
+ // / ```mlir
886
+ // / %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...)
887
+ // / ```
888
+ // / is canonicalized to
889
+ // / ```mlir
890
+ // / %0 = bufferization.dealloc retain (%arg0 : ...)
891
+ // / ```
892
+ struct DeallocRemoveDeallocMemrefsContainedInRetained
893
+ : public OpRewritePattern<DeallocOp> {
894
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
895
+
896
+ LogicalResult matchAndRewrite (DeallocOp deallocOp,
897
+ PatternRewriter &rewriter) const override {
898
+ // Unique memrefs to be deallocated.
899
+ DenseMap<Value, unsigned > retained;
900
+ for (auto [i, ret] : llvm::enumerate (deallocOp.getRetained ()))
901
+ retained[ret] = i;
902
+
903
+ // There must not be any duplicates in the retain list anymore because we
904
+ // would miss updating one of the result values otherwise.
905
+ if (retained.size () != deallocOp.getRetained ().size ())
906
+ return failure ();
907
+
908
+ SmallVector<Value> newMemrefs, newConditions;
909
+ for (auto [memref, cond] :
910
+ llvm::zip (deallocOp.getMemrefs (), deallocOp.getConditions ())) {
911
+ if (retained.contains (memref)) {
912
+ rewriter.setInsertionPointAfter (deallocOp);
913
+ auto orOp = rewriter.create <arith::OrIOp>(
914
+ deallocOp.getLoc (),
915
+ deallocOp.getUpdatedConditions ()[retained[memref]], cond);
916
+ rewriter.replaceAllUsesExcept (
917
+ deallocOp.getUpdatedConditions ()[retained[memref]],
918
+ orOp.getResult (), orOp);
919
+ continue ;
920
+ }
921
+
922
+ newMemrefs.push_back (memref);
923
+ newConditions.push_back (cond);
924
+ }
925
+
926
+ // Return failure if we don't change anything such that we don't run into an
927
+ // infinite loop of pattern applications.
928
+ return updateDeallocIfChanged (deallocOp, newMemrefs, newConditions,
929
+ rewriter);
930
+ }
931
+ };
932
+
851
933
// / Erase deallocation operations where the variadic list of memrefs to
852
- // / deallocate is emtpy . Example:
934
+ // / deallocate is empty . Example:
853
935
// / ```mlir
854
- // / bufferization.dealloc retain (%arg0: memref<2xi32>)
936
+ // / %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
855
937
// / ```
856
938
struct EraseEmptyDealloc : public OpRewritePattern <DeallocOp> {
857
939
using OpRewritePattern<DeallocOp>::OpRewritePattern;
858
940
859
941
LogicalResult matchAndRewrite (DeallocOp deallocOp,
860
942
PatternRewriter &rewriter) const override {
861
943
if (deallocOp.getMemrefs ().empty ()) {
862
- rewriter.eraseOp (deallocOp);
944
+ Value constFalse = rewriter.create <arith::ConstantOp>(
945
+ deallocOp.getLoc (), rewriter.getBoolAttr (false ));
946
+ rewriter.replaceOp (
947
+ deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions ().size (),
948
+ constFalse));
863
949
return success ();
864
950
}
865
951
return failure ();
@@ -871,55 +957,40 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
871
957
// /
872
958
// / Example:
873
959
// / ```
874
- // / %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
960
+ // / bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
875
961
// / if (%arg2, %false)
876
962
// / ```
877
963
// / becomes
878
964
// / ```
879
- // / %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
965
+ // / bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
880
966
// / ```
881
967
struct EraseAlwaysFalseDealloc : public OpRewritePattern <DeallocOp> {
882
968
using OpRewritePattern<DeallocOp>::OpRewritePattern;
883
969
884
970
LogicalResult matchAndRewrite (DeallocOp deallocOp,
885
971
PatternRewriter &rewriter) const override {
886
972
SmallVector<Value> newMemrefs, newConditions;
887
- SmallVector<Value> replacements;
888
-
889
- for (auto [res, memref, cond] :
890
- llvm::zip (deallocOp.getUpdatedConditions (), deallocOp.getMemrefs (),
891
- deallocOp.getConditions ())) {
892
- if (matchPattern (cond, m_Zero ())) {
893
- replacements.push_back (cond);
894
- continue ;
973
+ for (auto [memref, cond] :
974
+ llvm::zip (deallocOp.getMemrefs (), deallocOp.getConditions ())) {
975
+ if (!matchPattern (cond, m_Zero ())) {
976
+ newMemrefs.push_back (memref);
977
+ newConditions.push_back (cond);
895
978
}
896
- newMemrefs.push_back (memref);
897
- newConditions.push_back (cond);
898
- replacements.push_back ({});
899
979
}
900
980
901
- if (newMemrefs.size () == deallocOp.getMemrefs ().size ())
902
- return failure ();
903
-
904
- auto newDeallocOp = rewriter.create <DeallocOp>(
905
- deallocOp.getLoc (), newMemrefs, newConditions, deallocOp.getRetained ());
906
- unsigned i = 0 ;
907
- for (auto &repl : replacements)
908
- if (!repl)
909
- repl = newDeallocOp.getResult (i++);
910
-
911
- rewriter.replaceOp (deallocOp, replacements);
912
- return success ();
981
+ return updateDeallocIfChanged (deallocOp, newMemrefs, newConditions,
982
+ rewriter);
913
983
}
914
984
};
915
985
916
986
} // anonymous namespace
917
987
918
988
void DeallocOp::getCanonicalizationPatterns (RewritePatternSet &results,
919
989
MLIRContext *context) {
920
- results
921
- .add <DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
922
- context);
990
+ results.add <DeallocRemoveDuplicateDeallocMemrefs,
991
+ DeallocRemoveDuplicateRetainedMemrefs,
992
+ DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
993
+ EraseAlwaysFalseDealloc>(context);
923
994
}
924
995
925
996
// ===----------------------------------------------------------------------===//
0 commit comments