Skip to content

Commit 8037deb

Browse files
committed
[mlir][memref] Add pass to expand realloc operations, simplify lowering to LLVM
There are two motivations for this change: 1. It considerably simplifies adding support for the realloc operation to the new buffer deallocation pass by lowering the realloc such that no deallocation operation is inserted and the deallocation pass itself can insert that dealloc 2. The lowering is expressed on a higher level and thus easier to understand, and the lowerings of the memref operations it is composed of don't have to be duplicated in the MemRefToLLVM lowering (also see discussion in https://reviews.llvm.org/D133424) Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D159430
1 parent 12a7897 commit 8037deb

File tree

13 files changed

+346
-393
lines changed

13 files changed

+346
-393
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@ class AffineDialect;
2121
class ModuleOp;
2222

2323
namespace func {
24+
namespace arith {
25+
class ArithDialect;
26+
} // namespace arith
2427
class FuncDialect;
2528
} // namespace func
29+
namespace scf {
30+
class SCFDialect;
31+
} // namespace scf
2632
namespace tensor {
2733
class TensorDialect;
2834
} // namespace tensor
@@ -67,6 +73,10 @@ std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
6773
/// easier to reason about operations.
6874
std::unique_ptr<Pass> createExpandStridedMetadataPass();
6975

76+
/// Creates an operation pass to expand `memref.realloc` operations into their
77+
/// components.
78+
std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
79+
7080
//===----------------------------------------------------------------------===//
7181
// Registration
7282
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,48 @@ def ExpandStridedMetadata : Pass<"expand-strided-metadata"> {
202202
"affine::AffineDialect", "memref::MemRefDialect"
203203
];
204204
}
205+
206+
def ExpandRealloc : Pass<"expand-realloc"> {
207+
let summary = "Expand memref.realloc operations into its components";
208+
let description = [{
209+
The `memref.realloc` operation performs a conditional allocation and copy to
210+
increase the size of a buffer if necessary. This pass converts a `realloc`
211+
operation into this sequence of simpler operations such that other passes
212+
at a later stage in the compilation pipeline do not have to consider the
213+
`realloc` operation anymore (e.g., the buffer deallocation pass and the
214+
conversion pass to LLVM).
215+
216+
Example of an expansion:
217+
```mlir
218+
%realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
219+
```
220+
is expanded to
221+
```mlir
222+
%c0 = arith.constant 0 : index
223+
%dim = memref.dim %alloc, %c0 : memref<?xf32>
224+
%is_old_smaller = arith.cmpi ult, %dim, %arg1
225+
%realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
226+
%new_alloc = memref.alloc(%size) : memref<?xf32>
227+
%subview = memref.subview %new_alloc[0] [%dim] [1]
228+
memref.copy %alloc, %subview
229+
memref.dealloc %alloc
230+
scf.yield %alloc_0 : memref<?xf32>
231+
} else {
232+
%reinterpret_cast = memref.reinterpret_cast %alloc to
233+
offset: [0], sizes: [%size], strides: [1]
234+
scf.yield %reinterpret_cast : memref<?xf32>
235+
}
236+
```
237+
}];
238+
let options = [
239+
Option<"emitDeallocs", "emit-deallocs", "bool", /*default=*/"true",
240+
"Emit deallocation operations for the original MemRef">,
241+
];
242+
let constructor = "mlir::memref::createExpandReallocPass()";
243+
let dependentDialects = [
244+
"arith::ArithDialect", "scf::SCFDialect", "memref::MemRefDialect"
245+
];
246+
}
247+
205248
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
206249

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
6363
/// `memref.extract_strided_metadata` of its source.
6464
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns);
6565

66+
/// Appends patterns for expanding `memref.realloc` operations.
67+
void populateExpandReallocPatterns(RewritePatternSet &patterns,
68+
bool emitDeallocs = true);
69+
6670
/// Appends patterns for emulating wide integer memref operations with ops over
6771
/// narrower integer types.
6872
void populateMemRefWideIntEmulationPatterns(

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 2 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -116,165 +116,6 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
116116
}
117117
};
118118

119-
/// The base class for lowering realloc op, to support the implementation of
120-
/// realloc via allocation methods that may or may not support alignment.
121-
/// A derived class should provide an implementation of allocateBuffer using
122-
/// the underline allocation methods.
123-
struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
124-
using OpAdaptor = typename memref::ReallocOp::Adaptor;
125-
126-
ReallocOpLoweringBase(const LLVMTypeConverter &converter)
127-
: AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(),
128-
converter) {}
129-
130-
/// Allocates the new buffer. Returns the allocated pointer and the
131-
/// aligned pointer.
132-
virtual std::tuple<Value, Value>
133-
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
134-
Value sizeBytes, memref::ReallocOp op) const = 0;
135-
136-
LogicalResult
137-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
138-
ConversionPatternRewriter &rewriter) const final {
139-
auto reallocOp = cast<memref::ReallocOp>(op);
140-
return matchAndRewrite(reallocOp, OpAdaptor(operands, reallocOp), rewriter);
141-
}
142-
143-
// A `realloc` is converted as follows:
144-
// If new_size > old_size
145-
// 1. allocates a new buffer
146-
// 2. copies the content of the old buffer to the new buffer
147-
// 3. release the old buffer
148-
// 3. updates the buffer pointers in the memref descriptor
149-
// Update the size in the memref descriptor
150-
// Alignment request is handled by allocating `alignment` more bytes than
151-
// requested and shifting the aligned pointer relative to the allocated
152-
// memory.
153-
LogicalResult matchAndRewrite(memref::ReallocOp op, OpAdaptor adaptor,
154-
ConversionPatternRewriter &rewriter) const {
155-
OpBuilder::InsertionGuard guard(rewriter);
156-
Location loc = op.getLoc();
157-
158-
auto computeNumElements =
159-
[&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
160-
// Compute number of elements.
161-
Type indexType = ConvertToLLVMPattern::getIndexType();
162-
Value numElements =
163-
type.isDynamicDim(0)
164-
? getDynamicSize()
165-
: createIndexAttrConstant(rewriter, loc, indexType,
166-
type.getDimSize(0));
167-
if (numElements.getType() != indexType)
168-
numElements = typeConverter->materializeTargetConversion(
169-
rewriter, loc, indexType, numElements);
170-
return numElements;
171-
};
172-
173-
MemRefDescriptor desc(adaptor.getSource());
174-
Value oldDesc = desc;
175-
176-
// Split the block right before the current op into two blocks.
177-
Block *currentBlock = rewriter.getInsertionBlock();
178-
Block *block =
179-
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
180-
// Add a block argument by creating an empty block with the argument type
181-
// and then merging the block into the empty block.
182-
Block *endBlock = rewriter.createBlock(
183-
block->getParent(), Region::iterator(block), oldDesc.getType(), loc);
184-
rewriter.mergeBlocks(block, endBlock, {});
185-
// Add a new block for the true branch of the conditional statement we will
186-
// add.
187-
Block *trueBlock = rewriter.createBlock(
188-
currentBlock->getParent(), std::next(Region::iterator(currentBlock)));
189-
190-
rewriter.setInsertionPointToEnd(currentBlock);
191-
Value src = op.getSource();
192-
auto srcType = dyn_cast<MemRefType>(src.getType());
193-
Value srcNumElements = computeNumElements(
194-
srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); });
195-
auto dstType = cast<MemRefType>(op.getType());
196-
Value dstNumElements = computeNumElements(
197-
dstType, [&]() -> Value { return op.getDynamicResultSize(); });
198-
Value cond = rewriter.create<LLVM::ICmpOp>(
199-
loc, IntegerType::get(rewriter.getContext(), 1),
200-
LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements);
201-
rewriter.create<LLVM::CondBrOp>(loc, cond, trueBlock, ArrayRef<Value>(),
202-
endBlock, ValueRange{oldDesc});
203-
204-
rewriter.setInsertionPointToStart(trueBlock);
205-
Value sizeInBytes = getSizeInBytes(loc, dstType.getElementType(), rewriter);
206-
// Compute total byte size.
207-
auto dstByteSize =
208-
rewriter.create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
209-
// Since the src and dst memref are guarantee to have the same
210-
// element type by the verifier, it is safe here to reuse the
211-
// type size computed from dst memref.
212-
auto srcByteSize =
213-
rewriter.create<LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
214-
// Allocate a new buffer.
215-
auto [dstRawPtr, dstAlignedPtr] =
216-
allocateBuffer(rewriter, loc, dstByteSize, op);
217-
// Copy the data from the old buffer to the new buffer.
218-
Value srcAlignedPtr = desc.alignedPtr(rewriter, loc);
219-
auto toVoidPtr = [&](Value ptr) -> Value {
220-
if (getTypeConverter()->useOpaquePointers())
221-
return ptr;
222-
return rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
223-
};
224-
rewriter.create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
225-
toVoidPtr(srcAlignedPtr), srcByteSize,
226-
/*isVolatile=*/false);
227-
// Deallocate the old buffer.
228-
LLVM::LLVMFuncOp freeFunc =
229-
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
230-
rewriter.create<LLVM::CallOp>(loc, freeFunc,
231-
toVoidPtr(desc.allocatedPtr(rewriter, loc)));
232-
// Replace the old buffer addresses in the MemRefDescriptor with the new
233-
// buffer addresses.
234-
desc.setAllocatedPtr(rewriter, loc, dstRawPtr);
235-
desc.setAlignedPtr(rewriter, loc, dstAlignedPtr);
236-
rewriter.create<LLVM::BrOp>(loc, Value(desc), endBlock);
237-
238-
rewriter.setInsertionPoint(op);
239-
// Update the memref size.
240-
MemRefDescriptor newDesc(endBlock->getArgument(0));
241-
newDesc.setSize(rewriter, loc, 0, dstNumElements);
242-
rewriter.replaceOp(op, {newDesc});
243-
return success();
244-
}
245-
246-
private:
247-
using ConvertToLLVMPattern::matchAndRewrite;
248-
};
249-
250-
struct ReallocOpLowering : public ReallocOpLoweringBase {
251-
ReallocOpLowering(const LLVMTypeConverter &converter)
252-
: ReallocOpLoweringBase(converter) {}
253-
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
254-
Location loc, Value sizeBytes,
255-
memref::ReallocOp op) const override {
256-
return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
257-
getAlignment(rewriter, loc, op));
258-
}
259-
};
260-
261-
struct AlignedReallocOpLowering : public ReallocOpLoweringBase {
262-
AlignedReallocOpLowering(const LLVMTypeConverter &converter)
263-
: ReallocOpLoweringBase(converter) {}
264-
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
265-
Location loc, Value sizeBytes,
266-
memref::ReallocOp op) const override {
267-
Value ptr = allocateBufferAutoAlign(
268-
rewriter, loc, sizeBytes, op, &defaultLayout,
269-
alignedAllocationGetAlignment(rewriter, loc, op, &defaultLayout));
270-
return std::make_tuple(ptr, ptr);
271-
}
272-
273-
private:
274-
/// Default layout to use in absence of the corresponding analysis.
275-
DataLayout defaultLayout;
276-
};
277-
278119
struct AllocaScopeOpLowering
279120
: public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
280121
using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
@@ -1899,11 +1740,9 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
18991740
// clang-format on
19001741
auto allocLowering = converter.getOptions().allocLowering;
19011742
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1902-
patterns.add<AlignedAllocOpLowering, AlignedReallocOpLowering,
1903-
DeallocOpLowering>(converter);
1743+
patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
19041744
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1905-
patterns.add<AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
1906-
converter);
1745+
patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
19071746
}
19081747

19091748
namespace {

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
22
BufferizableOpInterfaceImpl.cpp
33
ComposeSubView.cpp
44
ExpandOps.cpp
5+
ExpandRealloc.cpp
56
ExpandStridedMetadata.cpp
67
EmulateWideInt.cpp
78
EmulateNarrowType.cpp

0 commit comments

Comments
 (0)