@@ -34,8 +34,9 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
34
34
!::llvm::cast<VectorType>($_self).isScalable()}]>;
35
35
36
36
// Whether a type is a scalable VectorType.
37
- def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
38
- ::llvm::cast<VectorType>($_self).isScalable()}]>;
37
+ def IsVectorTypeWithAnyDimScalablePred
38
+ : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
39
+ ::llvm::cast<VectorType>($_self).isScalable()}]>;
39
40
40
41
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
41
42
// Examples:
@@ -51,7 +52,7 @@ def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
51
52
]>;
52
53
53
54
// Whether a type is a VectorType and all dimensions are scalable.
54
- def allDimsScalableVectorTypePred : And<[
55
+ def IsVectorTypeWithAllDimsScalablePred : And<[
55
56
IsVectorTypePred,
56
57
CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
57
58
]>;
@@ -414,7 +415,7 @@ class FixedVectorOf<list<Type> allowedTypes> :
414
415
"fixed-length vector", "::mlir::VectorType">;
415
416
416
417
class ScalableVectorOf<list<Type> allowedTypes> :
417
- ShapedContainerType<allowedTypes, IsScalableVectorTypePred ,
418
+ ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred ,
418
419
"scalable vector", "::mlir::VectorType">;
419
420
420
421
// Any vector with a single trailing scalable dimension, with an element type in
@@ -447,7 +448,7 @@ class IsFixedVectorOfRankPred<list<int> allowedRanks> :
447
448
// Whether the number of elements of a scalable vector is from the given
448
449
// `allowedRanks` list
449
450
class IsScalableVectorOfRankPred<list<int> allowedRanks> :
450
- And<[IsScalableVectorTypePred ,
451
+ And<[IsVectorTypeWithAnyDimScalablePred ,
451
452
Or<!foreach(allowedlength, allowedRanks,
452
453
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
453
454
== }]
@@ -497,7 +498,7 @@ class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
497
498
// Whether the number of elements of a scalable vector is from the given
498
499
// `allowedLengths` list
499
500
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
500
- And<[IsScalableVectorTypePred ,
501
+ And<[IsVectorTypeWithAnyDimScalablePred ,
501
502
Or<!foreach(allowedlength, allowedLengths,
502
503
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
503
504
== }]
0 commit comments