Skip to content

Commit 8c67c48

Browse files
authored
[MLIR] Refactor FuncOpToLLVM to operate on FunctionOpInterface instead of FuncOp (NFC) (#68665)
* refactor `convertFuncOpToLLVMFuncOp` to accept a `FunctionOpInterface` instead of func::FuncOp * move `convertFuncOpToLLVMFuncOp` to corresponding public header, making it available for downstream project.
1 parent 2d854dd commit 8c67c48

File tree

2 files changed

+165
-134
lines changed

2 files changed

+165
-134
lines changed

mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,28 @@
1414
#ifndef MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H
1515
#define MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H
1616

17+
#include "mlir/Interfaces/FunctionInterfaces.h"
18+
#include "mlir/Support/LogicalResult.h"
19+
1720
namespace mlir {
1821

22+
namespace LLVM {
23+
class LLVMFuncOp;
24+
} // namespace LLVM
25+
26+
class ConversionPatternRewriter;
1927
class DialectRegistry;
2028
class LLVMTypeConverter;
2129
class RewritePatternSet;
2230
class SymbolTable;
2331

32+
/// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
33+
/// provided LLVMTypeConverter. Return failure if failed to so.
34+
FailureOr<LLVM::LLVMFuncOp>
35+
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
36+
ConversionPatternRewriter &rewriter,
37+
const LLVMTypeConverter &converter);
38+
2439
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
2540
/// `emitCWrappers` is set, the pattern will also produce functions
2641
/// that pass memref descriptors by pointer-to-structure in addition to the

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 150 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ static bool shouldUseBarePtrCallConv(Operation *op,
7474

7575
/// Only retain those attributes that are not constructed by
7676
/// `LLVMFuncOp::build`.
77-
static void filterFuncAttributes(func::FuncOp func,
77+
static void filterFuncAttributes(FunctionOpInterface func,
7878
SmallVectorImpl<NamedAttribute> &result) {
7979
for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
8080
if (attr.getName() == linkageAttrName ||
@@ -87,26 +87,26 @@ static void filterFuncAttributes(func::FuncOp func,
8787

8888
/// Propagate argument/results attributes.
8989
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
90-
func::FuncOp funcOp,
90+
FunctionOpInterface funcOp,
9191
LLVM::LLVMFuncOp wrapperFuncOp) {
92-
auto argAttrs = funcOp.getArgAttrs();
92+
auto argAttrs = funcOp.getAllArgAttrs();
9393
if (!resultStructType) {
9494
if (auto resAttrs = funcOp.getAllResultAttrs())
9595
wrapperFuncOp.setAllResultAttrs(resAttrs);
9696
if (argAttrs)
97-
wrapperFuncOp.setAllArgAttrs(*argAttrs);
97+
wrapperFuncOp.setAllArgAttrs(argAttrs);
9898
} else {
9999
SmallVector<Attribute> argAttributes;
100100
// Only modify the argument and result attributes when the result is now
101101
// an argument.
102102
if (argAttrs) {
103103
argAttributes.push_back(builder.getDictionaryAttr({}));
104-
argAttributes.append(argAttrs->begin(), argAttrs->end());
104+
argAttributes.append(argAttrs.begin(), argAttrs.end());
105105
wrapperFuncOp.setAllArgAttrs(argAttributes);
106106
}
107107
}
108-
if (funcOp.getSymVisibilityAttr())
109-
wrapperFuncOp.setSymVisibility(funcOp.getSymVisibilityAttr());
108+
cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
109+
.setVisibility(funcOp.getVisibility());
110110
}
111111

112112
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
@@ -119,9 +119,9 @@ static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
119119
/// the extra arguments.
120120
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
121121
const LLVMTypeConverter &typeConverter,
122-
func::FuncOp funcOp,
122+
FunctionOpInterface funcOp,
123123
LLVM::LLVMFuncOp newFuncOp) {
124-
auto type = funcOp.getFunctionType();
124+
auto type = cast<FunctionType>(funcOp.getFunctionType());
125125
auto [wrapperFuncType, resultStructType] =
126126
typeConverter.convertFunctionTypeCWrapper(type);
127127

@@ -179,12 +179,13 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
179179
/// corresponding to a memref descriptor.
180180
static void wrapExternalFunction(OpBuilder &builder, Location loc,
181181
const LLVMTypeConverter &typeConverter,
182-
func::FuncOp funcOp,
182+
FunctionOpInterface funcOp,
183183
LLVM::LLVMFuncOp newFuncOp) {
184184
OpBuilder::InsertionGuard guard(builder);
185185

186186
auto [wrapperType, resultStructType] =
187-
typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
187+
typeConverter.convertFunctionTypeCWrapper(
188+
cast<FunctionType>(funcOp.getFunctionType()));
188189
// This conversion can only fail if it could not convert one of the argument
189190
// types. But since it has been applied to a non-wrapper function before, it
190191
// should have failed earlier and not reach this point at all.
@@ -205,7 +206,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
205206
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
206207

207208
// Get a ValueRange containing arguments.
208-
FunctionType type = funcOp.getFunctionType();
209+
FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
209210
SmallVector<Value, 8> args;
210211
args.reserve(type.getNumInputs());
211212
ValueRange wrapperArgsRange(newFuncOp.getArguments());
@@ -317,6 +318,140 @@ static void modifyFuncOpToUseBarePtrCallingConv(
317318
}
318319
}
319320

321+
FailureOr<LLVM::LLVMFuncOp>
322+
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
323+
ConversionPatternRewriter &rewriter,
324+
const LLVMTypeConverter &converter) {
325+
// Check the funcOp has `FunctionType`.
326+
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
327+
if (!funcTy)
328+
return rewriter.notifyMatchFailure(
329+
funcOp, "Only support FunctionOpInterface with FunctionType");
330+
331+
// Convert the original function arguments. They are converted using the
332+
// LLVMTypeConverter provided to this legalization pattern.
333+
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
334+
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
335+
auto llvmType = converter.convertFunctionSignature(
336+
funcTy, varargsAttr && varargsAttr.getValue(),
337+
shouldUseBarePtrCallConv(funcOp, &converter), result);
338+
if (!llvmType)
339+
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
340+
341+
// Create an LLVM function, use external linkage by default until MLIR
342+
// functions have linkage.
343+
LLVM::Linkage linkage = LLVM::Linkage::External;
344+
if (funcOp->hasAttr(linkageAttrName)) {
345+
auto attr =
346+
dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
347+
if (!attr) {
348+
funcOp->emitError() << "Contains " << linkageAttrName
349+
<< " attribute not of type LLVM::LinkageAttr";
350+
return rewriter.notifyMatchFailure(
351+
funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
352+
}
353+
linkage = attr.getLinkage();
354+
}
355+
356+
SmallVector<NamedAttribute, 4> attributes;
357+
filterFuncAttributes(funcOp, attributes);
358+
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
359+
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
360+
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
361+
attributes);
362+
cast<FunctionOpInterface>(newFuncOp.getOperation())
363+
.setVisibility(funcOp.getVisibility());
364+
365+
// Create a memory effect attribute corresponding to readnone.
366+
StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
367+
if (funcOp->hasAttr(readnoneAttrName)) {
368+
auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
369+
if (!attr) {
370+
funcOp->emitError() << "Contains " << readnoneAttrName
371+
<< " attribute not of type UnitAttr";
372+
return rewriter.notifyMatchFailure(
373+
funcOp, "Contains readnone attribute not of type UnitAttr");
374+
}
375+
auto memoryAttr = LLVM::MemoryEffectsAttr::get(
376+
rewriter.getContext(),
377+
{LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
378+
LLVM::ModRefInfo::NoModRef});
379+
newFuncOp.setMemoryAttr(memoryAttr);
380+
}
381+
382+
// Propagate argument/result attributes to all converted arguments/result
383+
// obtained after converting a given original argument/result.
384+
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
385+
assert(!resAttrDicts.empty() && "expected array to be non-empty");
386+
if (funcOp.getNumResults() == 1)
387+
newFuncOp.setAllResultAttrs(resAttrDicts);
388+
}
389+
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
390+
SmallVector<Attribute> newArgAttrs(
391+
cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
392+
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
393+
// Some LLVM IR attribute have a type attached to them. During FuncOp ->
394+
// LLVMFuncOp conversion these types may have changed. Account for that
395+
// change by converting attributes' types as well.
396+
SmallVector<NamedAttribute, 4> convertedAttrs;
397+
auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
398+
convertedAttrs.reserve(attrsDict.size());
399+
for (const NamedAttribute &attr : attrsDict) {
400+
const auto convert = [&](const NamedAttribute &attr) {
401+
return TypeAttr::get(converter.convertType(
402+
cast<TypeAttr>(attr.getValue()).getValue()));
403+
};
404+
if (attr.getName().getValue() ==
405+
LLVM::LLVMDialect::getByValAttrName()) {
406+
convertedAttrs.push_back(rewriter.getNamedAttr(
407+
LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
408+
} else if (attr.getName().getValue() ==
409+
LLVM::LLVMDialect::getByRefAttrName()) {
410+
convertedAttrs.push_back(rewriter.getNamedAttr(
411+
LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
412+
} else if (attr.getName().getValue() ==
413+
LLVM::LLVMDialect::getStructRetAttrName()) {
414+
convertedAttrs.push_back(rewriter.getNamedAttr(
415+
LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
416+
} else if (attr.getName().getValue() ==
417+
LLVM::LLVMDialect::getInAllocaAttrName()) {
418+
convertedAttrs.push_back(rewriter.getNamedAttr(
419+
LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
420+
} else {
421+
convertedAttrs.push_back(attr);
422+
}
423+
}
424+
auto mapping = result.getInputMapping(i);
425+
assert(mapping && "unexpected deletion of function argument");
426+
// Only attach the new argument attributes if there is a one-to-one
427+
// mapping from old to new types. Otherwise, attributes might be
428+
// attached to types that they do not support.
429+
if (mapping->size == 1) {
430+
newArgAttrs[mapping->inputNo] =
431+
DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
432+
continue;
433+
}
434+
// TODO: Implement custom handling for types that expand to multiple
435+
// function arguments.
436+
for (size_t j = 0; j < mapping->size; ++j)
437+
newArgAttrs[mapping->inputNo + j] =
438+
DictionaryAttr::get(rewriter.getContext(), {});
439+
}
440+
if (!newArgAttrs.empty())
441+
newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
442+
}
443+
444+
rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
445+
newFuncOp.end());
446+
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
447+
&result))) {
448+
return rewriter.notifyMatchFailure(funcOp,
449+
"region types conversion failed");
450+
}
451+
452+
return newFuncOp;
453+
}
454+
320455
namespace {
321456

322457
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
@@ -328,128 +463,9 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
328463
FailureOr<LLVM::LLVMFuncOp>
329464
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
330465
ConversionPatternRewriter &rewriter) const {
331-
// Convert the original function arguments. They are converted using the
332-
// LLVMTypeConverter provided to this legalization pattern.
333-
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
334-
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
335-
auto llvmType = getTypeConverter()->convertFunctionSignature(
336-
funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
337-
shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
338-
if (!llvmType)
339-
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
340-
341-
// Create an LLVM function, use external linkage by default until MLIR
342-
// functions have linkage.
343-
LLVM::Linkage linkage = LLVM::Linkage::External;
344-
if (funcOp->hasAttr(linkageAttrName)) {
345-
auto attr =
346-
dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
347-
if (!attr) {
348-
funcOp->emitError() << "Contains " << linkageAttrName
349-
<< " attribute not of type LLVM::LinkageAttr";
350-
return rewriter.notifyMatchFailure(
351-
funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
352-
}
353-
linkage = attr.getLinkage();
354-
}
355-
356-
SmallVector<NamedAttribute, 4> attributes;
357-
filterFuncAttributes(funcOp, attributes);
358-
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
359-
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
360-
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
361-
attributes);
362-
if (funcOp.getSymVisibilityAttr())
363-
newFuncOp.setSymVisibility(funcOp.getSymVisibilityAttr());
364-
365-
// Create a memory effect attribute corresponding to readnone.
366-
StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
367-
if (funcOp->hasAttr(readnoneAttrName)) {
368-
auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
369-
if (!attr) {
370-
funcOp->emitError() << "Contains " << readnoneAttrName
371-
<< " attribute not of type UnitAttr";
372-
return rewriter.notifyMatchFailure(
373-
funcOp, "Contains readnone attribute not of type UnitAttr");
374-
}
375-
auto memoryAttr = LLVM::MemoryEffectsAttr::get(
376-
rewriter.getContext(),
377-
{LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
378-
LLVM::ModRefInfo::NoModRef});
379-
newFuncOp.setMemoryAttr(memoryAttr);
380-
}
381-
382-
// Propagate argument/result attributes to all converted arguments/result
383-
// obtained after converting a given original argument/result.
384-
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
385-
assert(!resAttrDicts.empty() && "expected array to be non-empty");
386-
if (funcOp.getNumResults() == 1)
387-
newFuncOp.setAllResultAttrs(resAttrDicts);
388-
}
389-
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
390-
SmallVector<Attribute> newArgAttrs(
391-
cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
392-
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
393-
// Some LLVM IR attribute have a type attached to them. During FuncOp ->
394-
// LLVMFuncOp conversion these types may have changed. Account for that
395-
// change by converting attributes' types as well.
396-
SmallVector<NamedAttribute, 4> convertedAttrs;
397-
auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
398-
convertedAttrs.reserve(attrsDict.size());
399-
for (const NamedAttribute &attr : attrsDict) {
400-
const auto convert = [&](const NamedAttribute &attr) {
401-
return TypeAttr::get(getTypeConverter()->convertType(
402-
cast<TypeAttr>(attr.getValue()).getValue()));
403-
};
404-
if (attr.getName().getValue() ==
405-
LLVM::LLVMDialect::getByValAttrName()) {
406-
convertedAttrs.push_back(rewriter.getNamedAttr(
407-
LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
408-
} else if (attr.getName().getValue() ==
409-
LLVM::LLVMDialect::getByRefAttrName()) {
410-
convertedAttrs.push_back(rewriter.getNamedAttr(
411-
LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
412-
} else if (attr.getName().getValue() ==
413-
LLVM::LLVMDialect::getStructRetAttrName()) {
414-
convertedAttrs.push_back(rewriter.getNamedAttr(
415-
LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
416-
} else if (attr.getName().getValue() ==
417-
LLVM::LLVMDialect::getInAllocaAttrName()) {
418-
convertedAttrs.push_back(rewriter.getNamedAttr(
419-
LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
420-
} else {
421-
convertedAttrs.push_back(attr);
422-
}
423-
}
424-
auto mapping = result.getInputMapping(i);
425-
assert(mapping && "unexpected deletion of function argument");
426-
// Only attach the new argument attributes if there is a one-to-one
427-
// mapping from old to new types. Otherwise, attributes might be
428-
// attached to types that they do not support.
429-
if (mapping->size == 1) {
430-
newArgAttrs[mapping->inputNo] =
431-
DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
432-
continue;
433-
}
434-
// TODO: Implement custom handling for types that expand to multiple
435-
// function arguments.
436-
for (size_t j = 0; j < mapping->size; ++j)
437-
newArgAttrs[mapping->inputNo + j] =
438-
DictionaryAttr::get(rewriter.getContext(), {});
439-
}
440-
if (!newArgAttrs.empty())
441-
newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
442-
}
443-
444-
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
445-
newFuncOp.end());
446-
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
447-
&result))) {
448-
return rewriter.notifyMatchFailure(funcOp,
449-
"region types conversion failed");
450-
}
451-
452-
return newFuncOp;
466+
return mlir::convertFuncOpToLLVMFuncOp(
467+
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
468+
*getTypeConverter());
453469
}
454470
};
455471

0 commit comments

Comments
 (0)