Skip to content

Commit b214ca8

Browse files
authored
[mlir][vector] Rename vector type TD definitions (nfc) (#117150)
Currently, the Vector dialect TD file includes the following "vector" type definitions: ```mlir def AnyVector : VectorOf<[AnyType]>; def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; def AnyFixedVector : FixedVectorOf<[AnyType]>; def AnyScalableVector : ScalableVectorOf<[AnyType]>; ``` In short: * `AnyVector` _excludes_ 0-D vectors. * `AnyVectorOfAnyRank`, `AnyFixedVector`, and `AnyScalableVector` _include_ 0-D vectors. The naming for "groups" that include 0-D vectors is inconsistent and can be misleading, and `AnyVector` implies that 0-D vectors are included, which is not the case. This patch renames these definitions for clarity: ```mlir def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>; def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>; def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>; ``` Rationale: * The updated names are more explicit about 0-D vector support. * It becomes clearer that scalable vectors currently allow 0-D vectors - this might warrant a revisit. * The renaming paves the way for adding a new group for "fixed-width vectors excluding 0-D vectors" (e.g., AnyFixedVector), which I plan to introduce in a follow-up patch.
1 parent db6f627 commit b214ca8

File tree

9 files changed

+121
-118
lines changed

9 files changed

+121
-118
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
964964
(see [vector.transfer_read](../Vector/#vectortransfer_read-mlirvectortransferreadop)).
965965
}];
966966

967-
let results = (outs AnyVector:$result);
967+
let results = (outs AnyVectorOfNonZeroRank:$result);
968968

969969
let builders = [
970970
/// Builds an affine vector load op with the specified map and operands.
@@ -1031,7 +1031,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
10311031
(see [vector.transfer_write](../Vector/#vectortransfer_write-mlirvectortransferwriteop)).
10321032
}];
10331033

1034-
let arguments = (ins AnyVector:$value,
1034+
let arguments = (ins AnyVectorOfNonZeroRank:$value,
10351035
Arg<AnyMemRef, "the reference to store to",
10361036
[MemWrite]>:$memref,
10371037
Variadic<Index>:$indices,

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
3535
//===----------------------------------------------------------------------===//
3636

3737
class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
38-
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
38+
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
3939
"a vector with length " # length,
4040
"::mlir::VectorType">;
4141

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
371371
let arguments = (ins
372372
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
373373
Variadic<Index>:$indices,
374-
Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
374+
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
375375
ArmSME_TileSliceLayoutAttr:$layout
376376
);
377377
let results = (outs SMETile:$result);
@@ -444,7 +444,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
444444
}];
445445
let arguments = (ins SMETile:$valueToStore,
446446
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
447-
Variadic<Index>:$indices, Optional<AnyVector>:$mask,
447+
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
448448
ArmSME_TileSliceLayoutAttr:$layout
449449
);
450450
let extraClassDeclaration = [{
@@ -799,9 +799,9 @@ class OuterProductWideningBase<string mnemonic,
799799
]> {
800800

801801
let arguments = (ins
802-
AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
803-
Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
804-
Optional<AnyVector>:$acc);
802+
AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVectorOfNonZeroRank:$rhs,
803+
Optional<AnyVectorOfNonZeroRank>:$lhsMask, Optional<AnyVectorOfNonZeroRank>:$rhsMask,
804+
Optional<AnyVectorOfNonZeroRank>:$acc);
805805
let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
806806

807807
let assemblyFormat = [{

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
100100
op_description # [{ on active lanes. Inactive lanes will keep the value of
101101
the first operand.}];
102102
let arguments = (ins
103-
ScalableVectorOf<[I1]>:$mask,
104-
ScalableVectorOf<[AnyFloat]>:$src1,
105-
ScalableVectorOf<[AnyFloat]>:$src2
103+
ScalableVectorOfAnyRank<[I1]>:$mask,
104+
ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
105+
ScalableVectorOfAnyRank<[AnyFloat]>:$src2
106106
);
107-
let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
107+
let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
108108
let assemblyFormat =
109109
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
110110
}
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
123123
op_description # [{ on active lanes. Inactive lanes will keep the value of
124124
the first operand.}];
125125
let arguments = (ins
126-
ScalableVectorOf<[I1]>:$mask,
127-
ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
128-
ScalableVectorOf<[I8, I16, I32, I64]>:$src2
126+
ScalableVectorOfAnyRank<[I1]>:$mask,
127+
ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
128+
ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
129129
);
130-
let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
130+
let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
131131
let assemblyFormat =
132132
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
133133
}
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
511511

512512
def UmmlaIntrOp :
513513
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
514-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
514+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
515515

516516
def SmmlaIntrOp :
517517
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
518-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
518+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
519519

520520
def SdotIntrOp :
521521
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
522-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
522+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
523523

524524
def UdotIntrOp :
525525
ArmSVE_IntrBinaryOverloadedOp<"udot">,
526-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
526+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
527527

528528
def ScalableMaskedAddIIntrOp :
529529
ArmSVE_IntrBinaryOverloadedOp<"add">,
530-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
530+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
531531

532532
def ScalableMaskedAddFIntrOp :
533533
ArmSVE_IntrBinaryOverloadedOp<"fadd">,
534-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
534+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
535535

536536
def ScalableMaskedMulIIntrOp :
537537
ArmSVE_IntrBinaryOverloadedOp<"mul">,
538-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
538+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
539539

540540
def ScalableMaskedMulFIntrOp :
541541
ArmSVE_IntrBinaryOverloadedOp<"fmul">,
542-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
542+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
543543

544544
def ScalableMaskedSubIIntrOp :
545545
ArmSVE_IntrBinaryOverloadedOp<"sub">,
546-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
546+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
547547

548548
def ScalableMaskedSubFIntrOp :
549549
ArmSVE_IntrBinaryOverloadedOp<"fsub">,
550-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
550+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
551551

552552
def ScalableMaskedSDivIIntrOp :
553553
ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
554-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
554+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
555555

556556
def ScalableMaskedUDivIIntrOp :
557557
ArmSVE_IntrBinaryOverloadedOp<"udiv">,
558-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
558+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
559559

560560
def ScalableMaskedDivFIntrOp :
561561
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
562-
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
562+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
563563

564564
def ConvertFromSvboolIntrOp :
565565
ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,19 +581,19 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
581581
/*overloadedOperands=*/[0],
582582
/*overloadedResults=*/[],
583583
/*numResults=*/2>,
584-
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
585-
Arg<AnyScalableVector, "v2">:$v2)>;
584+
Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
585+
Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
586586

587587
// Note: This multi-vector intrinsic requires SME2.
588588
def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
589589
/*traits=*/[],
590590
/*overloadedOperands=*/[0],
591591
/*overloadedResults=*/[],
592592
/*numResults=*/4>,
593-
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
594-
Arg<AnyScalableVector, "v2">:$v2,
595-
Arg<AnyScalableVector, "v3">:$v3,
596-
Arg<AnyScalableVector, "v3">:$v4)>;
593+
Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
594+
Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
595+
Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
596+
Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
597597

598598
// Note: This intrinsic requires SME or SVE2.1.
599599
def PselIntrOp : ArmSVE_IntrOp<"psel",

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [
255255
let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$srcMemref,
256256
Variadic<Index>:$indices, BoolAttr:$transpose,
257257
I32Attr:$numTiles);
258-
let results = (outs AnyVector:$res);
258+
let results = (outs AnyVectorOfNonZeroRank:$res);
259259
let assemblyFormat = [{
260260
$srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
261261
}];
@@ -301,13 +301,13 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
301301
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
302302
```
303303
}];
304-
let arguments = (ins AnyVector:$matrixA,
305-
AnyVector:$matrixB,
306-
AnyVector:$matrixC,
304+
let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
305+
AnyVectorOfNonZeroRank:$matrixB,
306+
AnyVectorOfNonZeroRank:$matrixC,
307307
I64ArrayAttr:$mmaShape,
308308
OptionalAttr<UnitAttr>:$tf32Enabled);
309309

310-
let results = (outs AnyVector:$res);
310+
let results = (outs AnyVectorOfNonZeroRank:$res);
311311

312312
let builders = [
313313
OpBuilder<(ins "Value":$matrixA,
@@ -357,16 +357,16 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
357357
```
358358
}];
359359

360-
let arguments = (ins AnyVector:$matrixA,
361-
AnyVector:$matrixB,
362-
AnyVector:$matrixC,
360+
let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
361+
AnyVectorOfNonZeroRank:$matrixB,
362+
AnyVectorOfNonZeroRank:$matrixC,
363363
NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
364364
I64ArrayAttr:$mmaShape,
365365
DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
366366
OptionalAttr<UnitAttr>:$tf32Enabled
367367
);
368368

369-
let results = (outs AnyVector:$res);
369+
let results = (outs AnyVectorOfNonZeroRank:$res);
370370

371371
let builders = [
372372
OpBuilder<(ins "Value":$matrixA,
@@ -825,10 +825,10 @@ def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure,
825825

826826
The input and output must be of the same vector type and shape.
827827
}];
828-
let arguments = (ins VectorOf<[F32]>:$in,
828+
let arguments = (ins VectorOfNonZeroRankOf<[F32]>:$in,
829829
DefaultValuedAttr<RcpRoundingModeAttr, "RcpRoundingMode::APPROX">:$rounding,
830830
UnitAttr:$ftz);
831-
let results = (outs VectorOf<[F32]>:$out);
831+
let results = (outs VectorOfNonZeroRankOf<[F32]>:$out);
832832
let assemblyFormat = [{
833833
$in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}`
834834
attr-dict `:` type($out)

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
166166

167167
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
168168
AnyTypeOf<types>.predicate,
169-
VectorOf<types>.predicate,
169+
VectorOfNonZeroRankOf<types>.predicate,
170170
TosaTensorOf<types>.predicate]>,
171171
description>;
172172

0 commit comments

Comments
 (0)