Skip to content

Commit 9e7b6f4

Browse files
authored
[mlir] Adopt ConvertToLLVMPatternInterface GpuToLLVMConversionPass to align with convert-to-llvm (#73761)
This is a follow-up to the introduction of `convert-to-llvm`: it is supposed to be a unifying pass through the `ConvertToLLVMPatternInterface`, but some specific conversion (like the GPU target) aren't vanilla LLVM target. Instead they need extra customizations that are specific to LLVM-on-GPUs and our custom runtime wrappers. This change make the GpuToLLVMConversionPass just as pluggable as the `convert-to-llvm` by using the same mechanism.
1 parent 14028ec commit 9e7b6f4

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ namespace mlir {
2222
/// implementing `ConvertToLLVMPatternInterface`.
2323
std::unique_ptr<Pass> createConvertToLLVMPass();
2424

25+
/// Register the extension that will load dependent dialects for LLVM
26+
/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
27+
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry);
28+
2529
} // namespace mlir
2630

2731
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ class ConvertToLLVMPass
124124

125125
} // namespace
126126

127+
void mlir::registerConvertToLLVMDependentDialectLoading(
128+
DialectRegistry &registry) {
129+
registry.addExtensions<LoadDependentDialectExtension>();
130+
}
131+
127132
std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
128133
return std::make_unique<ConvertToLLVMPass>();
129134
}

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1919
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
2020
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
2123
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
2224
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
2325
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -38,6 +40,8 @@
3840
#include "llvm/Support/Error.h"
3941
#include "llvm/Support/FormatVariadic.h"
4042

43+
#define DEBUG_TYPE "gpu-to-llvm"
44+
4145
namespace mlir {
4246
#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
4347
#include "mlir/Conversion/Passes.h.inc"
@@ -48,12 +52,14 @@ using namespace mlir;
4852
static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
4953

5054
namespace {
51-
5255
class GpuToLLVMConversionPass
5356
: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
5457
public:
5558
using Base::Base;
56-
59+
void getDependentDialects(DialectRegistry &registry) const final {
60+
Base::getDependentDialects(registry);
61+
registerConvertToLLVMDependentDialectLoading(registry);
62+
}
5763
// Run the dialect converter on the module.
5864
void runOnOperation() override;
5965
};
@@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
580586
} // namespace
581587

582588
void GpuToLLVMConversionPass::runOnOperation() {
583-
LowerToLLVMOptions options(&getContext());
589+
MLIRContext *context = &getContext();
590+
SymbolTable symbolTable = SymbolTable(getOperation());
591+
LowerToLLVMOptions options(context);
584592
options.useBarePtrCallConv = hostBarePtrCallConv;
593+
RewritePatternSet patterns(context);
594+
ConversionTarget target(*context);
595+
target.addLegalDialect<LLVM::LLVMDialect>();
596+
LLVMTypeConverter converter(context, options);
597+
598+
// Populate all patterns from all dialects that implement the
599+
// `ConvertToLLVMPatternInterface` interface.
600+
for (Dialect *dialect : context->getLoadedDialects()) {
601+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
602+
if (!iface)
603+
continue;
604+
iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
605+
}
585606

586-
LLVMTypeConverter converter(&getContext(), options);
587-
RewritePatternSet patterns(&getContext());
588-
LLVMConversionTarget target(getContext());
589-
590-
SymbolTable symbolTable = SymbolTable(getOperation());
591607
// Preserve GPU modules if they have target attributes.
592608
target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
593609
[](gpu::GPUModuleOp module) -> bool {
@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
605621
!module.getTargetsAttr().empty());
606622
});
607623

608-
mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
609-
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
624+
// These aren't covered by the ConvertToLLVMPatternInterface right now.
610625
populateVectorToLLVMConversionPatterns(converter, patterns);
611626
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
612-
populateFuncToLLVMConversionPatterns(converter, patterns);
613627
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
614628
target);
615629
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,

0 commit comments

Comments
 (0)