Skip to content

[mlir] Use new VectorType wrappers CommonTypeConstraints.td #118645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 4, 2024

Conversation

banach-space
Copy link
Contributor

As a follow-on for #87986, moves the VectorType convenience wrappers
(FixedVectorType and ScalableVectorType) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".

As a follow-on for llvm#87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".
@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2024

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Andrzej Warzyński (banach-space)

Changes

As a follow-on for #87986, moves the VectorType convenience wrappers
(FixedVectorType and ScalableVectorType) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".


Full diff: https://github.com/llvm/llvm-project/pull/118645.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+31)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-7)
  • (modified) mlir/include/mlir/IR/VectorTypes.h (-39)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+1-3)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
 SliceVerificationResult isRankReducedType(ShapedType originalType,
                                           ShapedType candidateReducedType);
 
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+//  * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return vecTy.isScalable();
+  }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return !vecTy.isScalable();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                                          CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
-                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
-                                              CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                                  !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType.
 def IsVectorTypeWithAnyDimScalablePred
-        : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                  ::llvm::cast<VectorType>($_self).isScalable()}]>;
+        : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType, with a single trailing scalable dimension.
 // Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
 //  * isa<vector::ScalableVectorType>(type)
 //
 //===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return vecTy.isScalable();
-  }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return !vecTy.isScalable();
-  }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
 #include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
   // Note, we could relax this for vectors with 1 scalable dim, e.g.:
   //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
   // However, this would most likely require updating the lowerings to LLVM.
-  if (isa<vector::ScalableVectorType>(type) &&
-      !isa<SplatElementsAttr>(getValue()))
+  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2024

@llvm/pr-subscribers-mlir-core

Author: Andrzej Warzyński (banach-space)

Changes

As a follow-on for #87986, moves the VectorType convenience wrappers
(FixedVectorType and ScalableVectorType) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".


Full diff: https://github.com/llvm/llvm-project/pull/118645.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+31)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-7)
  • (modified) mlir/include/mlir/IR/VectorTypes.h (-39)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+1-3)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
 SliceVerificationResult isRankReducedType(ShapedType originalType,
                                           ShapedType candidateReducedType);
 
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+//  * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return vecTy.isScalable();
+  }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return !vecTy.isScalable();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                                          CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
-                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
-                                              CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                                  !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType.
 def IsVectorTypeWithAnyDimScalablePred
-        : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                  ::llvm::cast<VectorType>($_self).isScalable()}]>;
+        : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType, with a single trailing scalable dimension.
 // Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
 //  * isa<vector::ScalableVectorType>(type)
 //
 //===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return vecTy.isScalable();
-  }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return !vecTy.isScalable();
-  }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
 #include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
   // Note, we could relax this for vectors with 1 scalable dim, e.g.:
   //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
   // However, this would most likely require updating the lowerings to LLVM.
-  if (isa<vector::ScalableVectorType>(type) &&
-      !isa<SplatElementsAttr>(getValue()))
+  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

@banach-space banach-space merged commit e84c918 into llvm:main Dec 4, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants