Skip to content

Commit 876a480

Browse files
[mlir][Conversion] Add type converter parameter to ConvertToLLVMPatternInterface
Most `*-to-llvm` conversion patterns require a type converter. This revision adds a type converter to the `populateConvertToLLVMConversionPatterns` function and implements the interface for the MemRef dialect. Differential Revision: https://reviews.llvm.org/D157387
1 parent 13bb748 commit 876a480

File tree

9 files changed

+70
-8
lines changed

9 files changed

+70
-8
lines changed

mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ class ConvertToLLVMPatternInterface
4040
/// Hook for derived dialect interface to provide conversion patterns
4141
/// and mark dialect legal for the conversion target.
4242
virtual void populateConvertToLLVMConversionPatterns(
43-
ConversionTarget &target, RewritePatternSet &patterns) const = 0;
43+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
44+
RewritePatternSet &patterns) const = 0;
4445
};
4546

4647
/// Recursively walk the IR and collect all dialects implementing the interface,
4748
/// and populate the conversion patterns.
4849
void populateConversionTargetFromOperation(Operation *op,
4950
ConversionTarget &target,
51+
LLVMTypeConverter &typeConverter,
5052
RewritePatternSet &patterns);
5153

5254
} // namespace mlir

mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <memory>
1313

1414
namespace mlir {
15+
class DialectRegistry;
1516
class Pass;
1617
class LLVMTypeConverter;
1718
class RewritePatternSet;
@@ -23,6 +24,9 @@ class RewritePatternSet;
2324
/// MemRef dialect to the LLVM dialect.
2425
void populateFinalizeMemRefToLLVMConversionPatterns(
2526
LLVMTypeConverter &converter, RewritePatternSet &patterns);
27+
28+
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);
29+
2630
} // namespace mlir
2731

2832
#endif // MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_INITALLEXTENSIONS_H_
1515
#define MLIR_INITALLEXTENSIONS_H_
1616

17+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1718
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
1819
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
1920
#include "mlir/Target/LLVM/NVVM/Target.h"
@@ -29,6 +30,7 @@ namespace mlir {
2930
/// pipelines and transformations you are using.
3031
inline void registerAllExtensions(DialectRegistry &registry) {
3132
func::registerAllExtensions(registry);
33+
registerConvertMemRefToLLVMInterface(registry);
3234
registerConvertNVVMToLLVMInterface(registry);
3335
registerNVVMTarget(registry);
3436
}

mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMPass
2222
LINK_LIBS PUBLIC
2323
MLIRConvertToLLVMInterface
2424
MLIRIR
25+
MLIRLLVMCommonConversion
2526
MLIRLLVMDialect
2627
MLIRPass
2728
MLIRRewrite

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1010
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
1111
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1213
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1314
#include "mlir/IR/PatternMatch.h"
1415
#include "mlir/Pass/Pass.h"
@@ -62,6 +63,7 @@ class ConvertToLLVMPass
6263
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
6364
std::shared_ptr<const FrozenRewritePatternSet> patterns;
6465
std::shared_ptr<const ConversionTarget> target;
66+
std::shared_ptr<const LLVMTypeConverter> typeConverter;
6567

6668
public:
6769
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -72,23 +74,26 @@ class ConvertToLLVMPass
7274

7375
ConvertToLLVMPass(const ConvertToLLVMPass &other)
7476
: ConvertToLLVMPassBase(other), patterns(other.patterns),
75-
target(other.target) {}
77+
target(other.target), typeConverter(other.typeConverter) {}
7678

7779
LogicalResult initialize(MLIRContext *context) final {
7880
RewritePatternSet tempPatterns(context);
7981
auto target = std::make_shared<ConversionTarget>(*context);
8082
target->addLegalDialect<LLVM::LLVMDialect>();
83+
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
8184
for (Dialect *dialect : context->getLoadedDialects()) {
8285
// First time we encounter this dialect: if it implements the interface,
8386
// let's populate patterns !
8487
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
8588
if (!iface)
8689
continue;
87-
iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns);
90+
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
91+
tempPatterns);
8892
}
8993
patterns =
9094
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
9195
this->target = target;
96+
this->typeConverter = typeConverter;
9297
return success();
9398
}
9499

mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
using namespace mlir;
1515

16-
void mlir::populateConversionTargetFromOperation(Operation *root,
17-
ConversionTarget &target,
18-
RewritePatternSet &patterns) {
16+
void mlir::populateConversionTargetFromOperation(
17+
Operation *root, ConversionTarget &target, LLVMTypeConverter &typeConverter,
18+
RewritePatternSet &patterns) {
1919
DenseSet<Dialect *> dialects;
2020
root->walk([&](Operation *op) {
2121
Dialect *dialect = op->getDialect();
@@ -26,6 +26,7 @@ void mlir::populateConversionTargetFromOperation(Operation *root,
2626
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
2727
if (!iface)
2828
return;
29-
iface->populateConvertToLLVMConversionPatterns(target, patterns);
29+
iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
30+
patterns);
3031
});
3132
}

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1010

1111
#include "mlir/Analysis/DataLayoutAnalysis.h"
12+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1213
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1314
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1415
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -1935,4 +1936,27 @@ struct FinalizeMemRefToLLVMConversionPass
19351936
signalPassFailure();
19361937
}
19371938
};
1939+
1940+
/// Implement the interface to convert MemRef to LLVM.
1941+
struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1942+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1943+
void loadDependentDialects(MLIRContext *context) const final {
1944+
context->loadDialect<LLVM::LLVMDialect>();
1945+
}
1946+
1947+
/// Hook for derived dialect interface to provide conversion patterns
1948+
/// and mark dialect legal for the conversion target.
1949+
void populateConvertToLLVMConversionPatterns(
1950+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
1951+
RewritePatternSet &patterns) const final {
1952+
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1953+
}
1954+
};
1955+
19381956
} // namespace
1957+
1958+
void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
1959+
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1960+
dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1961+
});
1962+
}

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
201201
/// Hook for derived dialect interface to provide conversion patterns
202202
/// and mark dialect legal for the conversion target.
203203
void populateConvertToLLVMConversionPatterns(
204-
ConversionTarget &target, RewritePatternSet &patterns) const final {
204+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
205+
RewritePatternSet &patterns) const final {
205206
populateNVVMToLLVMConversionPatterns(patterns);
206207
}
207208
};

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
// RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' %s -split-input-file | FileCheck %s
22
// RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32 use-opaque-pointers=1' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
33

4+
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
5+
// and the generic `convert-to-llvm` pass. This produces slightly different IR
6+
// because the conversion target is set up differently. Only one test case is
7+
// checked.
8+
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s
9+
410
// CHECK-LABEL: func @view(
511
// CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
612
func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
@@ -88,6 +94,10 @@ func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
8894
// CHECK-LABEL: func @view_empty_memref(
8995
// CHECK: %[[ARG0:.*]]: index,
9096
// CHECK: %[[ARG1:.*]]: memref<0xi8>)
97+
98+
// CHECK-INTERFACE-LABEL: func @view_empty_memref(
99+
// CHECK-INTERFACE: %[[ARG0:.*]]: index,
100+
// CHECK-INTERFACE: %[[ARG1:.*]]: memref<0xi8>)
91101
func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
92102

93103
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -101,6 +111,18 @@ func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
101111
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
102112
// CHECK: llvm.mlir.constant(4 : index) : i64
103113
// CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
114+
115+
// CHECK-INTERFACE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
116+
// CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64
117+
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
118+
// CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64
119+
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
120+
// CHECK-INTERFACE: llvm.mlir.constant(1 : index) : i64
121+
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
122+
// CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64
123+
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
124+
// CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64
125+
// CHECK-INTERFACE: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
104126
%0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32>
105127

106128
return

0 commit comments

Comments
 (0)