Skip to content

Commit aec9e20

Browse files
committed
[mlir] introduce type constraints for operands of LLVM dialect operations
Historically, the operations in the MLIR's LLVM dialect only checked that the operand are of LLVM dialect type without more detailed constraints. This was due to LLVM dialect types wrapping LLVM IR types and having clunky verification methods. With the new first-class modeling, it is possible to define type constraints similarly to other dialects and use them to enforce some correctness rules in verifiers instead of having LLVM assert during translation to LLVM IR. This hardening discovered several issues where MLIR was producing LLVM dialect operations that cannot exist in LLVM IR. Depends On D85900 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D85901
1 parent bdc4c0b commit aec9e20

File tree

6 files changed

+215
-94
lines changed

6 files changed

+215
-94
lines changed

mlir/include/mlir/Dialect/GPU/GPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2121
// Type constraint accepting standard integers, indices and wrapped LLVM integer
2222
// types.
2323
def IntLikeOrLLVMInt : TypeConstraint<
24-
Or<[AnySignlessInteger.predicate, Index.predicate, LLVMInt.predicate]>,
24+
Or<[AnySignlessInteger.predicate, Index.predicate,
25+
LLVM_AnyInteger.predicate]>,
2526
"integer, index or LLVM dialect equivalent">;
2627

2728
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 104 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
include "mlir/IR/OpBase.td"
1818
include "mlir/Interfaces/SideEffectInterfaces.td"
1919

20+
//===----------------------------------------------------------------------===//
21+
// LLVM Dialect.
22+
//===----------------------------------------------------------------------===//
23+
2024
def LLVM_Dialect : Dialect {
2125
let name = "llvm";
2226
let cppNamespace = "LLVM";
@@ -38,34 +42,108 @@ def LLVM_Dialect : Dialect {
3842
}];
3943
}
4044

41-
// LLVM IR type wrapped in MLIR.
45+
//===----------------------------------------------------------------------===//
46+
// LLVM dialect type constraints.
47+
//===----------------------------------------------------------------------===//
48+
49+
// LLVM dialect type.
4250
def LLVM_Type : DialectType<LLVM_Dialect,
4351
CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
4452
"LLVM dialect type">;
4553

46-
// Type constraint accepting only wrapped LLVM integer types.
47-
def LLVMInt : TypeConstraint<
48-
And<[LLVM_Type.predicate,
49-
CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
50-
"LLVM dialect integer">;
54+
// Type constraint accepting LLVM integer types.
55+
def LLVM_AnyInteger : Type<
56+
CPred<"$_self.isa<::mlir::LLVM::LLVMIntegerType>()">,
57+
"LLVM integer type">;
58+
59+
// Type constraints accepting LLVM integer type of a specific width.
60+
class LLVM_IntBase<int width> :
61+
Type<And<[
62+
LLVM_AnyInteger.predicate,
63+
CPred<"$_self.cast<::mlir::LLVM::LLVMIntegerType>().getBitWidth() == "
64+
# width>]>,
65+
"LLVM " # width # "-bit integer type">,
66+
BuildableType<
67+
"::mlir::LLVM::LLVMIntegerType::get($_builder.getContext(), "
68+
# width # ")">;
69+
70+
def LLVM_i1 : LLVM_IntBase<1>;
71+
def LLVM_i8 : LLVM_IntBase<8>;
72+
def LLVM_i32 : LLVM_IntBase<32>;
5173

52-
def LLVMIntBase : TypeConstraint<
74+
// Type constraint accepting LLVM primitive types, i.e. all types except void
75+
// and function.
76+
def LLVM_PrimitiveType : Type<
5377
And<[LLVM_Type.predicate,
54-
CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
55-
"LLVM dialect integer">;
56-
57-
// Integer type of a specific width.
58-
class LLVMI<int width>
59-
: Type<And<[
60-
LLVM_Type.predicate,
61-
CPred<
62-
"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
63-
"LLVM dialect " # width # "-bit integer">,
64-
BuildableType<
65-
"::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext(),"
66-
# width # ")">;
67-
68-
def LLVMI1 : LLVMI<1>;
78+
CPred<"!$_self.isa<::mlir::LLVM::LLVMVoidType, "
79+
"::mlir::LLVM::LLVMFunctionType>()">]>,
80+
"primitive LLVM type">;
81+
82+
// Type constraint accepting any LLVM floating point type.
83+
def LLVM_AnyFloat : Type<
84+
CPred<"$_self.isa<::mlir::LLVM::LLVMBFloatType, "
85+
"::mlir::LLVM::LLVMHalfType, "
86+
"::mlir::LLVM::LLVMFloatType, "
87+
"::mlir::LLVM::LLVMDoubleType>()">,
88+
"floating point LLVM type">;
89+
90+
// Type constraint accepting any LLVM pointer type.
91+
def LLVM_AnyPointer : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMPointerType>()">,
92+
"LLVM pointer type">;
93+
94+
// Type constraint accepting LLVM pointer type with an additional constraint
95+
// on the element type.
96+
class LLVM_PointerTo<Type pointee> : Type<
97+
And<[LLVM_AnyPointer.predicate,
98+
SubstLeaves<
99+
"$_self",
100+
"$_self.cast<::mlir::LLVM::LLVMPointerType>().getElementType()",
101+
pointee.predicate>]>,
102+
"LLVM pointer to " # pointee.description>;
103+
104+
// Type constraint accepting any LLVM structure type.
105+
def LLVM_AnyStruct : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMStructType>()">,
106+
"LLVM structure type">;
107+
108+
// Type constraint accepting opaque LLVM structure type.
109+
def LLVM_OpaqueStruct : Type<
110+
And<[LLVM_AnyStruct.predicate,
111+
CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>;
112+
113+
// Type constraint accepting any LLVM type that can be loaded or stored, i.e. a
114+
// type that has size (not void, function or opaque struct type).
115+
def LLVM_LoadableType : Type<
116+
And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
117+
"LLVM type with size">;
118+
119+
// Type constraint accepting any LLVM aggregate type, i.e. structure or array.
120+
def LLVM_AnyAggregate : Type<
121+
CPred<"$_self.isa<::mlir::LLVM::LLVMStructType, "
122+
"::mlir::LLVM::LLVMArrayType>()">,
123+
"LLVM aggregate type">;
124+
125+
// Type constraint accepting any LLVM non-aggregate type, i.e. not structure or
126+
// array.
127+
def LLVM_AnyNonAggregate : Type<Neg<LLVM_AnyAggregate.predicate>,
128+
"LLVM non-aggregate type">;
129+
130+
// Type constraint accepting any LLVM vector type.
131+
def LLVM_AnyVector : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMVectorType>()">,
132+
"LLVM vector type">;
133+
134+
// Type constraint accepting an LLVM vector type with an additional constraint
135+
// on the vector element type.
136+
class LLVM_VectorOf<Type element> : Type<
137+
And<[LLVM_AnyVector.predicate,
138+
SubstLeaves<
139+
"$_self",
140+
"$_self.cast<::mlir::LLVM::LLVMVectorType>().getElementType()",
141+
element.predicate>]>,
142+
"LLVM vector of " # element.description>;
143+
144+
// Type constraint accepting a constrained type, or a vector of such types.
145+
class LLVM_ScalarOrVectorOf<Type element> :
146+
AnyTypeOf<[element, LLVM_VectorOf<element>]>;
69147

70148
// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
71149
// used to translate to LLVM IR proper.
@@ -85,6 +163,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<OpTrait> traits = []> :
85163
string llvmBuilder = "";
86164
}
87165

166+
//===----------------------------------------------------------------------===//
167+
// Base classes for LLVM dialect operations.
168+
//===----------------------------------------------------------------------===//
169+
88170
// Base class for LLVM operations. All operations get an "llvm." prefix in
89171
// their name automatically. LLVM operations have either zero or one result,
90172
// this class is specialized below for both cases and should not be used

0 commit comments

Comments
 (0)