Skip to content

Commit 4529797

Browse files
committed
Add a generic "convert-to-llvm" pass delegating to an interface
The multiple -convert-XXX-to-llvm passes are really nice testing tools for individual dialects, but the expectation is that a proper conversion should assemble the conversion patterns using `populateXXXToLLVMConversionPatterns() APIs. However most customers just chain the conversion passes by convenience. This pass makes it composable more transparently to assemble the required patterns for conversion to LLVM dialect by using an interface. The Pass will scan the input and collect all the dialect present, and for those who implement the `ConvertToLLVMPatternInterface` it will use it to populate the conversion pattern, and possible the conversion target. Since these conversions can involve intermediate dialects, or target other dialects than LLVM (for example AVX or NVVM), this pass can't statically declare the required `getDependentDialects()` before the pass pipeline begins. This is worked around by using an extension in the dialectRegistry that will be invoked for every new loaded dialects in the context. This allows to lookup the interface ahead of time and use it to query the dependent dialects. Differential Revision: https://reviews.llvm.org/D157183
1 parent 370a6f0 commit 4529797

File tree

15 files changed

+300
-7
lines changed

15 files changed

+300
-7
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- ToLLVMInterface.h - Conversion to LLVM iface ---*- C++ -*-=============//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
10+
#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
11+
12+
#include "mlir/IR/DialectInterface.h"
13+
#include "mlir/IR/MLIRContext.h"
14+
#include "mlir/Support/LogicalResult.h"
15+
16+
namespace mlir {
17+
class ConversionTarget;
18+
class LLVMTypeConverter;
19+
class MLIRContext;
20+
class Operation;
21+
class RewritePatternSet;
22+
23+
/// Base class for dialect interfaces providing translation to LLVM IR.
24+
/// Dialects that can be translated should provide an implementation of this
25+
/// interface for the supported operations. The interface may be implemented in
26+
/// a separate library to avoid the "main" dialect library depending on LLVM IR.
27+
/// The interface can be attached using the delayed registration mechanism
28+
/// available in DialectRegistry.
29+
class ConvertToLLVMPatternInterface
30+
: public DialectInterface::Base<ConvertToLLVMPatternInterface> {
31+
public:
32+
ConvertToLLVMPatternInterface(Dialect *dialect) : Base(dialect) {}
33+
34+
/// Hook for derived dialect interface to load the dialects they
35+
/// target. The LLVMDialect is implicitly already loaded, but this
36+
/// method allows to load other intermediate dialects used in the
37+
/// conversion, or target dialects like NVVM for example.
38+
virtual void loadDependentDialects(MLIRContext *context) const {}
39+
40+
/// Hook for derived dialect interface to provide conversion patterns
41+
/// and mark dialect legal for the conversion target.
42+
virtual void populateConvertToLLVMConversionPatterns(
43+
ConversionTarget &target, RewritePatternSet &patterns) const = 0;
44+
};
45+
46+
/// Recursively walk the IR and collect all dialects implementing the interface,
47+
/// and populate the conversion patterns.
48+
void populateConversionTargetFromOperation(Operation *op,
49+
ConversionTarget &target,
50+
RewritePatternSet &patterns);
51+
52+
} // namespace mlir
53+
54+
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- ToLLVMPass.h - Conversion to LLVM pass ---*- C++ -*-===================//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
10+
#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
/// Create a pass that performs dialect conversion to LLVM for all dialects
18+
/// implementing `ConvertToLLVMPatternInterface`.
19+
std::unique_ptr<Pass> createConvertToLLVMPass();
20+
21+
} // namespace mlir
22+
23+
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H

mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <memory>
1212

1313
namespace mlir {
14-
14+
class DialectRegistry;
1515
class LLVMTypeConverter;
1616
class RewritePatternSet;
1717
class Pass;
@@ -21,6 +21,8 @@ class Pass;
2121

2222
void populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns);
2323

24+
void registerConvertNVVMToLLVMInterface(DialectRegistry &registry);
25+
2426
} // namespace mlir
2527

2628
#endif // MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
2525
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
2626
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
27+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
2728
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
2829
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
2930
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
15+
//===----------------------------------------------------------------------===//
16+
// ToLLVM
17+
//===----------------------------------------------------------------------===//
18+
19+
def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
20+
let summary = "Convert to LLVM via dialect interfaces found in the input IR";
21+
let description = [{
22+
This is a generic pass to convert to LLVM, it uses the
23+
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
24+
the injection of conversion patterns.
25+
}];
26+
}
27+
1428
//===----------------------------------------------------------------------===//
1529
// AffineToStandard
1630
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/DialectRegistry.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class DialectExtensionBase {
4444
virtual ~DialectExtensionBase();
4545

4646
/// Return the dialects that our required by this extension to be loaded
47-
/// before applying.
47+
/// before applying. If empty then the extension is invoked for every loaded
48+
/// dialect indepently.
4849
ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
4950

5051
/// Apply this extension to the given context and the required dialects.
@@ -55,12 +56,11 @@ class DialectExtensionBase {
5556
virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
5657

5758
protected:
58-
/// Initialize the extension with a set of required dialects. Note that there
59-
/// should always be at least one affected dialect.
59+
/// Initialize the extension with a set of required dialects.
60+
/// If the list is empty, the extension is invoked for every loaded dialect
61+
/// independently.
6062
DialectExtensionBase(ArrayRef<StringRef> dialectNames)
61-
: dialectNames(dialectNames.begin(), dialectNames.end()) {
62-
assert(!dialectNames.empty() && "expected at least one affected dialect");
63-
}
63+
: dialectNames(dialectNames.begin(), dialectNames.end()) {}
6464

6565
private:
6666
/// The names of the dialects affected by this extension.

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_INITALLEXTENSIONS_H_
1515
#define MLIR_INITALLEXTENSIONS_H_
1616

17+
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
1718
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
1819

1920
#include <cstdlib>
@@ -27,6 +28,7 @@ namespace mlir {
2728
/// pipelines and transformations you are using.
2829
inline void registerAllExtensions(DialectRegistry &registry) {
2930
func::registerAllExtensions(registry);
31+
registerConvertNVVMToLLVMInterface(registry);
3032
}
3133

3234
} // namespace mlir

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_subdirectory(ComplexToSPIRV)
1313
add_subdirectory(ComplexToStandard)
1414
add_subdirectory(ControlFlowToLLVM)
1515
add_subdirectory(ControlFlowToSPIRV)
16+
add_subdirectory(ConvertToLLVM)
1617
add_subdirectory(FuncToLLVM)
1718
add_subdirectory(FuncToSPIRV)
1819
add_subdirectory(GPUCommon)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
set(LLVM_OPTIONAL_SOURCES
2+
ConvertToLLVMPass.cpp
3+
ToLLVMInterface.cpp
4+
)
5+
6+
add_mlir_conversion_library(MLIRConvertToLLVMInterface
7+
ToLLVMInterface.cpp
8+
9+
DEPENDS
10+
11+
LINK_LIBS PUBLIC
12+
MLIRIR
13+
MLIRSupport
14+
)
15+
16+
add_mlir_conversion_library(MLIRConvertToLLVMPass
17+
ConvertToLLVMPass.cpp
18+
19+
DEPENDS
20+
MLIRConversionPassIncGen
21+
22+
LINK_LIBS PUBLIC
23+
MLIRConvertToLLVMInterface
24+
MLIRPass
25+
MLIRIR
26+
MLIRSupport
27+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
10+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
11+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
16+
#include "mlir/Transforms/DialectConversion.h"
17+
#include <memory>
18+
19+
#define DEBUG_TYPE "convert-to-llvm"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_CONVERTTOLLVMPASS
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
28+
namespace {
29+
30+
/// This DialectExtension can be attached to the context, which will invoke the
31+
/// `apply()` method for every loaded dialect. If a dialect implements the
32+
/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
33+
/// through the interface. This extension is loaded in the context before
34+
/// starting a pass pipeline that involves dialect conversion to LLVM.
35+
class LoadDependentDialectExtension : public DialectExtensionBase {
36+
public:
37+
LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
38+
39+
void apply(MLIRContext *context,
40+
MutableArrayRef<Dialect *> dialects) const final {
41+
LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
42+
for (Dialect *dialect : dialects) {
43+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
44+
if (!iface)
45+
continue;
46+
LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
47+
<< dialect->getNamespace() << "\n");
48+
iface->loadDependentDialects(context);
49+
}
50+
}
51+
52+
/// Return a copy of this extension.
53+
virtual std::unique_ptr<DialectExtensionBase> clone() const final {
54+
return std::make_unique<LoadDependentDialectExtension>(*this);
55+
}
56+
};
57+
58+
/// This is a generic pass to convert to LLVM, it uses the
59+
/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
60+
/// the injection of conversion patterns.
61+
class ConvertToLLVMPass
62+
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
63+
std::shared_ptr<const FrozenRewritePatternSet> patterns;
64+
std::shared_ptr<const ConversionTarget> target;
65+
66+
public:
67+
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
68+
void getDependentDialects(DialectRegistry &registry) const final {
69+
registry.insert<LLVM::LLVMDialect>();
70+
registry.addExtensions<LoadDependentDialectExtension>();
71+
}
72+
73+
ConvertToLLVMPass(const ConvertToLLVMPass &other)
74+
: ConvertToLLVMPassBase(other), patterns(other.patterns),
75+
target(other.target) {}
76+
77+
LogicalResult initialize(MLIRContext *context) final {
78+
RewritePatternSet tempPatterns(context);
79+
auto target = std::make_shared<ConversionTarget>(*context);
80+
target->addLegalDialect<LLVM::LLVMDialect>();
81+
for (Dialect *dialect : context->getLoadedDialects()) {
82+
// First time we encounter this dialect: if it implements the interface,
83+
// let's populate patterns !
84+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
85+
if (!iface)
86+
continue;
87+
iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns);
88+
}
89+
patterns =
90+
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
91+
this->target = target;
92+
return success();
93+
}
94+
95+
void runOnOperation() final {
96+
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
97+
signalPassFailure();
98+
}
99+
};
100+
101+
} // namespace
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- ToLLVMInterface.cpp - MLIR LLVM Conversion -------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
10+
#include "mlir/IR/Dialect.h"
11+
#include "mlir/IR/Operation.h"
12+
#include "llvm/ADT/DenseSet.h"
13+
14+
using namespace mlir;
15+
16+
void mlir::populateConversionTargetFromOperation(Operation *root,
17+
ConversionTarget &target,
18+
RewritePatternSet &patterns) {
19+
DenseSet<Dialect *> dialects;
20+
root->walk([&](Operation *op) {
21+
Dialect *dialect = op->getDialect();
22+
if (!dialects.insert(dialect).second)
23+
return;
24+
// First time we encounter this dialect: if it implements the interface,
25+
// let's populate patterns !
26+
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
27+
if (!iface)
28+
return;
29+
iface->populateConvertToLLVMConversionPatterns(target, patterns);
30+
});
31+
}

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
1515

16+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1617
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1718
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -190,8 +191,29 @@ struct ConvertNVVMToLLVMPass
190191
}
191192
};
192193

194+
/// Implement the interface to convert NNVM to LLVM.
195+
struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
196+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
197+
void loadDependentDialects(MLIRContext *context) const final {
198+
context->loadDialect<NVVMDialect>();
199+
}
200+
201+
/// Hook for derived dialect interface to provide conversion patterns
202+
/// and mark dialect legal for the conversion target.
203+
void populateConvertToLLVMConversionPatterns(
204+
ConversionTarget &target, RewritePatternSet &patterns) const final {
205+
populateNVVMToLLVMConversionPatterns(patterns);
206+
}
207+
};
208+
193209
} // namespace
194210

195211
void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
196212
patterns.add<PtxLowering>(patterns.getContext());
197213
}
214+
215+
void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry &registry) {
216+
registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
217+
dialect->addInterfaces<NVVMToLLVMDialectInterface>();
218+
});
219+
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1818

19+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1920
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/IR/BuiltinAttributes.h"
@@ -721,6 +722,7 @@ void NVVMDialect::initialize() {
721722
// Support unknown operations because not all NVVM operations are
722723
// registered.
723724
allowUnknownOperations();
725+
declarePromisedInterface<ConvertToLLVMPatternInterface>();
724726
}
725727

726728
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,

0 commit comments

Comments
 (0)