Skip to content

Commit 99e2780

Browse files
committed
Reapply "[mlir] allow function type cloning to fail (llvm#137130)"
This reapplies commit 7318074. This reverts commit d504628.
1 parent c8336df commit 99e2780

File tree

11 files changed

+111
-46
lines changed

11 files changed

+111
-46
lines changed

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,12 @@ class AbstractResultOpt
387387
mlir::OpBuilder rewriter(context);
388388
auto resultType = funcTy.getResult(0);
389389
auto argTy = getResultArgumentType(resultType, shouldBoxResult);
390-
func.insertArgument(0u, argTy, {}, loc);
391-
func.eraseResult(0u);
390+
llvm::LogicalResult res = func.insertArgument(0u, argTy, {}, loc);
391+
(void)res;
392+
assert(llvm::succeeded(res) && "failed to insert function argument");
393+
res = func.eraseResult(0u);
394+
(void)res;
395+
assert(llvm::succeeded(res) && "failed to erase function result");
392396
mlir::Value newArg = func.getArgument(0u);
393397
if (mustEmboxResult(resultType, shouldBoxResult)) {
394398
auto bufferType = fir::ReferenceType::get(resultType);

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def LLVMFunctionType : LLVMType<"LLVMFunction", "func"> {
104104
bool isVarArg() const { return getVarArg(); }
105105

106106
/// Returns a clone of this function type with the given argument
107-
/// and result types.
107+
/// and result types. Returns null if the resulting function type would
108+
/// not verify.
108109
LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
109110

110111
/// Returns the result type of the function as an ArrayRef, enabling better

mlir/include/mlir/Interfaces/FunctionInterfaces.td

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -255,79 +255,105 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
255255
BlockArgListType getArguments() { return getFunctionBody().getArguments(); }
256256

257257
/// Insert a single argument of type `argType` with attributes `argAttrs` and
258-
/// location `argLoc` at `argIndex`.
259-
void insertArgument(unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs,
260-
::mlir::Location argLoc) {
261-
insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
258+
/// location `argLoc` at `argIndex`. Returns failure if the function cannot be
259+
/// updated to have the new signature.
260+
::llvm::LogicalResult insertArgument(
261+
unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs,
262+
::mlir::Location argLoc) {
263+
return insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
262264
}
263265

264266
/// Inserts arguments with the listed types, attributes, and locations at the
265267
/// listed indices. `argIndices` must be sorted. Arguments are inserted in the
266268
/// order they are listed, such that arguments with identical index will
267-
/// appear in the same order that they were listed here.
268-
void insertArguments(::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
269-
::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs,
270-
::llvm::ArrayRef<::mlir::Location> argLocs) {
269+
/// appear in the same order that they were listed here. Returns failure if
270+
/// the function cannot be updated to have the new signature.
271+
::llvm::LogicalResult insertArguments(
272+
::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
273+
::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs,
274+
::llvm::ArrayRef<::mlir::Location> argLocs) {
271275
unsigned originalNumArgs = $_op.getNumArguments();
272276
::mlir::Type newType = $_op.getTypeWithArgsAndResults(
273277
argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
278+
if (!newType)
279+
return ::llvm::failure();
274280
::mlir::function_interface_impl::insertFunctionArguments(
275281
$_op, argIndices, argTypes, argAttrs, argLocs,
276282
originalNumArgs, newType);
283+
return ::llvm::success();
277284
}
278285

279-
/// Insert a single result of type `resultType` at `resultIndex`.
280-
void insertResult(unsigned resultIndex, ::mlir::Type resultType,
281-
::mlir::DictionaryAttr resultAttrs) {
282-
insertResults({resultIndex}, {resultType}, {resultAttrs});
286+
/// Insert a single result of type `resultType` at `resultIndex`.Returns
287+
/// failure if the function cannot be updated to have the new signature.
288+
::llvm::LogicalResult insertResult(
289+
unsigned resultIndex, ::mlir::Type resultType,
290+
::mlir::DictionaryAttr resultAttrs) {
291+
return insertResults({resultIndex}, {resultType}, {resultAttrs});
283292
}
284293

285294
/// Inserts results with the listed types at the listed indices.
286295
/// `resultIndices` must be sorted. Results are inserted in the order they are
287296
/// listed, such that results with identical index will appear in the same
288-
/// order that they were listed here.
289-
void insertResults(::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes,
290-
::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) {
297+
/// order that they were listed here. Returns failure if the function
298+
/// cannot be updated to have the new signature.
299+
::llvm::LogicalResult insertResults(
300+
::llvm::ArrayRef<unsigned> resultIndices,
301+
::mlir::TypeRange resultTypes,
302+
::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) {
291303
unsigned originalNumResults = $_op.getNumResults();
292304
::mlir::Type newType = $_op.getTypeWithArgsAndResults(
293305
/*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes);
306+
if (!newType)
307+
return ::llvm::failure();
294308
::mlir::function_interface_impl::insertFunctionResults(
295309
$_op, resultIndices, resultTypes, resultAttrs,
296310
originalNumResults, newType);
311+
return ::llvm::success();
297312
}
298313

299-
/// Erase a single argument at `argIndex`.
300-
void eraseArgument(unsigned argIndex) {
314+
/// Erase a single argument at `argIndex`. Returns failure if the function
315+
/// cannot be updated to have the new signature.
316+
::llvm::LogicalResult eraseArgument(unsigned argIndex) {
301317
::llvm::BitVector argsToErase($_op.getNumArguments());
302318
argsToErase.set(argIndex);
303-
eraseArguments(argsToErase);
319+
return eraseArguments(argsToErase);
304320
}
305321

306-
/// Erases the arguments listed in `argIndices`.
307-
void eraseArguments(const ::llvm::BitVector &argIndices) {
322+
/// Erases the arguments listed in `argIndices`. Returns failure if the
323+
/// function cannot be updated to have the new signature.
324+
::llvm::LogicalResult eraseArguments(const ::llvm::BitVector &argIndices) {
308325
::mlir::Type newType = $_op.getTypeWithoutArgs(argIndices);
326+
if (!newType)
327+
return ::llvm::failure();
309328
::mlir::function_interface_impl::eraseFunctionArguments(
310329
$_op, argIndices, newType);
330+
return ::llvm::success();
311331
}
312332

313-
/// Erase a single result at `resultIndex`.
314-
void eraseResult(unsigned resultIndex) {
333+
/// Erase a single result at `resultIndex`. Returns failure if the function
334+
/// cannot be updated to have the new signature.
335+
LogicalResult eraseResult(unsigned resultIndex) {
315336
::llvm::BitVector resultsToErase($_op.getNumResults());
316337
resultsToErase.set(resultIndex);
317-
eraseResults(resultsToErase);
338+
return eraseResults(resultsToErase);
318339
}
319340

320-
/// Erases the results listed in `resultIndices`.
321-
void eraseResults(const ::llvm::BitVector &resultIndices) {
341+
/// Erases the results listed in `resultIndices`. Returns failure if the
342+
/// function cannot be updated to have the new signature.
343+
::llvm::LogicalResult eraseResults(const ::llvm::BitVector &resultIndices) {
322344
::mlir::Type newType = $_op.getTypeWithoutResults(resultIndices);
345+
if (!newType)
346+
return ::llvm::failure();
323347
::mlir::function_interface_impl::eraseFunctionResults(
324348
$_op, resultIndices, newType);
349+
return ::llvm::success();
325350
}
326351

327352
/// Return the type of this function with the specified arguments and
328353
/// results inserted. This is used to update the function's signature in
329354
/// the `insertArguments` and `insertResults` methods. The arrays must be
330-
/// sorted by increasing index.
355+
/// sorted by increasing index. Return nullptr if the updated type would
356+
/// not be valid.
331357
::mlir::Type getTypeWithArgsAndResults(
332358
::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
333359
::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes) {
@@ -341,7 +367,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
341367

342368
/// Return the type of this function without the specified arguments and
343369
/// results. This is used to update the function's signature in the
344-
/// `eraseArguments` and `eraseResults` methods.
370+
/// `eraseArguments` and `eraseResults` methods. Return nullptr if the
371+
/// updated type would not be valid.
345372
::mlir::Type getTypeWithoutArgsAndResults(
346373
const ::llvm::BitVector &argIndices, const ::llvm::BitVector &resultIndices) {
347374
::llvm::SmallVector<::mlir::Type> argStorage, resultStorage;

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
125125
// Perform signature modification
126126
rewriter.modifyOpInPlace(
127127
gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
128-
static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
129-
argIndices, argTypes, argAttrs, argLocs);
128+
LogicalResult inserted =
129+
static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
130+
argIndices, argTypes, argAttrs, argLocs);
131+
(void)inserted;
132+
assert(succeeded(inserted) &&
133+
"expected GPU funcs to support inserting any argument");
130134
});
131135
} else {
132136
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ updateFuncOp(func::FuncOp func,
9292
}
9393

9494
// Erase the results.
95-
func.eraseResults(erasedResultIndices);
95+
if (failed(func.eraseResults(erasedResultIndices)))
96+
return failure();
9697

9798
// Add the new arguments to the entry block if the function is not external.
9899
if (func.isExternal())

mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
113113
}
114114

115115
// Update function.
116-
funcOp.eraseResults(erasedResultIndices);
116+
if (failed(funcOp.eraseResults(erasedResultIndices)))
117+
return failure();
117118
returnOp.getOperandsMutable().assign(newReturnValues);
118119

119120
// Update function calls.

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
232232

233233
LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234234
TypeRange results) const {
235-
assert(results.size() == 1 && "expected a single result type");
235+
if (results.size() != 1 || !isValidResultType(results[0]))
236+
return {};
237+
if (!llvm::all_of(inputs, isValidArgumentType))
238+
return {};
236239
return get(results[0], llvm::to_vector(inputs), isVarArg());
237240
}
238241

mlir/lib/Query/Query.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
8888
// Remove unused function arguments
8989
size_t currentIndex = 0;
9090
while (currentIndex < funcOp.getNumArguments()) {
91+
// Erase if possible.
9192
if (funcOp.getArgument(currentIndex).use_empty())
92-
funcOp.eraseArgument(currentIndex);
93-
else
94-
++currentIndex;
93+
if (succeeded(funcOp.eraseArgument(currentIndex)))
94+
continue;
95+
++currentIndex;
9596
}
9697

9798
return funcOp;

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,11 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
698698

699699
// 3. Functions
700700
for (auto &f : list.functions) {
701-
f.funcOp.eraseArguments(f.nonLiveArgs);
702-
f.funcOp.eraseResults(f.nonLiveRets);
701+
// Some functions may not allow erasing arguments or results. These calls
702+
// return failure in such cases without modifying the function, so it's okay
703+
// to proceed.
704+
(void)f.funcOp.eraseArguments(f.nonLiveArgs);
705+
(void)f.funcOp.eraseResults(f.nonLiveRets);
703706
}
704707

705708
// 4. Operands

mlir/test/IR/test-func-erase-result.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-func-erase-result -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-func-erase-result -split-input-file -verify-diagnostics | FileCheck %s
22

33
// CHECK: func private @f(){{$}}
44
// CHECK-NOT: attributes{{.*}}result
@@ -66,3 +66,8 @@ func.func private @f() -> (
6666
f32 {test.erase_this_result},
6767
tensor<3xf32>
6868
)
69+
70+
// -----
71+
72+
// expected-error @below {{failed to erase results}}
73+
llvm.func @llvm_func(!llvm.ptr, i64)

mlir/test/lib/IR/TestFunc.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ struct TestFuncInsertArg
4545
: unknownLoc);
4646
}
4747
func->removeAttr("test.insert_args");
48-
func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
49-
locsToInsert);
48+
if (succeeded(func.insertArguments(indicesToInsert, typesToInsert,
49+
attrsToInsert, locsToInsert)))
50+
continue;
51+
52+
emitError(func->getLoc()) << "failed to insert arguments";
53+
return signalPassFailure();
5054
}
5155
}
5256
};
@@ -79,7 +83,12 @@ struct TestFuncInsertResult
7983
: DictionaryAttr::get(&getContext()));
8084
}
8185
func->removeAttr("test.insert_results");
82-
func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
86+
if (succeeded(func.insertResults(indicesToInsert, typesToInsert,
87+
attrsToInsert)))
88+
continue;
89+
90+
emitError(func->getLoc()) << "failed to insert results";
91+
return signalPassFailure();
8392
}
8493
}
8594
};
@@ -100,7 +109,10 @@ struct TestFuncEraseArg
100109
for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
101110
if (func.getArgAttr(argIndex, "test.erase_this_arg"))
102111
indicesToErase.set(argIndex);
103-
func.eraseArguments(indicesToErase);
112+
if (succeeded(func.eraseArguments(indicesToErase)))
113+
continue;
114+
emitError(func->getLoc()) << "failed to erase arguments";
115+
return signalPassFailure();
104116
}
105117
}
106118
};
@@ -122,7 +134,10 @@ struct TestFuncEraseResult
122134
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
123135
if (func.getResultAttr(resultIndex, "test.erase_this_result"))
124136
indicesToErase.set(resultIndex);
125-
func.eraseResults(indicesToErase);
137+
if (succeeded(func.eraseResults(indicesToErase)))
138+
continue;
139+
emitError(func->getLoc()) << "failed to erase results";
140+
return signalPassFailure();
126141
}
127142
}
128143
};

0 commit comments

Comments
 (0)