@@ -116,165 +116,6 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
116
116
}
117
117
};
118
118
119
- // / The base class for lowering realloc op, to support the implementation of
120
- // / realloc via allocation methods that may or may not support alignment.
121
- // / A derived class should provide an implementation of allocateBuffer using
122
- // / the underline allocation methods.
123
- struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
124
- using OpAdaptor = typename memref::ReallocOp::Adaptor;
125
-
126
- ReallocOpLoweringBase (const LLVMTypeConverter &converter)
127
- : AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(),
128
- converter) {}
129
-
130
- // / Allocates the new buffer. Returns the allocated pointer and the
131
- // / aligned pointer.
132
- virtual std::tuple<Value, Value>
133
- allocateBuffer (ConversionPatternRewriter &rewriter, Location loc,
134
- Value sizeBytes, memref::ReallocOp op) const = 0 ;
135
-
136
- LogicalResult
137
- matchAndRewrite (Operation *op, ArrayRef<Value> operands,
138
- ConversionPatternRewriter &rewriter) const final {
139
- auto reallocOp = cast<memref::ReallocOp>(op);
140
- return matchAndRewrite (reallocOp, OpAdaptor (operands, reallocOp), rewriter);
141
- }
142
-
143
- // A `realloc` is converted as follows:
144
- // If new_size > old_size
145
- // 1. allocates a new buffer
146
- // 2. copies the content of the old buffer to the new buffer
147
- // 3. release the old buffer
148
- // 3. updates the buffer pointers in the memref descriptor
149
- // Update the size in the memref descriptor
150
- // Alignment request is handled by allocating `alignment` more bytes than
151
- // requested and shifting the aligned pointer relative to the allocated
152
- // memory.
153
- LogicalResult matchAndRewrite (memref::ReallocOp op, OpAdaptor adaptor,
154
- ConversionPatternRewriter &rewriter) const {
155
- OpBuilder::InsertionGuard guard (rewriter);
156
- Location loc = op.getLoc ();
157
-
158
- auto computeNumElements =
159
- [&](MemRefType type, function_ref<Value ()> getDynamicSize) -> Value {
160
- // Compute number of elements.
161
- Type indexType = ConvertToLLVMPattern::getIndexType ();
162
- Value numElements =
163
- type.isDynamicDim (0 )
164
- ? getDynamicSize ()
165
- : createIndexAttrConstant (rewriter, loc, indexType,
166
- type.getDimSize (0 ));
167
- if (numElements.getType () != indexType)
168
- numElements = typeConverter->materializeTargetConversion (
169
- rewriter, loc, indexType, numElements);
170
- return numElements;
171
- };
172
-
173
- MemRefDescriptor desc (adaptor.getSource ());
174
- Value oldDesc = desc;
175
-
176
- // Split the block right before the current op into two blocks.
177
- Block *currentBlock = rewriter.getInsertionBlock ();
178
- Block *block =
179
- rewriter.splitBlock (currentBlock, rewriter.getInsertionPoint ());
180
- // Add a block argument by creating an empty block with the argument type
181
- // and then merging the block into the empty block.
182
- Block *endBlock = rewriter.createBlock (
183
- block->getParent (), Region::iterator (block), oldDesc.getType (), loc);
184
- rewriter.mergeBlocks (block, endBlock, {});
185
- // Add a new block for the true branch of the conditional statement we will
186
- // add.
187
- Block *trueBlock = rewriter.createBlock (
188
- currentBlock->getParent (), std::next (Region::iterator (currentBlock)));
189
-
190
- rewriter.setInsertionPointToEnd (currentBlock);
191
- Value src = op.getSource ();
192
- auto srcType = dyn_cast<MemRefType>(src.getType ());
193
- Value srcNumElements = computeNumElements (
194
- srcType, [&]() -> Value { return desc.size (rewriter, loc, 0 ); });
195
- auto dstType = cast<MemRefType>(op.getType ());
196
- Value dstNumElements = computeNumElements (
197
- dstType, [&]() -> Value { return op.getDynamicResultSize (); });
198
- Value cond = rewriter.create <LLVM::ICmpOp>(
199
- loc, IntegerType::get (rewriter.getContext (), 1 ),
200
- LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements);
201
- rewriter.create <LLVM::CondBrOp>(loc, cond, trueBlock, ArrayRef<Value>(),
202
- endBlock, ValueRange{oldDesc});
203
-
204
- rewriter.setInsertionPointToStart (trueBlock);
205
- Value sizeInBytes = getSizeInBytes (loc, dstType.getElementType (), rewriter);
206
- // Compute total byte size.
207
- auto dstByteSize =
208
- rewriter.create <LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
209
- // Since the src and dst memref are guarantee to have the same
210
- // element type by the verifier, it is safe here to reuse the
211
- // type size computed from dst memref.
212
- auto srcByteSize =
213
- rewriter.create <LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
214
- // Allocate a new buffer.
215
- auto [dstRawPtr, dstAlignedPtr] =
216
- allocateBuffer (rewriter, loc, dstByteSize, op);
217
- // Copy the data from the old buffer to the new buffer.
218
- Value srcAlignedPtr = desc.alignedPtr (rewriter, loc);
219
- auto toVoidPtr = [&](Value ptr) -> Value {
220
- if (getTypeConverter ()->useOpaquePointers ())
221
- return ptr;
222
- return rewriter.create <LLVM::BitcastOp>(loc, getVoidPtrType (), ptr);
223
- };
224
- rewriter.create <LLVM::MemcpyOp>(loc, toVoidPtr (dstAlignedPtr),
225
- toVoidPtr (srcAlignedPtr), srcByteSize,
226
- /* isVolatile=*/ false );
227
- // Deallocate the old buffer.
228
- LLVM::LLVMFuncOp freeFunc =
229
- getFreeFn (getTypeConverter (), op->getParentOfType <ModuleOp>());
230
- rewriter.create <LLVM::CallOp>(loc, freeFunc,
231
- toVoidPtr (desc.allocatedPtr (rewriter, loc)));
232
- // Replace the old buffer addresses in the MemRefDescriptor with the new
233
- // buffer addresses.
234
- desc.setAllocatedPtr (rewriter, loc, dstRawPtr);
235
- desc.setAlignedPtr (rewriter, loc, dstAlignedPtr);
236
- rewriter.create <LLVM::BrOp>(loc, Value (desc), endBlock);
237
-
238
- rewriter.setInsertionPoint (op);
239
- // Update the memref size.
240
- MemRefDescriptor newDesc (endBlock->getArgument (0 ));
241
- newDesc.setSize (rewriter, loc, 0 , dstNumElements);
242
- rewriter.replaceOp (op, {newDesc});
243
- return success ();
244
- }
245
-
246
- private:
247
- using ConvertToLLVMPattern::matchAndRewrite;
248
- };
249
-
250
- struct ReallocOpLowering : public ReallocOpLoweringBase {
251
- ReallocOpLowering (const LLVMTypeConverter &converter)
252
- : ReallocOpLoweringBase(converter) {}
253
- std::tuple<Value, Value> allocateBuffer (ConversionPatternRewriter &rewriter,
254
- Location loc, Value sizeBytes,
255
- memref::ReallocOp op) const override {
256
- return allocateBufferManuallyAlign (rewriter, loc, sizeBytes, op,
257
- getAlignment (rewriter, loc, op));
258
- }
259
- };
260
-
261
- struct AlignedReallocOpLowering : public ReallocOpLoweringBase {
262
- AlignedReallocOpLowering (const LLVMTypeConverter &converter)
263
- : ReallocOpLoweringBase(converter) {}
264
- std::tuple<Value, Value> allocateBuffer (ConversionPatternRewriter &rewriter,
265
- Location loc, Value sizeBytes,
266
- memref::ReallocOp op) const override {
267
- Value ptr = allocateBufferAutoAlign (
268
- rewriter, loc, sizeBytes, op, &defaultLayout,
269
- alignedAllocationGetAlignment (rewriter, loc, op, &defaultLayout));
270
- return std::make_tuple (ptr, ptr);
271
- }
272
-
273
- private:
274
- // / Default layout to use in absence of the corresponding analysis.
275
- DataLayout defaultLayout;
276
- };
277
-
278
119
struct AllocaScopeOpLowering
279
120
: public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
280
121
using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
@@ -1899,11 +1740,9 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1899
1740
// clang-format on
1900
1741
auto allocLowering = converter.getOptions ().allocLowering ;
1901
1742
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1902
- patterns.add <AlignedAllocOpLowering, AlignedReallocOpLowering,
1903
- DeallocOpLowering>(converter);
1743
+ patterns.add <AlignedAllocOpLowering, DeallocOpLowering>(converter);
1904
1744
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1905
- patterns.add <AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
1906
- converter);
1745
+ patterns.add <AllocOpLowering, DeallocOpLowering>(converter);
1907
1746
}
1908
1747
1909
1748
namespace {
0 commit comments