-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Use new VectorType wrappers CommonTypeConstraints.td #118645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As a follow-on for llvm#87986, moves the VectorType convenience wrappers (`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This allows us to use the new wrappers in "CommonTypeConstraints.td".
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-arith Author: Andrzej Warzyński (banach-space) ChangesAs a follow-on for #87986, moves the VectorType convenience wrappers Full diff: https://github.com/llvm/llvm-project/pull/118645.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+// * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return vecTy.isScalable();
+ }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return !vecTy.isScalable();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
- CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
// Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
- : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- ::llvm::cast<VectorType>($_self).isScalable()}]>;
+ : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
// * isa<vector::ScalableVectorType>(type)
//
//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
- using VectorType::VectorType;
-
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return vecTy.isScalable();
- }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
- using VectorType::VectorType;
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return !vecTy.isScalable();
- }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
- if (isa<vector::ScalableVectorType>(type) &&
- !isa<SplatElementsAttr>(getValue()))
+ if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
|
@llvm/pr-subscribers-mlir-core Author: Andrzej Warzyński (banach-space) ChangesAs a follow-on for #87986, moves the VectorType convenience wrappers Full diff: https://github.com/llvm/llvm-project/pull/118645.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+// * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return vecTy.isScalable();
+ }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return !vecTy.isScalable();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
- CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
// Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
- : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- ::llvm::cast<VectorType>($_self).isScalable()}]>;
+ : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
// * isa<vector::ScalableVectorType>(type)
//
//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
- using VectorType::VectorType;
-
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return vecTy.isScalable();
- }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
- using VectorType::VectorType;
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return !vecTy.isScalable();
- }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
- if (isa<vector::ScalableVectorType>(type) &&
- !isa<SplatElementsAttr>(getValue()))
+ if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
|
Add missing empty line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool
As a follow-on for #87986, moves the VectorType convenience wrappers
(
FixedVectorType
andScalableVectorType
) to BuiltinTypes.h. Thisallows us to use the new wrappers in "CommonTypeConstraints.td".