Skip to content

[mlir] Adopt ConvertToLLVMPatternInterface GpuToLLVMConversionPass to align with convert-to-llvm #73761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ namespace mlir {
/// implementing `ConvertToLLVMPatternInterface`.
std::unique_ptr<Pass> createConvertToLLVMPass();

/// Register the extension that will load dependent dialects for LLVM
/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry);

} // namespace mlir

#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
5 changes: 5 additions & 0 deletions mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ class ConvertToLLVMPass

} // namespace

void mlir::registerConvertToLLVMDependentDialectLoading(
DialectRegistry &registry) {
registry.addExtensions<LoadDependentDialectExtension>();
}

std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
return std::make_unique<ConvertToLLVMPass>();
}
36 changes: 25 additions & 11 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
Expand All @@ -38,6 +40,8 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"

#define DEBUG_TYPE "gpu-to-llvm"

namespace mlir {
#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
Expand All @@ -48,12 +52,14 @@ using namespace mlir;
static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";

namespace {

class GpuToLLVMConversionPass
: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
public:
using Base::Base;

void getDependentDialects(DialectRegistry &registry) const final {
Base::getDependentDialects(registry);
registerConvertToLLVMDependentDialectLoading(registry);
}
// Run the dialect converter on the module.
void runOnOperation() override;
};
Expand Down Expand Up @@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
} // namespace

void GpuToLLVMConversionPass::runOnOperation() {
LowerToLLVMOptions options(&getContext());
MLIRContext *context = &getContext();
SymbolTable symbolTable = SymbolTable(getOperation());
LowerToLLVMOptions options(context);
options.useBarePtrCallConv = hostBarePtrCallConv;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
LLVMTypeConverter converter(context, options);

// Populate all patterns from all dialects that implement the
// `ConvertToLLVMPatternInterface` interface.
for (Dialect *dialect : context->getLoadedDialects()) {
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
}

LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
LLVMConversionTarget target(getContext());

SymbolTable symbolTable = SymbolTable(getOperation());
// Preserve GPU modules if they have target attributes.
target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
[](gpu::GPUModuleOp module) -> bool {
Expand All @@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
!module.getTargetsAttr().empty());
});

mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// These aren't covered by the ConvertToLLVMPatternInterface right now.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity - why these passes are deleted and not covered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't follow what you mean by deleted?

The populateXXXXToLLVMConversionPatterns that I didn't remove are the one for which we haven't implemented the ConvertToLLVMPatternInterface.
The one I removed (populateControlFlowToLLVMConversionPatterns and populateFuncToLLVMConversionPatterns) will be handled line 604 above (and only when these dialects are loaded: we only pay what we may need :) )

populateVectorToLLVMConversionPatterns(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
Expand Down