Skip to content

Commit 55b8bfb

Browse files
committed
[mlir][vector] Rename vector type TD definitions (nfc)
Currently, the Vector dialect TD file includes the following "vector" type definitions: ```mlir def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; def AnyFixedVector : FixedVectorOf<[AnyType]>; def AnyScalableVector : ScalableVectorOf<[AnyType]>; ``` In short: * `AnyVector` _excludes_ 0-D vectors. * `AnyVectorOfAnyRank`, `AnyFixedVector`, and `AnyScalableVector` _include_ 0-D vectors. The naming for "groups" that include 0-D vectors is inconsistent and can be misleading. This patch renames the definitions as follows: ```mlir def AnyVector : VectorOf<[AnyType]>; def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>; def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>; ``` Rationale: * The updated names are more explicit about 0-D vector support. * It becomes clearer that scalable vectors currently allow 0-D vectors - this might warrant a revisit. * The renaming paves the way for adding a new group for "fixed-width vectors excluding 0-D vectors" (e.g., AnyFixedVector), which I plan to introduce in a follow-up patch.
1 parent 4086ead commit 55b8bfb

File tree

4 files changed

+58
-55
lines changed

4 files changed

+58
-55
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
3535
//===----------------------------------------------------------------------===//
3636

3737
class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
38-
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
38+
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
3939
"a vector with length " # length,
4040
"::mlir::VectorType">;
4141

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
100100
op_description # [{ on active lanes. Inactive lanes will keep the value of
101101
the first operand.}];
102102
let arguments = (ins
103-
ScalableVectorOf<[I1]>:$mask,
104-
ScalableVectorOf<[AnyFloat]>:$src1,
105-
ScalableVectorOf<[AnyFloat]>:$src2
103+
ScalableVectorOfAnyRank<[I1]>:$mask,
104+
ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
105+
ScalableVectorOfAnyRank<[AnyFloat]>:$src2
106106
);
107-
let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
107+
let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
108108
let assemblyFormat =
109109
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
110110
}
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
123123
op_description # [{ on active lanes. Inactive lanes will keep the value of
124124
the first operand.}];
125125
let arguments = (ins
126-
ScalableVectorOf<[I1]>:$mask,
127-
ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
128-
ScalableVectorOf<[I8, I16, I32, I64]>:$src2
126+
ScalableVectorOfAnyRank<[I1]>:$mask,
127+
ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
128+
ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
129129
);
130-
let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
130+
let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
131131
let assemblyFormat =
132132
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
133133
}
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
511511

512512
def UmmlaIntrOp :
513513
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
514-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
514+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
515515

516516
def SmmlaIntrOp :
517517
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
518-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
518+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
519519

520520
def SdotIntrOp :
521521
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
522-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
522+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
523523

524524
def UdotIntrOp :
525525
ArmSVE_IntrBinaryOverloadedOp<"udot">,
526-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
526+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
527527

528528
def ScalableMaskedAddIIntrOp :
529529
ArmSVE_IntrBinaryOverloadedOp<"add">,
530-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
530+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
531531

532532
def ScalableMaskedAddFIntrOp :
533533
ArmSVE_IntrBinaryOverloadedOp<"fadd">,
534-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
534+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
535535

536536
def ScalableMaskedMulIIntrOp :
537537
ArmSVE_IntrBinaryOverloadedOp<"mul">,
538-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
538+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
539539

540540
def ScalableMaskedMulFIntrOp :
541541
ArmSVE_IntrBinaryOverloadedOp<"fmul">,
542-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
542+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
543543

544544
def ScalableMaskedSubIIntrOp :
545545
ArmSVE_IntrBinaryOverloadedOp<"sub">,
546-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
546+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
547547

548548
def ScalableMaskedSubFIntrOp :
549549
ArmSVE_IntrBinaryOverloadedOp<"fsub">,
550-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
550+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
551551

552552
def ScalableMaskedSDivIIntrOp :
553553
ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
554-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
554+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
555555

556556
def ScalableMaskedUDivIIntrOp :
557557
ArmSVE_IntrBinaryOverloadedOp<"udiv">,
558-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
558+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
559559

560560
def ScalableMaskedDivFIntrOp :
561561
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
562-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
562+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
563563

564564
def ConvertFromSvboolIntrOp :
565565
ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,19 +581,19 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
581581
/*overloadedOperands=*/[0],
582582
/*overloadedResults=*/[],
583583
/*numResults=*/2>,
584-
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
585-
Arg<AnyScalableVector, "v2">:$v2)>;
584+
Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
585+
Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
586586

587587
// Note: This multi-vector intrinsic requires SME2.
588588
def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
589589
/*traits=*/[],
590590
/*overloadedOperands=*/[0],
591591
/*overloadedResults=*/[],
592592
/*numResults=*/4>,
593-
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
594-
Arg<AnyScalableVector, "v2">:$v2,
595-
Arg<AnyScalableVector, "v3">:$v3,
596-
Arg<AnyScalableVector, "v3">:$v4)>;
593+
Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
594+
Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
595+
Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
596+
Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
597597

598598
// Note: This intrinsic requires SME or SVE2.1.
599599
def PselIntrOp : ArmSVE_IntrOp<"psel",

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
417417
let hasVerifier = 1;
418418
}
419419

420-
def Vector_ShuffleOp :
421-
Vector_Op<"shuffle", [Pure,
422-
PredOpTrait<"first operand v1 and result have same element type",
423-
TCresVTEtIsSameAsOpBase<0, 0>>,
424-
PredOpTrait<"second operand v2 and result have same element type",
425-
TCresVTEtIsSameAsOpBase<0, 1>>,
426-
InferTypeOpAdaptor]>,
427-
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
428-
DenseI64ArrayAttr:$mask)>,
429-
Results<(outs AnyVector:$vector)> {
420+
def Vector_ShuffleOp
421+
: Vector_Op<
422+
"shuffle",
423+
[Pure,
424+
PredOpTrait<"first operand v1 and result have same element type",
425+
TCresVTEtIsSameAsOpBase<0, 0>>,
426+
PredOpTrait<"second operand v2 and result have same element type",
427+
TCresVTEtIsSameAsOpBase<0, 1>>,
428+
InferTypeOpAdaptor]>,
429+
Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
430+
DenseI64ArrayAttr:$mask)>,
431+
Results<(outs AnyVector:$vector)> {
430432
let summary = "shuffle operation";
431433
let description = [{
432434
The shuffle operation constructs a permutation (or duplication) of elements

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
3030
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
3131

3232
// Whether a type is a fixed-length VectorType.
33-
def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
33+
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3434
!::llvm::cast<VectorType>($_self).isScalable()}]>;
3535

3636
// Whether a type is a scalable VectorType.
@@ -438,11 +438,11 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
438438
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
439439
"::mlir::VectorType">;
440440

441-
class FixedVectorOf<list<Type> allowedTypes> :
442-
ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
441+
class FixedVectorOfAnyRank<list<Type> allowedTypes> :
442+
ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
443443
"fixed-length vector", "::mlir::VectorType">;
444444

445-
class ScalableVectorOf<list<Type> allowedTypes> :
445+
class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
446446
ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
447447
"scalable vector", "::mlir::VectorType">;
448448

@@ -467,7 +467,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
467467
// Whether the number of elements of a fixed-length vector is from the given
468468
// `allowedRanks` list
469469
class IsFixedVectorOfRankPred<list<int> allowedRanks> :
470-
And<[IsFixedVectorTypePred,
470+
And<[IsFixedVectorOfAnyRankTypePred,
471471
Or<!foreach(allowedlength, allowedRanks,
472472
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
473473
== }]
@@ -509,8 +509,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
509509
// the type is from the given `allowedTypes` list
510510
class FixedVectorOfRankAndType<list<int> allowedRanks,
511511
list<Type> allowedTypes> : AllOfType<
512-
[FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
513-
FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
512+
[FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
513+
FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
514514
"::mlir::VectorType">;
515515

516516
// Whether the number of elements of a vector is from the given
@@ -525,7 +525,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
525525
// Whether the number of elements of a fixed-length vector is from the given
526526
// `allowedLengths` list
527527
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
528-
And<[IsFixedVectorTypePred,
528+
And<[IsFixedVectorOfAnyRankTypePred,
529529
Or<!foreach(allowedlength, allowedLengths,
530530
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
531531
== }]
@@ -612,17 +612,17 @@ class VectorOfLengthAndType<list<int> allowedLengths,
612612
// `allowedLengths` list and the type is from the given `allowedTypes` list
613613
class FixedVectorOfLengthAndType<list<int> allowedLengths,
614614
list<Type> allowedTypes> : AllOfType<
615-
[FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
616-
FixedVectorOf<allowedTypes>.summary #
615+
[FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
616+
FixedVectorOfAnyRank<allowedTypes>.summary #
617617
FixedVectorOfLength<allowedLengths>.summary,
618618
"::mlir::VectorType">;
619619

620620
// Any scalable vector where the number of elements is from the given
621621
// `allowedLengths` list and the type is from the given `allowedTypes` list
622622
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
623623
list<Type> allowedTypes> : AllOfType<
624-
[ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
625-
ScalableVectorOf<allowedTypes>.summary #
624+
[ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
625+
ScalableVectorOfAnyRank<allowedTypes>.summary #
626626
ScalableVectorOfLength<allowedLengths>.summary,
627627
"::mlir::VectorType">;
628628

@@ -632,10 +632,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
632632
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
633633
list<int> allowedLengths,
634634
list<Type> allowedTypes> : AllOfType<
635-
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
635+
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
636636
ScalableVectorOfLength<allowedLengths>],
637637
ScalableVectorOfRank<allowedRanks>.summary #
638-
ScalableVectorOf<allowedTypes>.summary #
638+
ScalableVectorOfAnyRank<allowedTypes>.summary #
639639
ScalableVectorOfLength<allowedLengths>.summary,
640640
"::mlir::VectorType">;
641641

@@ -657,13 +657,14 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
657657
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
658658
"::mlir::VectorType">;
659659

660+
// Unlike the following definitions, this one excludes 0-D vectors
660661
def AnyVector : VectorOf<[AnyType]>;
661-
// Temporary vector type clone that allows gradual transition to 0-D vectors.
662+
662663
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
663664

664-
def AnyFixedVector : FixedVectorOf<[AnyType]>;
665+
def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
665666

666-
def AnyScalableVector : ScalableVectorOf<[AnyType]>;
667+
def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
667668

668669
// Shaped types.
669670

0 commit comments

Comments
 (0)