18
18
#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19
19
#include " mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20
20
#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21
+ #include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22
+ #include " mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
21
23
#include " mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
22
24
#include " mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
23
25
#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
38
40
#include " llvm/Support/Error.h"
39
41
#include " llvm/Support/FormatVariadic.h"
40
42
43
+ #define DEBUG_TYPE " gpu-to-llvm"
44
+
41
45
namespace mlir {
42
46
#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
43
47
#include " mlir/Conversion/Passes.h.inc"
@@ -48,12 +52,14 @@ using namespace mlir;
48
52
static constexpr const char *kGpuBinaryStorageSuffix = " _gpubin_cst" ;
49
53
50
54
namespace {
51
-
52
55
class GpuToLLVMConversionPass
53
56
: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54
57
public:
55
58
using Base::Base;
56
-
59
+ void getDependentDialects (DialectRegistry ®istry) const final {
60
+ Base::getDependentDialects (registry);
61
+ registerConvertToLLVMDependentDialectLoading (registry);
62
+ }
57
63
// Run the dialect converter on the module.
58
64
void runOnOperation () override ;
59
65
};
@@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
580
586
} // namespace
581
587
582
588
void GpuToLLVMConversionPass::runOnOperation () {
583
- LowerToLLVMOptions options (&getContext ());
589
+ MLIRContext *context = &getContext ();
590
+ SymbolTable symbolTable = SymbolTable (getOperation ());
591
+ LowerToLLVMOptions options (context);
584
592
options.useBarePtrCallConv = hostBarePtrCallConv;
593
+ RewritePatternSet patterns (context);
594
+ ConversionTarget target (*context);
595
+ target.addLegalDialect <LLVM::LLVMDialect>();
596
+ LLVMTypeConverter converter (context, options);
597
+
598
+ // Populate all patterns from all dialects that implement the
599
+ // `ConvertToLLVMPatternInterface` interface.
600
+ for (Dialect *dialect : context->getLoadedDialects ()) {
601
+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
602
+ if (!iface)
603
+ continue ;
604
+ iface->populateConvertToLLVMConversionPatterns (target, converter, patterns);
605
+ }
585
606
586
- LLVMTypeConverter converter (&getContext (), options);
587
- RewritePatternSet patterns (&getContext ());
588
- LLVMConversionTarget target (getContext ());
589
-
590
- SymbolTable symbolTable = SymbolTable (getOperation ());
591
607
// Preserve GPU modules if they have target attributes.
592
608
target.addDynamicallyLegalOp <gpu::GPUModuleOp>(
593
609
[](gpu::GPUModuleOp module ) -> bool {
@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
605
621
!module .getTargetsAttr ().empty ());
606
622
});
607
623
608
- mlir::arith::populateArithToLLVMConversionPatterns (converter, patterns);
609
- mlir::cf::populateControlFlowToLLVMConversionPatterns (converter, patterns);
624
+ // These aren't covered by the ConvertToLLVMPatternInterface right now.
610
625
populateVectorToLLVMConversionPatterns (converter, patterns);
611
626
populateFinalizeMemRefToLLVMConversionPatterns (converter, patterns);
612
- populateFuncToLLVMConversionPatterns (converter, patterns);
613
627
populateAsyncStructuralTypeConversionsAndLegality (converter, patterns,
614
628
target);
615
629
populateGpuToLLVMConversionPatterns (converter, patterns, gpuBinaryAnnotation,
0 commit comments