Skip to content

[mlir] Add a ValueSemantics trait. #99493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ struct IntegerTypeStorage;
struct TupleTypeStorage;
} // namespace detail

/// Type trait indicating that the type has value semantics.
template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 13 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
let typeName = "builtin." # typeMnemonic;
}

//===----------------------------------------------------------------------===//
// Traits
//===----------------------------------------------------------------------===//

/// Type trait indicating that the type has value semantics.
def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
let cppNamespace = "::mlir";
}

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -745,7 +754,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
//===----------------------------------------------------------------------===//

def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
ShapedTypeInterface
ShapedTypeInterface, ValueSemantics
], "TensorType"> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Expand Down Expand Up @@ -1001,7 +1010,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
//===----------------------------------------------------------------------===//

def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
ShapedTypeInterface
ShapedTypeInterface, ValueSemantics
], "TensorType"> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Expand Down Expand Up @@ -1049,7 +1058,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//

def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
let description = [{
Syntax:
Expand Down
26 changes: 21 additions & 5 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def HasStaticShapePred :
// Whether a type is a TupleType.
def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;

// Whether a type has a ValueSemantics trait.
def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;

//===----------------------------------------------------------------------===//
// Type definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -403,6 +406,11 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
]>;

// Container with value semantics.
class ValueSemanticsContainerOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
"container with value semantics">;

// Vector types.

class VectorOf<list<Type> allowedTypes> :
Expand Down Expand Up @@ -842,10 +850,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
// Common type constraints
//===----------------------------------------------------------------------===//
// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T, a vector of Ts, or a tensor of Ts
// they're either a T, a vector of Ts, or a tensor of Ts.
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
allowedType.predicate, VectorOf<[allowedType]>.predicate,
TensorOf<[allowedType]>.predicate]>,
allowedType.predicate,
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
name>;

// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T or a mapable container of Ts.
class TypeOrValueSemanticsContainer<Type allowedType, string name>
: TypeConstraint<Or<[
allowedType.predicate,
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
name>;

// Temporary constraint to allow gradual transition to supporting 0-D vectors.
Expand All @@ -864,8 +880,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;

// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
"signless-integer-like">;
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
AnySignlessIntegerOrIndex, "signless-integer-like">;

def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
AnySignlessIntegerOrIndex,
Expand Down
Loading