Skip to content

Commit b2ab375

Browse files
committed
[mlir] use the new stateful LLVM type translator by default
Previous type model in the LLVM dialect did not support identified structure types properly and therefore could use stateless translations implemented as free functions. The new model supports identified structs and must keep track of the identified structure types present in the target context (LLVMContext or MLIRContext) to avoid creating duplicate structs due to LLVM's type auto-renaming. Expose the stateful type translation classes and use them during translation, storing the state as part of ModuleTranslation. Drop the test type translation mechanism that is no longer necessary and update the tests to exercise type translation as part of the main translation flow. Update the code in vector-to-LLVM dialect conversion that relied on stateless translation to use the new class in a stateless manner. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D85297
1 parent e1de85f commit b2ab375

File tree

13 files changed

+277
-414
lines changed

13 files changed

+277
-414
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@ struct LLVMTypeStorage;
4848
struct LLVMDialectImpl;
4949
} // namespace detail
5050

51-
/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function
52-
/// exists exclusively for the purpose of gradual transition to the first-party
53-
/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM
54-
/// translation.
55-
llvm::Type *convertLLVMType(LLVMType type);
56-
5751
///// Ops /////
5852
#define GET_OP_CLASSES
5953
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/Block.h"
2020
#include "mlir/IR/Module.h"
2121
#include "mlir/IR/Value.h"
22+
#include "mlir/Target/LLVMIR/TypeTranslation.h"
2223

2324
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
2425
#include "llvm/IR/BasicBlock.h"
@@ -127,6 +128,9 @@ class ModuleTranslation {
127128
/// Mappings between llvm.mlir.global definitions and corresponding globals.
128129
DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
129130

131+
/// A stateful object used to translate types.
132+
TypeToLLVMIRTranslator typeTranslator;
133+
130134
protected:
131135
/// Mappings between original and translated values, used for lookups.
132136
llvm::StringMap<llvm::Function *> functionMapping;

mlir/include/mlir/Target/LLVMIR/TypeTranslation.h

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
1515
#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
1616

17+
#include <memory>
18+
1719
namespace llvm {
20+
class DataLayout;
1821
class LLVMContext;
1922
class Type;
2023
} // namespace llvm
@@ -27,8 +30,49 @@ namespace LLVM {
2730

2831
class LLVMType;
2932

30-
llvm::Type *translateTypeToLLVMIR(LLVMType type, llvm::LLVMContext &context);
31-
LLVMType translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context);
33+
namespace detail {
34+
class TypeToLLVMIRTranslatorImpl;
35+
class TypeFromLLVMIRTranslatorImpl;
36+
} // namespace detail
37+
38+
/// Utility class to translate MLIR LLVM dialect types to LLVM IR. Stores the
39+
/// translation state, in particular any identified structure types that can be
40+
/// reused in further translation.
41+
class TypeToLLVMIRTranslator {
42+
public:
43+
TypeToLLVMIRTranslator(llvm::LLVMContext &context);
44+
~TypeToLLVMIRTranslator();
45+
46+
/// Returns the perferred alignment for the type given the data layout. Note
47+
/// that this will perform type conversion and store its results for future
48+
/// uses.
49+
// TODO: this should be removed when MLIR has proper data layout.
50+
unsigned getPreferredAlignment(LLVM::LLVMType type,
51+
const llvm::DataLayout &layout);
52+
53+
/// Translates the given MLIR LLVM dialect type to LLVM IR.
54+
llvm::Type *translateType(LLVM::LLVMType type);
55+
56+
private:
57+
/// Private implementation.
58+
std::unique_ptr<detail::TypeToLLVMIRTranslatorImpl> impl;
59+
};
60+
61+
/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
62+
/// the translation state, in particular any identified structure types that are
63+
/// reused across translations.
64+
class TypeFromLLVMIRTranslator {
65+
public:
66+
TypeFromLLVMIRTranslator(MLIRContext &context);
67+
~TypeFromLLVMIRTranslator();
68+
69+
/// Translates the given LLVM IR type to the MLIR LLVM dialect.
70+
LLVM::LLVMType translateType(llvm::Type *type);
71+
72+
private:
73+
/// Private implementation.
74+
std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
75+
};
3276

3377
} // namespace LLVM
3478
} // namespace mlir

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
126126
if (!elementTy)
127127
return failure();
128128

129-
auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
130-
// TODO: this should be abstracted away to avoid depending on translation.
131-
align = dataLayout.getPrefTypeAlignment(LLVM::translateTypeToLLVMIR(
132-
elementTy.cast<LLVM::LLVMType>(),
133-
typeConverter.getDialect()->getLLVMContext()));
129+
// TODO: this should use the MLIR data layout when it becomes available and
130+
// stop depending on translation.
131+
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
132+
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
133+
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
134+
dialect->getLLVMModule().getDataLayout());
134135
return success();
135136
}
136137

mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp

Lines changed: 14 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/Module.h"
1717
#include "mlir/IR/StandardTypes.h"
1818
#include "mlir/Target/LLVMIR.h"
19+
#include "mlir/Target/LLVMIR/TypeTranslation.h"
1920
#include "mlir/Translation.h"
2021

2122
#include "llvm/IR/Attributes.h"
@@ -48,7 +49,8 @@ class Importer {
4849
public:
4950
Importer(MLIRContext *context, ModuleOp module)
5051
: b(context), context(context), module(module),
51-
unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
52+
unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)),
53+
typeTranslator(*context) {
5254
b.setInsertionPointToStart(module.getBody());
5355
dialect = context->getRegisteredDialect<LLVMDialect>();
5456
}
@@ -129,6 +131,8 @@ class Importer {
129131
Location unknownLoc;
130132
/// Cached dialect.
131133
LLVMDialect *dialect;
134+
/// The stateful type translator (contains named structs).
135+
LLVM::TypeFromLLVMIRTranslator typeTranslator;
132136
};
133137
} // namespace
134138

@@ -149,79 +153,16 @@ Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
149153
}
150154

151155
LLVMType Importer::processType(llvm::Type *type) {
152-
switch (type->getTypeID()) {
153-
case llvm::Type::FloatTyID:
154-
return LLVMType::getFloatTy(dialect);
155-
case llvm::Type::DoubleTyID:
156-
return LLVMType::getDoubleTy(dialect);
157-
case llvm::Type::IntegerTyID:
158-
return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
159-
case llvm::Type::PointerTyID: {
160-
LLVMType elementType = processType(type->getPointerElementType());
161-
if (!elementType)
162-
return nullptr;
163-
return elementType.getPointerTo(type->getPointerAddressSpace());
164-
}
165-
case llvm::Type::ArrayTyID: {
166-
LLVMType elementType = processType(type->getArrayElementType());
167-
if (!elementType)
168-
return nullptr;
169-
return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
170-
}
171-
case llvm::Type::ScalableVectorTyID: {
172-
emitError(unknownLoc) << "scalable vector types not supported";
173-
return nullptr;
174-
}
175-
case llvm::Type::FixedVectorTyID: {
176-
auto *typeVTy = llvm::cast<llvm::FixedVectorType>(type);
177-
LLVMType elementType = processType(typeVTy->getElementType());
178-
if (!elementType)
179-
return nullptr;
180-
return LLVMType::getVectorTy(elementType, typeVTy->getNumElements());
181-
}
182-
case llvm::Type::VoidTyID:
183-
return LLVMType::getVoidTy(dialect);
184-
case llvm::Type::FP128TyID:
185-
return LLVMType::getFP128Ty(dialect);
186-
case llvm::Type::X86_FP80TyID:
187-
return LLVMType::getX86_FP80Ty(dialect);
188-
case llvm::Type::StructTyID: {
189-
SmallVector<LLVMType, 4> elementTypes;
190-
elementTypes.reserve(type->getStructNumElements());
191-
for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
192-
LLVMType ty = processType(type->getStructElementType(i));
193-
if (!ty)
194-
return nullptr;
195-
elementTypes.push_back(ty);
196-
}
197-
return LLVMType::getStructTy(dialect, elementTypes,
198-
cast<llvm::StructType>(type)->isPacked());
199-
}
200-
case llvm::Type::FunctionTyID: {
201-
llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
202-
SmallVector<LLVMType, 4> paramTypes;
203-
for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
204-
LLVMType ty = processType(fty->getParamType(i));
205-
if (!ty)
206-
return nullptr;
207-
paramTypes.push_back(ty);
208-
}
209-
LLVMType result = processType(fty->getReturnType());
210-
if (!result)
211-
return nullptr;
156+
if (LLVMType result = typeTranslator.translateType(type))
157+
return result;
212158

213-
return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
214-
}
215-
default: {
216-
// FIXME: Diagnostic should be able to natively handle types that have
217-
// operator<<(raw_ostream&) defined.
218-
std::string s;
219-
llvm::raw_string_ostream os(s);
220-
os << *type;
221-
emitError(unknownLoc) << "unhandled type: " << os.str();
222-
return nullptr;
223-
}
224-
}
159+
// FIXME: Diagnostic should be able to natively handle types that have
160+
// operator<<(raw_ostream&) defined.
161+
std::string s;
162+
llvm::raw_string_ostream os(s);
163+
os << *type;
164+
emitError(unknownLoc) << "unhandled type: " << os.str();
165+
return nullptr;
225166
}
226167

227168
// We only need integers, floats, doubles, and vectors and tensors thereof for

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
304304
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
305305
ompDialect(
306306
module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
307-
llvmDialect(module->getContext()->getRegisteredDialect<LLVMDialect>()) {
307+
llvmDialect(module->getContext()->getRegisteredDialect<LLVMDialect>()),
308+
typeTranslator(this->llvmModule->getContext()) {
308309
assert(satisfiesLLVMModule(mlirModule) &&
309310
"mlirModule should honor LLVM's module semantics.");
310311
}
@@ -935,7 +936,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
935936
llvm::Type *ModuleTranslation::convertType(LLVMType type) {
936937
// Lock the LLVM context as we create types in it.
937938
llvm::sys::SmartScopedLock<true> lock(llvmDialect->getLLVMContextMutex());
938-
return LLVM::translateTypeToLLVMIR(type, llvmDialect->getLLVMContext());
939+
return typeTranslator.translateType(type);
939940
}
940941

941942
/// A helper to look up remapped operands in the value remapping table.`

mlir/lib/Target/LLVMIR/TypeTranslation.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@
1111
#include "mlir/IR/MLIRContext.h"
1212

1313
#include "llvm/ADT/TypeSwitch.h"
14+
#include "llvm/IR/DataLayout.h"
1415
#include "llvm/IR/DerivedTypes.h"
1516
#include "llvm/IR/Type.h"
1617

1718
using namespace mlir;
1819

19-
namespace {
20+
namespace mlir {
21+
namespace LLVM {
22+
namespace detail {
2023
/// Support for translating MLIR LLVM dialect types to LLVM IR.
21-
class TypeToLLVMIRTranslator {
24+
class TypeToLLVMIRTranslatorImpl {
2225
public:
2326
/// Constructs a class creating types in the given LLVM context.
24-
TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {}
27+
TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}
2528

2629
/// Translates a single type.
2730
llvm::Type *translateType(LLVM::LLVMType type) {
@@ -160,22 +163,32 @@ class TypeToLLVMIRTranslator {
160163
/// type instead of creating a new type.
161164
llvm::DenseMap<LLVM::LLVMType, llvm::Type *> knownTranslations;
162165
};
163-
} // end namespace
164-
165-
/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain
166-
/// the mapping for identified structs so new structs will be created with
167-
/// auto-renaming on each call. This is intended exclusively for testing.
168-
llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMType type,
169-
llvm::LLVMContext &context) {
170-
return TypeToLLVMIRTranslator(context).translateType(type);
166+
} // end namespace detail
167+
} // end namespace LLVM
168+
} // end namespace mlir
169+
170+
LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
171+
: impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {}
172+
173+
LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {}
174+
175+
llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) {
176+
return impl->translateType(type);
171177
}
172178

173-
namespace {
179+
unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
180+
LLVM::LLVMType type, const llvm::DataLayout &layout) {
181+
return layout.getPrefTypeAlignment(translateType(type));
182+
}
183+
184+
namespace mlir {
185+
namespace LLVM {
186+
namespace detail {
174187
/// Support for translating LLVM IR types to MLIR LLVM dialect types.
175-
class TypeFromLLVMIRTranslator {
188+
class TypeFromLLVMIRTranslatorImpl {
176189
public:
177190
/// Constructs a class creating types in the given MLIR context.
178-
TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {}
191+
TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
179192

180193
/// Translates the given type.
181194
LLVM::LLVMType translateType(llvm::Type *type) {
@@ -299,11 +312,15 @@ class TypeFromLLVMIRTranslator {
299312
/// The context in which MLIR types are created.
300313
MLIRContext &context;
301314
};
302-
} // end namespace
315+
} // end namespace detail
316+
} // end namespace LLVM
317+
} // end namespace mlir
318+
319+
LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
320+
: impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
321+
322+
LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
303323

304-
/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended
305-
/// exclusively for testing.
306-
LLVM::LLVMType mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type,
307-
MLIRContext &context) {
308-
return TypeFromLLVMIRTranslator(context).translateType(type);
324+
LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
325+
return impl->translateType(type);
309326
}

mlir/test/Target/import.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
%struct.t = type {}
44
%struct.s = type { %struct.t, i64 }
55

6-
; CHECK: llvm.mlir.global external @g1() : !llvm.struct<(struct<()>, i64)>
6+
; CHECK: llvm.mlir.global external @g1() : !llvm.struct<"struct.s", (struct<"struct.t", ()>, i64)>
77
@g1 = external global %struct.s, align 8
88
; CHECK: llvm.mlir.global external @g2() : !llvm.double
99
@g2 = external global double, align 8

0 commit comments

Comments
 (0)