Skip to content

Commit 5ec38d6

Browse files
committed
[mlir] Wrapped return value of function lookup in FailureOr for error handling
1 parent d26a77d commit 5ec38d6

File tree

10 files changed

+196
-144
lines changed

10 files changed

+196
-144
lines changed

mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace LLVM {
2323
/// Generate IR that prints the given string to stdout.
2424
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
2525
/// have the signature void(char const*). The default function is `printString`.
26-
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
26+
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
2727
StringRef symbolName, StringRef string,
2828
const LLVMTypeConverter &typeConverter,
2929
bool addNewline = true,

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Support/LLVM.h"
19-
#include <optional>
2019

2120
namespace mlir {
2221
class Location;
@@ -29,40 +28,42 @@ class ValueRange;
2928
namespace LLVM {
3029
class LLVMFuncOp;
3130

32-
/// Helper functions to lookup or create the declaration for commonly used
31+
/// Helper functions to look up or create the declaration for commonly used
3332
/// external C function calls. The list of functions provided here must be
3433
/// implemented separately (e.g. as part of a support runtime library or as part
3534
/// of the libc).
36-
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
37-
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
38-
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
39-
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40-
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
41-
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
35+
/// Failure if an unexpected version of function is found.
36+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
37+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
38+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
39+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
41+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
4242
/// Declares a function to print a C-string.
4343
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
4444
/// have the signature void(char const*). The default function is `printString`.
45-
LLVM::LLVMFuncOp
45+
FailureOr<LLVM::LLVMFuncOp>
4646
lookupOrCreatePrintStringFn(Operation *moduleOp,
4747
std::optional<StringRef> runtimeFunctionName = {});
48-
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
49-
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
50-
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
51-
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52-
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
53-
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
48+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
49+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
50+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
51+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
53+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
5454
Type indexType);
55-
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
56-
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
55+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
56+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
5757
Type indexType);
58-
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
58+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
5959
Type indexType);
60-
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
61-
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
60+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
61+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
6262
Type unrankedDescriptorType);
6363

6464
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
65-
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
65+
/// Return a failure if the FuncOp found has unexpected signature.
66+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFn(Operation *moduleOp, StringRef name,
6667
ArrayRef<Type> paramTypes = {},
6768
Type resultType = {}, bool isVarArg = false,
6869
bool isReserved = false);

mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
396396
// Allocate memory for the coroutine frame.
397397
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398398
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399+
if (failed(allocFuncOp))
400+
return failure();
399401
auto coroAlloc = rewriter.create<LLVM::CallOp>(
400-
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
402+
loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
401403

402404
// Begin a coroutine: @llvm.coro.begin.
403405
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
431433
// Free the memory.
432434
auto freeFuncOp =
433435
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
434-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
436+
if (failed(freeFuncOp))
437+
return failure();
438+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
435439
ValueRange(coroMem.getResult()));
436440

437441
return success();

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
6161

6262
// Failed block: Generate IR to print the message and call `abort`.
6363
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64-
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
64+
if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
6565
*getTypeConverter(), /*addNewLine=*/false,
66-
/*runtimeFunctionName=*/"puts");
66+
/*runtimeFunctionName=*/"puts").failed()) {
67+
return failure();
68+
}
6769
if (abortOnFailedAssert) {
6870
// Insert the `abort` declaration if necessary.
6971
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
276276

277277
// Find the malloc and free, or declare them if necessary.
278278
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
279-
LLVM::LLVMFuncOp freeFunc, mallocFunc;
280-
if (toDynamic)
279+
FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
280+
if (toDynamic) {
281281
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
282-
if (!toDynamic)
282+
if (failed(mallocFunc))
283+
return failure();
284+
}
285+
if (!toDynamic) {
283286
freeFunc = LLVM::lookupOrCreateFreeFn(module);
287+
if (failed(freeFunc))
288+
return failure();
289+
}
284290

285291
unsigned unrankedMemrefPos = 0;
286292
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
@@ -293,7 +299,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
293299
// Allocate memory, copy, and free the source if necessary.
294300
Value memory =
295301
toDynamic
296-
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
302+
? builder.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
297303
.getResult()
298304
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
299305
IntegerType::get(getContext(), 8),
@@ -302,7 +308,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
302308
Value source = desc.memRefDescPtr(builder, loc);
303309
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
304310
if (!toDynamic)
305-
builder.create<LLVM::CallOp>(loc, freeFunc, source);
311+
builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
306312

307313
// Create a new descriptor. The same descriptor can be returned multiple
308314
// times, attempting to modify its pointer can lead to memory leaks

mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
2727
return uniqueName;
2828
}
2929

30-
void mlir::LLVM::createPrintStrCall(
30+
LogicalResult mlir::LLVM::createPrintStrCall(
3131
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
3232
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
3333
std::optional<StringRef> runtimeFunctionName) {
@@ -59,8 +59,12 @@ void mlir::LLVM::createPrintStrCall(
5959
SmallVector<LLVM::GEPArg> indices(1, 0);
6060
Value gep =
6161
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
62-
Operation *printer =
63-
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
64-
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
65-
gep);
62+
if (auto printer =
63+
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) {
64+
builder.create<LLVM::CallOp>(loc, TypeRange(),
65+
SymbolRefAttr::get(printer.value()), gep);
66+
} else {
67+
return failure();
68+
}
69+
return success();
6670
}

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
using namespace mlir;
1616

1717
namespace {
18-
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
18+
FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
1919
Operation *module, Type indexType) {
2020
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
2121
if (useGenericFn)
@@ -24,7 +24,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
2424
return LLVM::lookupOrCreateMallocFn(module, indexType);
2525
}
2626

27-
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
27+
FailureOr<LLVM::LLVMFuncOp> getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
2828
Operation *module, Type indexType) {
2929
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
3030

@@ -80,10 +80,11 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
8080
<< " to integer address space "
8181
"failed. Consider adding memory space conversions.";
8282
}
83-
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
83+
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
8484
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
8585
getIndexType());
86-
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
86+
if (failed(allocFuncOp)) return std::make_tuple(Value(), Value());
87+
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
8788

8889
Value allocatedPtr =
8990
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -146,11 +147,12 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
146147
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
147148

148149
Type elementPtrType = this->getElementPtrType(memRefType);
149-
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
150+
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
150151
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
151152
getIndexType());
153+
if (failed(allocFuncOp)) return Value();
152154
auto results = rewriter.create<LLVM::CallOp>(
153-
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
155+
loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
154156

155157
return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
156158
elementPtrType, *getTypeConverter());

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
4242
return !ShapedType::isDynamic(strideOrOffset);
4343
}
4444

45-
LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46-
ModuleOp module) {
45+
FailureOr<LLVM::LLVMFuncOp> getFreeFn(const LLVMTypeConverter *typeConverter,
46+
ModuleOp module) {
4747
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
4848

4949
if (useGenericFn)
@@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
220220
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221221
ConversionPatternRewriter &rewriter) const override {
222222
// Insert the `free` declaration if it is not already present.
223-
LLVM::LLVMFuncOp freeFunc =
223+
auto freeFunc =
224224
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225+
if (failed(freeFunc))
226+
return failure();
225227
Value allocatedPtr;
226228
if (auto unrankedTy =
227229
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
@@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
236238
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
237239
.allocatedPtr(rewriter, op.getLoc());
238240
}
239-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
241+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
242+
allocatedPtr);
240243
return success();
241244
}
242245
};
@@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
838841
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
839842
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
840843
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
841-
rewriter.create<LLVM::CallOp>(loc, copyFn,
844+
if (failed(copyFn))
845+
return failure();
846+
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
842847
ValueRange{elemSize, sourcePtr, targetPtr});
843848

844849
// Restore stack used for descriptors

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,24 +1546,32 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15461546

15471547
auto punct = printOp.getPunctuation();
15481548
if (auto stringLiteral = printOp.getStringLiteral()) {
1549-
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1550-
*stringLiteral, *getTypeConverter(),
1551-
/*addNewline=*/false);
1549+
if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1550+
*stringLiteral, *getTypeConverter(),
1551+
/*addNewline=*/false)
1552+
.failed()) {
1553+
return failure();
1554+
}
15521555
} else if (punct != PrintPunctuation::NoPunctuation) {
1553-
emitCall(rewriter, printOp->getLoc(), [&] {
1554-
switch (punct) {
1555-
case PrintPunctuation::Close:
1556-
return LLVM::lookupOrCreatePrintCloseFn(parent);
1557-
case PrintPunctuation::Open:
1558-
return LLVM::lookupOrCreatePrintOpenFn(parent);
1559-
case PrintPunctuation::Comma:
1560-
return LLVM::lookupOrCreatePrintCommaFn(parent);
1561-
case PrintPunctuation::NewLine:
1562-
return LLVM::lookupOrCreatePrintNewlineFn(parent);
1563-
default:
1564-
llvm_unreachable("unexpected punctuation");
1565-
}
1566-
}());
1556+
if (auto op = [&] -> FailureOr<LLVM::LLVMFuncOp> {
1557+
switch (punct) {
1558+
case PrintPunctuation::Close:
1559+
return LLVM::lookupOrCreatePrintCloseFn(parent);
1560+
case PrintPunctuation::Open:
1561+
return LLVM::lookupOrCreatePrintOpenFn(parent);
1562+
case PrintPunctuation::Comma:
1563+
return LLVM::lookupOrCreatePrintCommaFn(parent);
1564+
case PrintPunctuation::NewLine:
1565+
return LLVM::lookupOrCreatePrintNewlineFn(parent);
1566+
default:
1567+
llvm_unreachable("unexpected punctuation");
1568+
}
1569+
}();
1570+
succeeded(op))
1571+
emitCall(rewriter, printOp->getLoc(), op.value());
1572+
else {
1573+
return failure();
1574+
}
15671575
}
15681576

15691577
rewriter.eraseOp(printOp);
@@ -1588,7 +1596,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15881596

15891597
// Make sure element type has runtime support.
15901598
PrintConversion conversion = PrintConversion::None;
1591-
Operation *printer;
1599+
FailureOr<Operation *> printer;
15921600
if (printType.isF32()) {
15931601
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
15941602
} else if (printType.isF64()) {
@@ -1631,6 +1639,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16311639
} else {
16321640
return failure();
16331641
}
1642+
if (failed(printer))
1643+
return failure();
16341644

16351645
switch (conversion) {
16361646
case PrintConversion::ZeroExt64:
@@ -1648,7 +1658,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16481658
case PrintConversion::None:
16491659
break;
16501660
}
1651-
emitCall(rewriter, loc, printer, value);
1661+
emitCall(rewriter, loc, printer.value(), value);
16521662
return success();
16531663
}
16541664

0 commit comments

Comments
 (0)