Skip to content

Commit e332c22

Browse files
[mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.
Differential revision: https://reviews.llvm.org/D96488
1 parent 19b4d3c commit e332c22

File tree

5 files changed

+228
-109
lines changed

5 files changed

+228
-109
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//===- FunctionCallUtils.h - Utilities for C function calls -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares helper functions to call common simple C functions in
10+
// LLVMIR (e.g. among others to support printing and debugging).
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
15+
#define MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
16+
17+
#include "mlir/IR/Operation.h"
18+
#include "mlir/Support/LLVM.h"
19+
20+
namespace mlir {
21+
class Location;
22+
class ModuleOp;
23+
class OpBuilder;
24+
class Operation;
25+
class Type;
26+
class ValueRange;
27+
28+
namespace LLVM {
29+
class LLVMFuncOp;
30+
31+
/// Helper functions to lookup or create the declaration for commonly used
32+
/// external C function calls. Such ops can then be invoked by creating a CallOp
33+
/// with the proper arguments via `createLLVMCall`.
34+
/// The list of functions provided here must be implemented separately (e.g. as
35+
/// part of a support runtime library or as part of the libc).
36+
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
37+
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
38+
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
39+
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
40+
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
41+
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
42+
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
43+
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
44+
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
45+
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
46+
Type indexType);
47+
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
48+
49+
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
50+
LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
51+
ArrayRef<Type> paramTypes = {},
52+
Type resultType = {});
53+
54+
/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`.
55+
Operation::result_range createLLVMCall(OpBuilder &b, Location loc,
56+
LLVM::LLVMFuncOp fn,
57+
ValueRange args = {},
58+
ArrayRef<Type> resultTypes = {});
59+
60+
} // namespace LLVM
61+
} // namespace mlir
62+
63+
#endif // MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "../PassDetail.h"
1515
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
1616
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
17+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1718
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1819
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1920
#include "mlir/IR/Attributes.h"
@@ -1793,31 +1794,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
17931794
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
17941795
}
17951796

1796-
// Creates a call to an allocation function with params and casts the
1797-
// resulting void pointer to ptrType.
1798-
Value createAllocCall(Location loc, StringRef name, Type ptrType,
1799-
ArrayRef<Value> params, ModuleOp module,
1800-
ConversionPatternRewriter &rewriter) const {
1801-
SmallVector<Type, 2> paramTypes;
1802-
auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1803-
if (!allocFuncOp) {
1804-
for (Value param : params)
1805-
paramTypes.push_back(param.getType());
1806-
auto allocFuncType =
1807-
LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes);
1808-
OpBuilder::InsertionGuard guard(rewriter);
1809-
rewriter.setInsertionPointToStart(module.getBody());
1810-
allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
1811-
name, allocFuncType);
1812-
}
1813-
auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp);
1814-
auto allocatedPtr = rewriter
1815-
.create<LLVM::CallOp>(loc, getVoidPtrType(),
1816-
allocFuncSymbol, params)
1817-
.getResult(0);
1818-
return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
1819-
}
1820-
18211797
/// Allocates the underlying buffer. Returns the allocated pointer and the
18221798
/// aligned pointer.
18231799
virtual std::tuple<Value, Value>
@@ -1909,9 +1885,12 @@ struct AllocOpLowering : public AllocLikeOpLowering {
19091885
// Allocate the underlying buffer and store a pointer to it in the MemRef
19101886
// descriptor.
19111887
Type elementPtrType = this->getElementPtrType(memRefType);
1888+
auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
1889+
allocOp->getParentOfType<ModuleOp>(), getIndexType());
1890+
auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
1891+
getVoidPtrType());
19121892
Value allocatedPtr =
1913-
createAllocCall(loc, "malloc", elementPtrType, {sizeBytes},
1914-
allocOp->getParentOfType<ModuleOp>(), rewriter);
1893+
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
19151894

19161895
Value alignedPtr = allocatedPtr;
19171896
if (alignment) {
@@ -1991,9 +1970,13 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
19911970
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
19921971

19931972
Type elementPtrType = this->getElementPtrType(memRefType);
1994-
Value allocatedPtr = createAllocCall(
1995-
loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes},
1996-
allocOp->getParentOfType<ModuleOp>(), rewriter);
1973+
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
1974+
allocOp->getParentOfType<ModuleOp>(), getIndexType());
1975+
auto results =
1976+
createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
1977+
getVoidPtrType());
1978+
Value allocatedPtr =
1979+
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
19971980

19981981
return std::make_tuple(allocatedPtr, allocatedPtr);
19991982
}
@@ -2056,31 +2039,17 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
20562039

20572040
// Get frequently used types.
20582041
MLIRContext *context = builder.getContext();
2059-
auto voidType = LLVM::LLVMVoidType::get(context);
20602042
Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
20612043
auto i1Type = IntegerType::get(context, 1);
20622044
Type indexType = typeConverter.getIndexType();
20632045

20642046
// Find the malloc and free, or declare them if necessary.
20652047
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
2066-
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
2067-
if (!mallocFunc && toDynamic) {
2068-
OpBuilder::InsertionGuard guard(builder);
2069-
builder.setInsertionPointToStart(module.getBody());
2070-
mallocFunc = builder.create<LLVM::LLVMFuncOp>(
2071-
builder.getUnknownLoc(), "malloc",
2072-
LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType),
2073-
/*isVarArg=*/false));
2074-
}
2075-
auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
2076-
if (!freeFunc && !toDynamic) {
2077-
OpBuilder::InsertionGuard guard(builder);
2078-
builder.setInsertionPointToStart(module.getBody());
2079-
freeFunc = builder.create<LLVM::LLVMFuncOp>(
2080-
builder.getUnknownLoc(), "free",
2081-
LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType),
2082-
/*isVarArg=*/false));
2083-
}
2048+
LLVM::LLVMFuncOp freeFunc, mallocFunc;
2049+
if (toDynamic)
2050+
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
2051+
if (!toDynamic)
2052+
freeFunc = LLVM::lookupOrCreateFreeFn(module);
20842053

20852054
// Initialize shared constants.
20862055
Value zero =
@@ -2217,17 +2186,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
22172186
DeallocOp::Adaptor transformed(operands);
22182187

22192188
// Insert the `free` declaration if it is not already present.
2220-
auto freeFunc =
2221-
op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
2222-
if (!freeFunc) {
2223-
OpBuilder::InsertionGuard guard(rewriter);
2224-
rewriter.setInsertionPointToStart(
2225-
op->getParentOfType<ModuleOp>().getBody());
2226-
freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
2227-
rewriter.getUnknownLoc(), "free",
2228-
LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType()));
2229-
}
2230-
2189+
auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
22312190
MemRefDescriptor memref(transformed.memref());
22322191
Value casted = rewriter.create<LLVM::BitcastOp>(
22332192
op.getLoc(), getVoidPtrType(),

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 20 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
1212
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1314
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1415
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1516
#include "mlir/Dialect/Vector/VectorOps.h"
@@ -1311,11 +1312,14 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
13111312
Type eltType = vectorType ? vectorType.getElementType() : printType;
13121313
Operation *printer;
13131314
if (eltType.isF32()) {
1314-
printer = getPrintFloat(printOp);
1315+
printer =
1316+
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
13151317
} else if (eltType.isF64()) {
1316-
printer = getPrintDouble(printOp);
1318+
printer =
1319+
LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
13171320
} else if (eltType.isIndex()) {
1318-
printer = getPrintU64(printOp);
1321+
printer =
1322+
LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
13191323
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
13201324
// Integers need a zero or sign extension on the operand
13211325
// (depending on the source type) as well as a signed or
@@ -1325,7 +1329,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
13251329
if (width <= 64) {
13261330
if (width < 64)
13271331
conversion = PrintConversion::ZeroExt64;
1328-
printer = getPrintU64(printOp);
1332+
printer = LLVM::lookupOrCreatePrintU64Fn(
1333+
printOp->getParentOfType<ModuleOp>());
13291334
} else {
13301335
return failure();
13311336
}
@@ -1338,7 +1343,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
13381343
conversion = PrintConversion::ZeroExt64;
13391344
else if (width < 64)
13401345
conversion = PrintConversion::SignExt64;
1341-
printer = getPrintI64(printOp);
1346+
printer = LLVM::lookupOrCreatePrintI64Fn(
1347+
printOp->getParentOfType<ModuleOp>());
13421348
} else {
13431349
return failure();
13441350
}
@@ -1351,7 +1357,9 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
13511357
int64_t rank = vectorType ? vectorType.getRank() : 0;
13521358
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
13531359
conversion);
1354-
emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1360+
emitCall(rewriter, printOp->getLoc(),
1361+
LLVM::lookupOrCreatePrintNewlineFn(
1362+
printOp->getParentOfType<ModuleOp>()));
13551363
rewriter.eraseOp(printOp);
13561364
return success();
13571365
}
@@ -1386,8 +1394,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
13861394
return;
13871395
}
13881396

1389-
emitCall(rewriter, loc, getPrintOpen(op));
1390-
Operation *printComma = getPrintComma(op);
1397+
emitCall(rewriter, loc,
1398+
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1399+
Operation *printComma =
1400+
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
13911401
int64_t dim = vectorType.getDimSize(0);
13921402
for (int64_t d = 0; d < dim; ++d) {
13931403
auto reducedType =
@@ -1401,7 +1411,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
14011411
if (d != dim - 1)
14021412
emitCall(rewriter, loc, printComma);
14031413
}
1404-
emitCall(rewriter, loc, getPrintClose(op));
1414+
emitCall(rewriter, loc,
1415+
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
14051416
}
14061417

14071418
// Helper to emit a call.
@@ -1410,46 +1421,6 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
14101421
rewriter.create<LLVM::CallOp>(loc, TypeRange(),
14111422
rewriter.getSymbolRefAttr(ref), params);
14121423
}
1413-
1414-
// Helper for printer method declaration (first hit) and lookup.
1415-
static Operation *getPrint(Operation *op, StringRef name,
1416-
ArrayRef<Type> params) {
1417-
auto module = op->getParentOfType<ModuleOp>();
1418-
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1419-
if (func)
1420-
return func;
1421-
OpBuilder moduleBuilder(module.getBodyRegion());
1422-
return moduleBuilder.create<LLVM::LLVMFuncOp>(
1423-
op->getLoc(), name,
1424-
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
1425-
params));
1426-
}
1427-
1428-
// Helpers for method names.
1429-
Operation *getPrintI64(Operation *op) const {
1430-
return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
1431-
}
1432-
Operation *getPrintU64(Operation *op) const {
1433-
return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
1434-
}
1435-
Operation *getPrintFloat(Operation *op) const {
1436-
return getPrint(op, "printF32", Float32Type::get(op->getContext()));
1437-
}
1438-
Operation *getPrintDouble(Operation *op) const {
1439-
return getPrint(op, "printF64", Float64Type::get(op->getContext()));
1440-
}
1441-
Operation *getPrintOpen(Operation *op) const {
1442-
return getPrint(op, "printOpen", {});
1443-
}
1444-
Operation *getPrintClose(Operation *op) const {
1445-
return getPrint(op, "printClose", {});
1446-
}
1447-
Operation *getPrintComma(Operation *op) const {
1448-
return getPrint(op, "printComma", {});
1449-
}
1450-
Operation *getPrintNewline(Operation *op) const {
1451-
return getPrint(op, "printNewline", {});
1452-
}
14531424
};
14541425

14551426
/// Progressive lowering of ExtractStridedSliceOp to either:

mlir/lib/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_subdirectory(Transforms)
22

33
add_mlir_dialect_library(MLIRLLVMIR
4+
IR/FunctionCallUtils.cpp
45
IR/LLVMDialect.cpp
56
IR/LLVMTypes.cpp
67
IR/LLVMTypeSyntax.cpp

0 commit comments

Comments
 (0)