Skip to content

Commit 4bde084

Browse files
committed
[mlir][bufferization] Change semantics of DeallocOp result values
This change allows supporting operations for which we don't get precise aliasing information without the need to insert clone operations. E.g., `arith.select`. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D156992
1 parent 58066ed commit 4bde084

File tree

7 files changed

+533
-277
lines changed

7 files changed

+533
-277
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -485,22 +485,41 @@ def Bufferization_DeallocOp : Bufferization_Op<"dealloc", [
485485
deallocating that memref). If two memrefs alias each other, only one will be
486486
deallocated to avoid double free situations.
487487

488-
The memrefs to be deallocated must be the originally allocated memrefs,
489-
however, the memrefs to be retained may be arbitrary memrefs.
490-
491-
Returns a list of conditions corresponding to the list of memrefs which
492-
indicates the new ownerships, i.e., if the memref was deallocated the
493-
ownership was dropped (set to 'false') and otherwise will be the same as the
494-
input condition.
488+
The number of variadic `memref` operands (the memrefs to be deallocated)
489+
must equal the number of variadic `condition` operands and correspond to
490+
each other element-wise.
491+
492+
The `memref` operands must be the originally allocated memrefs, however, the
493+
`retained` memref operands may be arbitrary memrefs.
494+
495+
This operation returns a variadic number of `updatedConditions` operands,
496+
one updated condition per retained memref. An updated condition indicates
497+
the ownership of the respective retained memref. It is computed as the
498+
disjunction of all `conditions` operands where the corresponding to
499+
`memrefs` operand aliases with the retained memref. If the retained memref
500+
has no aliases among `memrefs`, the resulting updated condition is 'false'.
501+
This is because all memrefs that need to be deallocated within one basic
502+
block should be added to the same `bufferization.dealloc` operation at the
503+
end of the block; if no aliasing memref is present, then it does not have to
504+
be deallocated and thus we don't need to claim ownership. If the memrefs to
505+
be deallocated are split over multiple dealloc operations (e.g., to avoid
506+
aliasing checks at runtime between the `memref` operands), then the results
507+
have to be manually combined using an `arith.ori` operation and all of them
508+
still require the same list of `retained` memref operands unless the
509+
(potentially empty) set of aliasing memrefs can be determined statically. In
510+
that case, the `updatedCondition` operand can be replaced accordingly (e.g.,
511+
by a canonicalizer).
495512

496513
Example:
497514
```mlir
498-
%0:2 = bufferization.dealloc %a0, %a1 if %cond0, %cond1 retain %r0, %r1 :
499-
memref<2xf32>, memref<4xi32> retain memref<?xf32>, memref<f64>
515+
%0:3 = bufferization.dealloc (%a0, %a1 : memref<2xf32>, memref<4xi32>)
516+
if (%cond0, %cond1) retain (%r0, %r1, %r2 : memref<?xf32>, memref<f64>,
517+
memref<2xi32>)
500518
```
501-
Deallocation will be called on `%a0` if `%cond0` is 'true' and neither `%r0`
502-
or `%r1` are aliases of `%a0`. `%a1` will be deallocated when `%cond1` is
503-
set to 'true' and none of `%r0`, %r1` and `%a0` are aliases.
519+
Deallocation will be called on `%a0` if `%cond0` is 'true' and neither
520+
`%r0`, `%r1`, or `%r2` are aliases of `%a0`. `%a1` will be deallocated when
521+
`%cond1` is set to 'true' and none of `%r0`, %r1`, `%r2`, and `%a0` are
522+
aliases.
504523
}];
505524

506525
let arguments = (ins Variadic<AnyRankedOrUnrankedMemRef>:$memrefs,

mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp

Lines changed: 269 additions & 139 deletions
Large diffs are not rendered by default.

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 139 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,8 @@ LogicalResult DeallocOp::inferReturnTypes(
755755
ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
756756
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
757757
DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
758-
inferredReturnTypes = SmallVector<Type>(adaptor.getConditions().getTypes());
758+
inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
759+
IntegerType::get(context, 1));
759760
return success();
760761
}
761762

@@ -766,44 +767,46 @@ LogicalResult DeallocOp::verify() {
766767
return success();
767768
}
768769

770+
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
771+
ArrayRef<Value> memrefs,
772+
ArrayRef<Value> conditions,
773+
PatternRewriter &rewriter) {
774+
if (deallocOp.getMemrefs() == memrefs)
775+
return failure();
776+
777+
rewriter.updateRootInPlace(deallocOp, [&]() {
778+
deallocOp.getMemrefsMutable().assign(memrefs);
779+
deallocOp.getConditionsMutable().assign(conditions);
780+
});
781+
return success();
782+
}
783+
769784
namespace {
770785

771-
/// Remove duplicate values in the list of retained memrefs as well as the list
772-
/// of memrefs to be deallocated. For the latter, we need to make sure the
773-
/// corresponding condition values match as well, or otherwise have to combine
774-
/// them (by computing the disjunction of them).
786+
/// Remove duplicate values in the list of memrefs to be deallocated. We need to
787+
/// make sure the corresponding condition value is updated accordingly since
788+
/// their two conditions might not cover the same set of cases. In that case, we
789+
/// have to combine them (by computing the disjunction of them).
775790
/// Example:
776791
/// ```mlir
777-
/// %0:2 = bufferization.dealloc (%arg0, %arg0 : ...)
778-
/// if (%arg1, %arg2)
779-
/// retain (%arg3, %arg3 : ...)
792+
/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
780793
/// ```
781794
/// is canonicalized to
782795
/// ```mlir
783796
/// %0 = arith.ori %arg1, %arg2 : i1
784-
/// %1 = bufferization.dealloc (%arg0 : memref<2xi32>)
785-
/// if (%0)
786-
/// retain (%arg3 : memref<2xi32>)
797+
/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
787798
/// ```
788-
struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
799+
struct DeallocRemoveDuplicateDeallocMemrefs
800+
: public OpRewritePattern<DeallocOp> {
789801
using OpRewritePattern<DeallocOp>::OpRewritePattern;
790802

791803
LogicalResult matchAndRewrite(DeallocOp deallocOp,
792804
PatternRewriter &rewriter) const override {
793805
// Unique memrefs to be deallocated.
794-
DenseSet<Value> retained(deallocOp.getRetained().begin(),
795-
deallocOp.getRetained().end());
796806
DenseMap<Value, unsigned> memrefToCondition;
797-
SmallVector<Value> newMemrefs, newConditions, newRetained;
798-
SmallVector<int32_t> resultIndices(deallocOp.getMemrefs().size(), -1);
807+
SmallVector<Value> newMemrefs, newConditions;
799808
for (auto [i, memref, cond] :
800809
llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
801-
if (retained.contains(memref)) {
802-
rewriter.replaceAllUsesWith(deallocOp.getResult(i),
803-
deallocOp.getConditions()[i]);
804-
continue;
805-
}
806-
807810
if (memrefToCondition.count(memref)) {
808811
// If the dealloc conditions don't match, we need to make sure that the
809812
// dealloc happens on the union of cases.
@@ -816,50 +819,133 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
816819
newMemrefs.push_back(memref);
817820
newConditions.push_back(cond);
818821
}
819-
resultIndices[i] = memrefToCondition[memref];
820822
}
821823

824+
// Return failure if we don't change anything such that we don't run into an
825+
// infinite loop of pattern applications.
826+
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
827+
rewriter);
828+
}
829+
};
830+
831+
/// Remove duplicate values in the list of retained memrefs. We need to make
832+
/// sure the corresponding result condition value is replaced properly.
833+
/// Example:
834+
/// ```mlir
835+
/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
836+
/// ```
837+
/// is canonicalized to
838+
/// ```mlir
839+
/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
840+
/// ```
841+
struct DeallocRemoveDuplicateRetainedMemrefs
842+
: public OpRewritePattern<DeallocOp> {
843+
using OpRewritePattern<DeallocOp>::OpRewritePattern;
844+
845+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
846+
PatternRewriter &rewriter) const override {
822847
// Unique retained values
823-
DenseSet<Value> seen;
848+
DenseMap<Value, unsigned> seen;
849+
SmallVector<Value> newRetained;
850+
SmallVector<unsigned> resultReplacementIdx;
851+
unsigned i = 0;
824852
for (auto retained : deallocOp.getRetained()) {
825-
if (!seen.contains(retained)) {
826-
seen.insert(retained);
827-
newRetained.push_back(retained);
853+
if (seen.count(retained)) {
854+
resultReplacementIdx.push_back(seen[retained]);
855+
continue;
828856
}
857+
858+
seen[retained] = i;
859+
newRetained.push_back(retained);
860+
resultReplacementIdx.push_back(i++);
829861
}
830862

831863
// Return failure if we don't change anything such that we don't run into an
832864
// infinite loop of pattern applications.
833-
if (newConditions.size() == deallocOp.getConditions().size() &&
834-
newRetained.size() == deallocOp.getRetained().size())
865+
if (newRetained.size() == deallocOp.getRetained().size())
835866
return failure();
836867

837868
// We need to create a new op because the number of results is always the
838869
// same as the number of condition operands.
839-
auto newDealloc = rewriter.create<DeallocOp>(deallocOp.getLoc(), newMemrefs,
840-
newConditions, newRetained);
841-
for (auto [i, newIdx] : llvm::enumerate(resultIndices))
842-
if (newIdx != -1)
843-
rewriter.replaceAllUsesWith(deallocOp.getResult(i),
844-
newDealloc.getResult(newIdx));
845-
846-
rewriter.eraseOp(deallocOp);
870+
auto newDeallocOp =
871+
rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
872+
deallocOp.getConditions(), newRetained);
873+
SmallVector<Value> replacements(
874+
llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
875+
return newDeallocOp.getUpdatedConditions()[idx];
876+
}));
877+
rewriter.replaceOp(deallocOp, replacements);
847878
return success();
848879
}
849880
};
850881

882+
/// Remove memrefs to be deallocated that are also present in the retained list
883+
/// since they will always alias and thus never actually be deallocated.
884+
/// Example:
885+
/// ```mlir
886+
/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...)
887+
/// ```
888+
/// is canonicalized to
889+
/// ```mlir
890+
/// %0 = bufferization.dealloc retain (%arg0 : ...)
891+
/// ```
892+
struct DeallocRemoveDeallocMemrefsContainedInRetained
893+
: public OpRewritePattern<DeallocOp> {
894+
using OpRewritePattern<DeallocOp>::OpRewritePattern;
895+
896+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
897+
PatternRewriter &rewriter) const override {
898+
// Unique memrefs to be deallocated.
899+
DenseMap<Value, unsigned> retained;
900+
for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
901+
retained[ret] = i;
902+
903+
// There must not be any duplicates in the retain list anymore because we
904+
// would miss updating one of the result values otherwise.
905+
if (retained.size() != deallocOp.getRetained().size())
906+
return failure();
907+
908+
SmallVector<Value> newMemrefs, newConditions;
909+
for (auto [memref, cond] :
910+
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
911+
if (retained.contains(memref)) {
912+
rewriter.setInsertionPointAfter(deallocOp);
913+
auto orOp = rewriter.create<arith::OrIOp>(
914+
deallocOp.getLoc(),
915+
deallocOp.getUpdatedConditions()[retained[memref]], cond);
916+
rewriter.replaceAllUsesExcept(
917+
deallocOp.getUpdatedConditions()[retained[memref]],
918+
orOp.getResult(), orOp);
919+
continue;
920+
}
921+
922+
newMemrefs.push_back(memref);
923+
newConditions.push_back(cond);
924+
}
925+
926+
// Return failure if we don't change anything such that we don't run into an
927+
// infinite loop of pattern applications.
928+
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
929+
rewriter);
930+
}
931+
};
932+
851933
/// Erase deallocation operations where the variadic list of memrefs to
852-
/// deallocate is emtpy. Example:
934+
/// deallocate is empty. Example:
853935
/// ```mlir
854-
/// bufferization.dealloc retain (%arg0: memref<2xi32>)
936+
/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
855937
/// ```
856938
struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
857939
using OpRewritePattern<DeallocOp>::OpRewritePattern;
858940

859941
LogicalResult matchAndRewrite(DeallocOp deallocOp,
860942
PatternRewriter &rewriter) const override {
861943
if (deallocOp.getMemrefs().empty()) {
862-
rewriter.eraseOp(deallocOp);
944+
Value constFalse = rewriter.create<arith::ConstantOp>(
945+
deallocOp.getLoc(), rewriter.getBoolAttr(false));
946+
rewriter.replaceOp(
947+
deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
948+
constFalse));
863949
return success();
864950
}
865951
return failure();
@@ -871,55 +957,40 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
871957
///
872958
/// Example:
873959
/// ```
874-
/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
960+
/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
875961
/// if (%arg2, %false)
876962
/// ```
877963
/// becomes
878964
/// ```
879-
/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
965+
/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
880966
/// ```
881967
struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
882968
using OpRewritePattern<DeallocOp>::OpRewritePattern;
883969

884970
LogicalResult matchAndRewrite(DeallocOp deallocOp,
885971
PatternRewriter &rewriter) const override {
886972
SmallVector<Value> newMemrefs, newConditions;
887-
SmallVector<Value> replacements;
888-
889-
for (auto [res, memref, cond] :
890-
llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(),
891-
deallocOp.getConditions())) {
892-
if (matchPattern(cond, m_Zero())) {
893-
replacements.push_back(cond);
894-
continue;
973+
for (auto [memref, cond] :
974+
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
975+
if (!matchPattern(cond, m_Zero())) {
976+
newMemrefs.push_back(memref);
977+
newConditions.push_back(cond);
895978
}
896-
newMemrefs.push_back(memref);
897-
newConditions.push_back(cond);
898-
replacements.push_back({});
899979
}
900980

901-
if (newMemrefs.size() == deallocOp.getMemrefs().size())
902-
return failure();
903-
904-
auto newDeallocOp = rewriter.create<DeallocOp>(
905-
deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained());
906-
unsigned i = 0;
907-
for (auto &repl : replacements)
908-
if (!repl)
909-
repl = newDeallocOp.getResult(i++);
910-
911-
rewriter.replaceOp(deallocOp, replacements);
912-
return success();
981+
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
982+
rewriter);
913983
}
914984
};
915985

916986
} // anonymous namespace
917987

918988
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
919989
MLIRContext *context) {
920-
results
921-
.add<DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
922-
context);
990+
results.add<DeallocRemoveDuplicateDeallocMemrefs,
991+
DeallocRemoveDuplicateRetainedMemrefs,
992+
DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
993+
EraseAlwaysFalseDealloc>(context);
923994
}
924995

925996
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)