Skip to content

Commit e64c766

Browse files
committed
[mlir] recursively convert builtin types to LLVM when possible
Given that LLVM dialect types may now optionally contain types from other dialects, which itself is motivated by dialect interoperability and progressive lowering, the conversion should no longer assume that the outermost LLVM dialect type can be left as is. Instead, it should inspect the types it contains and attempt to convert them to the LLVM dialect. Introduce this capability for LLVM array, pointer and structure types. Only literal structures are currently supported as handling identified structures requires the converison infrastructure to have a mechanism for avoiding infite recursion in case of recursive types. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D112550
1 parent d96656c commit e64c766

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,53 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
3838
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
3939
addConversion([&](VectorType type) { return convertVectorType(type); });
4040

41-
// LLVM-compatible types are legal, so add a pass-through conversion.
41+
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
42+
// before the conversions below since conversions are attempted in reverse
43+
// order and those should take priority.
4244
addConversion([](Type type) {
4345
return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
4446
: llvm::None;
4547
});
4648

49+
// LLVM container types may (recursively) contain other types that must be
50+
// converted even when the outer type is compatible.
51+
addConversion([&](LLVM::LLVMPointerType type) -> llvm::Optional<Type> {
52+
if (auto pointee = convertType(type.getElementType()))
53+
return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
54+
return llvm::None;
55+
});
56+
addConversion([&](LLVM::LLVMStructType type) -> llvm::Optional<Type> {
57+
// TODO: handle conversion of identified structs, which may be recursive.
58+
if (type.isIdentified())
59+
return type;
60+
61+
SmallVector<Type> convertedSubtypes;
62+
convertedSubtypes.reserve(type.getBody().size());
63+
if (failed(convertTypes(type.getBody(), convertedSubtypes)))
64+
return llvm::None;
65+
66+
return LLVM::LLVMStructType::getLiteral(type.getContext(),
67+
convertedSubtypes, type.isPacked());
68+
});
69+
addConversion([&](LLVM::LLVMArrayType type) -> llvm::Optional<Type> {
70+
if (auto element = convertType(type.getElementType()))
71+
return LLVM::LLVMArrayType::get(element, type.getNumElements());
72+
return llvm::None;
73+
});
74+
addConversion([&](LLVM::LLVMFunctionType type) -> llvm::Optional<Type> {
75+
Type convertedResType = convertType(type.getReturnType());
76+
if (!convertedResType)
77+
return llvm::None;
78+
79+
SmallVector<Type> convertedArgTypes;
80+
convertedArgTypes.reserve(type.getNumParams());
81+
if (failed(convertTypes(type.getParams(), convertedArgTypes)))
82+
return llvm::None;
83+
84+
return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
85+
type.isVarArg());
86+
});
87+
4788
// Materialization for memrefs creates descriptor structs from individual
4889
// values constituting them, when descriptors are used, i.e. more than one
4990
// value represents a memref.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: mlir-opt -test-convert-call-op %s | FileCheck %s
2+
3+
// CHECK-LABEL: @ptr
4+
// CHECK: !llvm.ptr<i42>
5+
func private @ptr() -> !llvm.ptr<!test.smpla>
6+
7+
// CHECK-LABEL: @ptr_ptr()
8+
// CHECK: !llvm.ptr<ptr<i42>>
9+
func private @ptr_ptr() -> !llvm.ptr<!llvm.ptr<!test.smpla>>
10+
11+
// CHECK-LABEL: @struct_ptr()
12+
// CHECK: !llvm.struct<(ptr<i42>)>
13+
func private @struct_ptr() -> !llvm.struct<(ptr<!test.smpla>)>
14+
15+
// CHECK-LABEL: @named_struct_ptr()
16+
// CHECK: !llvm.struct<"named", (ptr<!test.smpla>)>
17+
func private @named_struct_ptr() -> !llvm.struct<"named", (ptr<!test.smpla>)>
18+
19+
// CHECK-LABEL: @array_ptr()
20+
// CHECK: !llvm.array<10 x ptr<i42>>
21+
func private @array_ptr() -> !llvm.array<10 x ptr<!test.smpla>>
22+
23+
// CHECK-LABEL: @func()
24+
// CHECK: !llvm.ptr<func<i42 (i42)>>
25+
func private @func() -> !llvm.ptr<!llvm.func<!test.smpla (!test.smpla)>>
26+
27+
// TODO: support conversion of recursive types in the conversion infra.
28+
// CHECK-LABEL: @named_recursive()
29+
// CHECK: !llvm.struct<"recursive", (ptr<!test.smpla>, ptr<struct<"recursive">>)>
30+
func private @named_recursive() -> !llvm.struct<"recursive", (ptr<!test.smpla>, ptr<struct<"recursive">>)>
31+

mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class TestConvertCallOp
5252
typeConverter.addConversion([&](test::TestType type) {
5353
return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
5454
});
55+
typeConverter.addConversion([&](test::SimpleAType type) {
56+
return IntegerType::get(type.getContext(), 42);
57+
});
5558

5659
// Populate patterns.
5760
RewritePatternSet patterns(m.getContext());

0 commit comments

Comments
 (0)