Skip to content

Commit 07fc852

Browse files
Vladislav Vinogradovjoker-eph
authored andcommitted
[mlir][ODS] Small fixes for ODS classes
* Introduce separate `RankedTensorOf` class. Use it as base class for `AnyRankedTensor`. * Add C++ class specification (`::mlir::MemRefType`) to `MemRefRankOf` and `StaticShapeMemRefOf`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D95936
1 parent d84e5fd commit 07fc852

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

mlir/include/mlir/IR/OpBase.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,11 +604,13 @@ class TensorOf<list<Type> allowedTypes> :
604604
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
605605
"::mlir::TensorType">;
606606

607+
class RankedTensorOf<list<Type> allowedTypes> :
608+
ShapedContainerType<allowedTypes, And<[IsTensorTypePred, HasRankPred]>,
609+
"ranked tensor", "::mlir::TensorType">;
610+
607611
def AnyTensor : TensorOf<[AnyType]>;
608612

609-
def AnyRankedTensor :
610-
ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>,
611-
"ranked tensor", "::mlir::TensorType">;
613+
def AnyRankedTensor : RankedTensorOf<[AnyType]>;
612614

613615
// TODO: Have an easy way to add another constraint to a type.
614616
class StaticShapeTensorOf<list<Type> allowedTypes>
@@ -675,11 +677,13 @@ def F64MemRef : MemRefOf<[F64]>;
675677
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
676678
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
677679
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
678-
MemRefOf<allowedTypes>.summary>;
680+
MemRefOf<allowedTypes>.summary,
681+
"::mlir::MemRefType">;
679682

680683
class StaticShapeMemRefOf<list<Type> allowedTypes>
681684
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
682-
"statically shaped " # MemRefOf<allowedTypes>.summary>;
685+
"statically shaped " # MemRefOf<allowedTypes>.summary,
686+
"::mlir::MemRefType">;
683687

684688
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
685689

0 commit comments

Comments
 (0)