Skip to content

Commit a2590e0

Browse files
[mlir][Transforms] Make 1:N function conversion pattern interface-based (#92395)
This commit turns the 1:N dialect conversion pattern for function signatures into a pattern for `FunctionOpInterface`. This is similar to the interface-based pattern that is provided with the 1:1 dialect conversion (`populateFunctionOpInterfaceTypeConversionPattern`). No change in functionality apart from supporting all `FunctionOpInterface` ops and not just `func::FuncOp`.
1 parent d1cff36 commit a2590e0

File tree

3 files changed

+75
-45
lines changed

3 files changed

+75
-45
lines changed

mlir/include/mlir/Transforms/OneToNTypeConversion.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,20 @@ LogicalResult
297297
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
298298
const FrozenRewritePatternSet &patterns);
299299

300+
/// Add a pattern to the given pattern list to convert the signature of a
301+
/// FunctionOpInterface op with the given type converter. This only supports
302+
/// ops which use FunctionType to represent their type. This is intended to be
303+
/// used with the 1:N dialect conversion.
304+
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
305+
StringRef functionLikeOpName, TypeConverter &converter,
306+
RewritePatternSet &patterns);
307+
template <typename FuncOpT>
308+
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
309+
TypeConverter &converter, RewritePatternSet &patterns) {
310+
populateOneToNFunctionOpInterfaceTypeConversionPattern(
311+
FuncOpT::getOperationName(), converter, patterns);
312+
}
313+
300314
} // namespace mlir
301315

302316
#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H

mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,50 +49,6 @@ class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
4949
}
5050
};
5151

52-
class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
53-
public:
54-
using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
55-
56-
LogicalResult
57-
matchAndRewrite(FuncOp op, OpAdaptor adaptor,
58-
OneToNPatternRewriter &rewriter) const override {
59-
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
60-
61-
// Construct mapping for function arguments.
62-
OneToNTypeMapping argumentMapping(op.getArgumentTypes());
63-
if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
64-
argumentMapping)))
65-
return failure();
66-
67-
// Construct mapping for function results.
68-
OneToNTypeMapping funcResultMapping(op.getResultTypes());
69-
if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
70-
funcResultMapping)))
71-
return failure();
72-
73-
// Nothing to do if the op doesn't have any non-identity conversions for its
74-
// operands or results.
75-
if (!argumentMapping.hasNonIdentityConversion() &&
76-
!funcResultMapping.hasNonIdentityConversion())
77-
return failure();
78-
79-
// Update the function signature in-place.
80-
auto newType = FunctionType::get(rewriter.getContext(),
81-
argumentMapping.getConvertedTypes(),
82-
funcResultMapping.getConvertedTypes());
83-
rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
84-
85-
// Update block signatures.
86-
if (!op.isExternal()) {
87-
Region *region = &op.getBody();
88-
Block *block = &region->front();
89-
rewriter.applySignatureConversion(block, argumentMapping);
90-
}
91-
92-
return success();
93-
}
94-
};
95-
9652
class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
9753
public:
9854
using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
@@ -121,10 +77,11 @@ void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
12177
patterns.add<
12278
// clang-format off
12379
ConvertTypesInFuncCallOp,
124-
ConvertTypesInFuncFuncOp,
12580
ConvertTypesInFuncReturnOp
12681
// clang-format on
12782
>(typeConverter, patterns.getContext());
83+
populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
84+
typeConverter, patterns);
12885
}
12986

13087
} // namespace mlir

mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Transforms/OneToNTypeConversion.h"
1010

11+
#include "mlir/Interfaces/FunctionInterfaces.h"
1112
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1213
#include "llvm/ADT/SmallSet.h"
1314

@@ -412,4 +413,62 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
412413
return success();
413414
}
414415

416+
namespace {
417+
class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
418+
public:
419+
FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
420+
MLIRContext *ctx,
421+
TypeConverter &converter)
422+
: OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1,
423+
ctx) {}
424+
425+
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
426+
const OneToNTypeMapping &operandMapping,
427+
const OneToNTypeMapping &resultMapping,
428+
ValueRange convertedOperands) const override {
429+
auto funcOp = cast<FunctionOpInterface>(op);
430+
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
431+
432+
// Construct mapping for function arguments.
433+
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
434+
if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
435+
argumentMapping)))
436+
return failure();
437+
438+
// Construct mapping for function results.
439+
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
440+
if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
441+
funcResultMapping)))
442+
return failure();
443+
444+
// Nothing to do if the op doesn't have any non-identity conversions for its
445+
// operands or results.
446+
if (!argumentMapping.hasNonIdentityConversion() &&
447+
!funcResultMapping.hasNonIdentityConversion())
448+
return failure();
449+
450+
// Update the function signature in-place.
451+
auto newType = FunctionType::get(rewriter.getContext(),
452+
argumentMapping.getConvertedTypes(),
453+
funcResultMapping.getConvertedTypes());
454+
rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); });
455+
456+
// Update block signatures.
457+
if (!funcOp.isExternal()) {
458+
Region *region = &funcOp.getFunctionBody();
459+
Block *block = &region->front();
460+
rewriter.applySignatureConversion(block, argumentMapping);
461+
}
462+
463+
return success();
464+
}
465+
};
466+
} // namespace
467+
468+
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
469+
StringRef functionLikeOpName, TypeConverter &converter,
470+
RewritePatternSet &patterns) {
471+
patterns.add<FunctionOpInterfaceSignatureConversion>(
472+
functionLikeOpName, patterns.getContext(), converter);
473+
}
415474
} // namespace mlir

0 commit comments

Comments
 (0)