|
14 | 14 | #include "../PassDetail.h"
|
15 | 15 | #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
16 | 16 | #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
| 17 | +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
17 | 18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
18 | 19 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
19 | 20 | #include "mlir/IR/Attributes.h"
|
@@ -1793,31 +1794,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
|
1793 | 1794 | return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
|
1794 | 1795 | }
|
1795 | 1796 |
|
1796 |
| - // Creates a call to an allocation function with params and casts the |
1797 |
| - // resulting void pointer to ptrType. |
1798 |
| - Value createAllocCall(Location loc, StringRef name, Type ptrType, |
1799 |
| - ArrayRef<Value> params, ModuleOp module, |
1800 |
| - ConversionPatternRewriter &rewriter) const { |
1801 |
| - SmallVector<Type, 2> paramTypes; |
1802 |
| - auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name); |
1803 |
| - if (!allocFuncOp) { |
1804 |
| - for (Value param : params) |
1805 |
| - paramTypes.push_back(param.getType()); |
1806 |
| - auto allocFuncType = |
1807 |
| - LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); |
1808 |
| - OpBuilder::InsertionGuard guard(rewriter); |
1809 |
| - rewriter.setInsertionPointToStart(module.getBody()); |
1810 |
| - allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), |
1811 |
| - name, allocFuncType); |
1812 |
| - } |
1813 |
| - auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); |
1814 |
| - auto allocatedPtr = rewriter |
1815 |
| - .create<LLVM::CallOp>(loc, getVoidPtrType(), |
1816 |
| - allocFuncSymbol, params) |
1817 |
| - .getResult(0); |
1818 |
| - return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr); |
1819 |
| - } |
1820 |
| - |
1821 | 1797 | /// Allocates the underlying buffer. Returns the allocated pointer and the
|
1822 | 1798 | /// aligned pointer.
|
1823 | 1799 | virtual std::tuple<Value, Value>
|
@@ -1909,9 +1885,12 @@ struct AllocOpLowering : public AllocLikeOpLowering {
|
1909 | 1885 | // Allocate the underlying buffer and store a pointer to it in the MemRef
|
1910 | 1886 | // descriptor.
|
1911 | 1887 | Type elementPtrType = this->getElementPtrType(memRefType);
|
| 1888 | + auto allocFuncOp = LLVM::lookupOrCreateMallocFn( |
| 1889 | + allocOp->getParentOfType<ModuleOp>(), getIndexType()); |
| 1890 | + auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, |
| 1891 | + getVoidPtrType()); |
1912 | 1892 | Value allocatedPtr =
|
1913 |
| - createAllocCall(loc, "malloc", elementPtrType, {sizeBytes}, |
1914 |
| - allocOp->getParentOfType<ModuleOp>(), rewriter); |
| 1893 | + rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); |
1915 | 1894 |
|
1916 | 1895 | Value alignedPtr = allocatedPtr;
|
1917 | 1896 | if (alignment) {
|
@@ -1991,9 +1970,13 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
|
1991 | 1970 | sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
|
1992 | 1971 |
|
1993 | 1972 | Type elementPtrType = this->getElementPtrType(memRefType);
|
1994 |
| - Value allocatedPtr = createAllocCall( |
1995 |
| - loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes}, |
1996 |
| - allocOp->getParentOfType<ModuleOp>(), rewriter); |
| 1973 | + auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( |
| 1974 | + allocOp->getParentOfType<ModuleOp>(), getIndexType()); |
| 1975 | + auto results = |
| 1976 | + createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, |
| 1977 | + getVoidPtrType()); |
| 1978 | + Value allocatedPtr = |
| 1979 | + rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); |
1997 | 1980 |
|
1998 | 1981 | return std::make_tuple(allocatedPtr, allocatedPtr);
|
1999 | 1982 | }
|
@@ -2056,31 +2039,17 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
|
2056 | 2039 |
|
2057 | 2040 | // Get frequently used types.
|
2058 | 2041 | MLIRContext *context = builder.getContext();
|
2059 |
| - auto voidType = LLVM::LLVMVoidType::get(context); |
2060 | 2042 | Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
|
2061 | 2043 | auto i1Type = IntegerType::get(context, 1);
|
2062 | 2044 | Type indexType = typeConverter.getIndexType();
|
2063 | 2045 |
|
2064 | 2046 | // Find the malloc and free, or declare them if necessary.
|
2065 | 2047 | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
|
2066 |
| - auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc"); |
2067 |
| - if (!mallocFunc && toDynamic) { |
2068 |
| - OpBuilder::InsertionGuard guard(builder); |
2069 |
| - builder.setInsertionPointToStart(module.getBody()); |
2070 |
| - mallocFunc = builder.create<LLVM::LLVMFuncOp>( |
2071 |
| - builder.getUnknownLoc(), "malloc", |
2072 |
| - LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType), |
2073 |
| - /*isVarArg=*/false)); |
2074 |
| - } |
2075 |
| - auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free"); |
2076 |
| - if (!freeFunc && !toDynamic) { |
2077 |
| - OpBuilder::InsertionGuard guard(builder); |
2078 |
| - builder.setInsertionPointToStart(module.getBody()); |
2079 |
| - freeFunc = builder.create<LLVM::LLVMFuncOp>( |
2080 |
| - builder.getUnknownLoc(), "free", |
2081 |
| - LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType), |
2082 |
| - /*isVarArg=*/false)); |
2083 |
| - } |
| 2048 | + LLVM::LLVMFuncOp freeFunc, mallocFunc; |
| 2049 | + if (toDynamic) |
| 2050 | + mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); |
| 2051 | + if (!toDynamic) |
| 2052 | + freeFunc = LLVM::lookupOrCreateFreeFn(module); |
2084 | 2053 |
|
2085 | 2054 | // Initialize shared constants.
|
2086 | 2055 | Value zero =
|
@@ -2217,17 +2186,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
|
2217 | 2186 | DeallocOp::Adaptor transformed(operands);
|
2218 | 2187 |
|
2219 | 2188 | // Insert the `free` declaration if it is not already present.
|
2220 |
| - auto freeFunc = |
2221 |
| - op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free"); |
2222 |
| - if (!freeFunc) { |
2223 |
| - OpBuilder::InsertionGuard guard(rewriter); |
2224 |
| - rewriter.setInsertionPointToStart( |
2225 |
| - op->getParentOfType<ModuleOp>().getBody()); |
2226 |
| - freeFunc = rewriter.create<LLVM::LLVMFuncOp>( |
2227 |
| - rewriter.getUnknownLoc(), "free", |
2228 |
| - LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType())); |
2229 |
| - } |
2230 |
| - |
| 2189 | + auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); |
2231 | 2190 | MemRefDescriptor memref(transformed.memref());
|
2232 | 2191 | Value casted = rewriter.create<LLVM::BitcastOp>(
|
2233 | 2192 | op.getLoc(), getVoidPtrType(),
|
|
0 commit comments