Skip to content

Commit d64b3e4

Browse files
committed
[mlir] Avoid needlessly converting LLVM named structs with compatible elements
Conversion of LLVM named structs leads to them being renamed since we cannot modify the body of the struct type once it is set. Previously, this applied to all named struct types, even if their element types were not affected by the conversion. Make this behvaior only applicable when element types are changed. This requires making the LLVM dialect type-compatibility check recursively look at the element types (arguably, it should have been doing than since the moment the LLVM dialect type system stopped being closed). In addition, have a more lax check for outer types only to avoid repeated check when necessary (e.g., parser, verifiers that are going to also look at the inner type). Reviewed By: wsmoses Differential Revision: https://reviews.llvm.org/D115037
1 parent 34a43f2 commit d64b3e4

File tree

5 files changed

+91
-17
lines changed

5 files changed

+91
-17
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,10 @@ void printType(Type type, AsmPrinter &printer);
429429
/// Returns `true` if the given type is compatible with the LLVM dialect.
430430
bool isCompatibleType(Type type);
431431

432+
/// Returns `true` if the given outer type is compatible with the LLVM dialect
433+
/// without checking its potential nested types such as struct elements.
434+
bool isCompatibleOuterType(Type type);
435+
432436
/// Returns `true` if the given type is a floating-point type compatible with
433437
/// the LLVM dialect.
434438
bool isCompatibleFloatingPointType(Type type);

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
5555
});
5656
addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
5757
ArrayRef<Type> callStack) -> llvm::Optional<LogicalResult> {
58+
// Fastpath for types that won't be converted by this callback anyway.
59+
if (LLVM::isCompatibleType(type)) {
60+
results.push_back(type);
61+
return success();
62+
}
63+
5864
if (type.isIdentified()) {
5965
auto convertedType = LLVM::LLVMStructType::getIdentified(
6066
type.getContext(), ("_Converted_" + type.getName()).str());

mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
468468
Type type = dispatchParse(parser, /*allowAny=*/false);
469469
if (!type)
470470
return type;
471-
if (!isCompatibleType(type)) {
471+
if (!isCompatibleOuterType(type)) {
472472
parser.emitError(loc) << "unexpected type, expected keyword";
473473
return nullptr;
474474
}

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===//
21
//
32
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43
// See https://llvm.org/LICENSE.txt for license information.
@@ -19,6 +18,7 @@
1918
#include "mlir/IR/DialectImplementation.h"
2019
#include "mlir/IR/TypeSupport.h"
2120

21+
#include "llvm/ADT/ScopeExit.h"
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/Support/TypeSize.h"
2424

@@ -120,9 +120,10 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
120120
//===----------------------------------------------------------------------===//
121121

122122
bool LLVMPointerType::isValidElementType(Type type) {
123-
return isCompatibleType(type) ? !type.isa<LLVMVoidType, LLVMTokenType,
124-
LLVMMetadataType, LLVMLabelType>()
125-
: type.isa<PointerElementTypeInterface>();
123+
return isCompatibleOuterType(type)
124+
? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
125+
LLVMLabelType>()
126+
: type.isa<PointerElementTypeInterface>();
126127
}
127128

128129
LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
@@ -483,17 +484,9 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
483484
// Utility functions.
484485
//===----------------------------------------------------------------------===//
485486

486-
bool mlir::LLVM::isCompatibleType(Type type) {
487-
// Only signless integers are compatible.
488-
if (auto intType = type.dyn_cast<IntegerType>())
489-
return intType.isSignless();
490-
491-
// 1D vector types are compatible if their element types are.
492-
if (auto vecType = type.dyn_cast<VectorType>())
493-
return vecType.getRank() == 1 && isCompatibleType(vecType.getElementType());
494-
487+
bool mlir::LLVM::isCompatibleOuterType(Type type) {
495488
// clang-format off
496-
return type.isa<
489+
if (type.isa<
497490
BFloat16Type,
498491
Float16Type,
499492
Float32Type,
@@ -512,8 +505,75 @@ bool mlir::LLVM::isCompatibleType(Type type) {
512505
LLVMScalableVectorType,
513506
LLVMVoidType,
514507
LLVMX86MMXType
515-
>();
516-
// clang-format on
508+
>()) {
509+
// clang-format on
510+
return true;
511+
}
512+
513+
// Only signless integers are compatible.
514+
if (auto intType = type.dyn_cast<IntegerType>())
515+
return intType.isSignless();
516+
517+
// 1D vector types are compatible.
518+
if (auto vecType = type.dyn_cast<VectorType>())
519+
return vecType.getRank() == 1;
520+
521+
return false;
522+
}
523+
524+
static bool isCompatibleImpl(Type type, SetVector<Type> &callstack) {
525+
if (callstack.contains(type))
526+
return true;
527+
528+
callstack.insert(type);
529+
auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); });
530+
531+
auto isCompatible = [&](Type type) {
532+
return isCompatibleImpl(type, callstack);
533+
};
534+
535+
return llvm::TypeSwitch<Type, bool>(type)
536+
.Case<LLVMStructType>([&](auto structType) {
537+
return llvm::all_of(structType.getBody(), isCompatible);
538+
})
539+
.Case<LLVMFunctionType>([&](auto funcType) {
540+
return isCompatible(funcType.getReturnType()) &&
541+
llvm::all_of(funcType.getParams(), isCompatible);
542+
})
543+
.Case<IntegerType>([](auto intType) { return intType.isSignless(); })
544+
.Case<VectorType>([&](auto vecType) {
545+
return vecType.getRank() == 1 && isCompatible(vecType.getElementType());
546+
})
547+
// clang-format off
548+
.Case<
549+
LLVMPointerType,
550+
LLVMFixedVectorType,
551+
LLVMScalableVectorType,
552+
LLVMArrayType
553+
>([&](auto containerType) {
554+
return isCompatible(containerType.getElementType());
555+
})
556+
.Case<
557+
BFloat16Type,
558+
Float16Type,
559+
Float32Type,
560+
Float64Type,
561+
Float80Type,
562+
Float128Type,
563+
LLVMLabelType,
564+
LLVMMetadataType,
565+
LLVMPPCFP128Type,
566+
LLVMTokenType,
567+
LLVMVoidType,
568+
LLVMX86MMXType
569+
>([](Type) { return true; })
570+
// clang-format on
571+
.Default([](Type) { return false; });
572+
}
573+
574+
bool mlir::LLVM::isCompatibleType(Type type) {
575+
SetVector<Type> callstack;
576+
return isCompatibleImpl(type, callstack);
517577
}
518578

519579
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {

mlir/test/Conversion/StandardToLLVM/convert-types.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ func private @struct_ptr() -> !llvm.struct<(ptr<!test.smpla>)>
1616
// CHECK: !llvm.struct<"_Converted_named", (ptr<i42>)>
1717
func private @named_struct_ptr() -> !llvm.struct<"named", (ptr<!test.smpla>)>
1818

19+
// CHECK-LABEL: @named_no_convert
20+
// CHECK: !llvm.struct<"no_convert", (ptr<struct<"no_convert">>)>
21+
func private @named_no_convert() -> !llvm.struct<"no_convert", (ptr<struct<"no_convert">>)>
22+
1923
// CHECK-LABEL: @array_ptr()
2024
// CHECK: !llvm.array<10 x ptr<i42>>
2125
func private @array_ptr() -> !llvm.array<10 x ptr<!test.smpla>>

0 commit comments

Comments
 (0)