Skip to content

Commit 3eb2e43

Browse files
committed
[mlir] Use new VectorType wrappers CommonTypeConstraints.td
As a follow-on for llvm#87986, moves the VectorType convenience wrappers (`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This allows us to use the new wrappers in "CommonTypeConstraints.td".
1 parent 52b9d0b commit 3eb2e43

File tree

4 files changed

+36
-49
lines changed

4 files changed

+36
-49
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
401401
SliceVerificationResult isRankReducedType(ShapedType originalType,
402402
ShapedType candidateReducedType);
403403

404+
//===----------------------------------------------------------------------===//
405+
// Convenience wrappers for VectorType
406+
//
407+
// These are provided to allow idiomatic code like:
408+
// * isa<vector::ScalableVectorType>(type)
409+
//===----------------------------------------------------------------------===//
410+
/// A vector type containing at least one scalable dimension.
411+
class ScalableVectorType : public VectorType {
412+
public:
413+
using VectorType::VectorType;
414+
415+
static bool classof(Type type) {
416+
auto vecTy = llvm::dyn_cast<VectorType>(type);
417+
if (!vecTy)
418+
return false;
419+
return vecTy.isScalable();
420+
}
421+
};
422+
423+
/// A vector type with no scalable dimensions.
424+
class FixedVectorType : public VectorType {
425+
public:
426+
using VectorType::VectorType;
427+
static bool classof(Type type) {
428+
auto vecTy = llvm::dyn_cast<VectorType>(type);
429+
if (!vecTy)
430+
return false;
431+
return !vecTy.isScalable();
432+
}
433+
};
434+
404435
//===----------------------------------------------------------------------===//
405436
// Deferred Method Definitions
406437
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
2424
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
2525
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
2626
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
27-
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
28-
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
29-
CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
27+
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
28+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
3029

3130
// Temporary vector type clone that allows gradual transition to 0-D vectors.
3231
// TODO: Remove this when all ops support 0-D vectors.
3332
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
3433

3534
// Whether a type is a fixed-length VectorType.
36-
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
37-
!::llvm::cast<VectorType>($_self).isScalable()}]>;
35+
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
3836

3937
// Whether a type is a scalable VectorType.
4038
def IsVectorTypeWithAnyDimScalablePred
41-
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
42-
::llvm::cast<VectorType>($_self).isScalable()}]>;
39+
: CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
4340

4441
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
4542
// Examples:

mlir/include/mlir/IR/VectorTypes.h

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,3 @@
1010
// * isa<vector::ScalableVectorType>(type)
1111
//
1212
//===----------------------------------------------------------------------===//
13-
14-
#ifndef MLIR_IR_VECTORTYPES_H
15-
#define MLIR_IR_VECTORTYPES_H
16-
17-
#include "mlir/IR/BuiltinTypes.h"
18-
#include "mlir/IR/Types.h"
19-
20-
namespace mlir {
21-
namespace vector {
22-
23-
/// A vector type containing at least one scalable dimension.
24-
class ScalableVectorType : public VectorType {
25-
public:
26-
using VectorType::VectorType;
27-
28-
static bool classof(Type type) {
29-
auto vecTy = llvm::dyn_cast<VectorType>(type);
30-
if (!vecTy)
31-
return false;
32-
return vecTy.isScalable();
33-
}
34-
};
35-
36-
/// A vector type with no scalable dimensions.
37-
class FixedVectorType : public VectorType {
38-
public:
39-
using VectorType::VectorType;
40-
static bool classof(Type type) {
41-
auto vecTy = llvm::dyn_cast<VectorType>(type);
42-
if (!vecTy)
43-
return false;
44-
return !vecTy.isScalable();
45-
}
46-
};
47-
48-
} // namespace vector
49-
} // namespace mlir
50-
51-
#endif // MLIR_IR_VECTORTYPES_H

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/IR/TypeUtilities.h"
24-
#include "mlir/IR/VectorTypes.h"
2524
#include "mlir/Support/LogicalResult.h"
2625

2726
#include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
226225
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
227226
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
228227
// However, this would most likely require updating the lowerings to LLVM.
229-
if (isa<vector::ScalableVectorType>(type) &&
230-
!isa<SplatElementsAttr>(getValue()))
228+
if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
231229
return emitOpError(
232230
"intializing scalable vectors with elements attribute is not supported"
233231
" unless it's a vector splat");

0 commit comments

Comments
 (0)