Skip to content

[mlir] Add ScalableVectorType and FixedVectorType #87986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Apr 8, 2024

This PR adds two small convenience Vector types:

  • ScalableVectorType and FixedVectorType.

The goal of these new types is two-fold:

  • Enable idiomatic checks like isa<ScalableVectorType>(...).
  • Make the split into "Scalable" and "Fixed-wdith" vectors a bit more
    explicit and more visible in the code-base.

The new types are added in mlir/include/mlir/IR (instead of e.g.
mlir/include/mlir/Dialect/Vector) so that the new types can be used
without requiring any new dependency (e.g. on the Vector dialect).

@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-arith

Author: Andrzej Warzyński (banach-space)

Changes

This PR adds two small convenience Vector types:

  • ScalableVectorType and FixedWidthVectorType.

The goal of these new types is two-fold:

  • enable idiomatic checks like isa&lt;ScalableVectorType&gt;(...),
  • make the split into "Scalable" and "Fixed-wdith" vectors a bit more
    explicit and more visible in the code-base.

Full diff: https://github.com/llvm/llvm-project/pull/87986.diff

4 Files Affected:

  • (added) mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h (+35)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/IR/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/IR/VectorTypes.cpp (+27)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h b/mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h
new file mode 100644
index 00000000000000..384969779d8f6d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h
@@ -0,0 +1,35 @@
+//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+#define MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace vector {
+
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type);
+};
+
+class FixedWidthVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type);
+};
+
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index efc4bfe622d53a..b66337cb07bacf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -214,8 +215,8 @@ LogicalResult arith::ConstantOp::verify() {
         "value must be an integer, float, or elements attribute");
   }
 
-  auto vecType = dyn_cast<VectorType>(type);
-  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
+  if (isa<vector::ScalableVectorType>(type) &&
+      !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 204462ffd047c6..6638feae1e140f 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorDialect
   VectorOps.cpp
+  VectorTypes.cpp
   ValueBoundsOpInterfaceImpl.cpp
   ScalableValueBoundsConstraintSet.cpp
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorTypes.cpp b/mlir/lib/Dialect/Vector/IR/VectorTypes.cpp
new file mode 100644
index 00000000000000..439040e73938d8
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/IR/VectorTypes.cpp
@@ -0,0 +1,27 @@
+//===- VectorTypes.cpp - MLIR Vector Types --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+bool ScalableVectorType::classof(Type type) {
+  auto vecTy = llvm::dyn_cast<VectorType>(type);
+  if (!vecTy)
+    return false;
+  return vecTy.isScalable();
+}
+
+bool FixedWidthVectorType::classof(Type type) {
+  auto vecTy = llvm::dyn_cast<VectorType>(type);
+  if (!vecTy)
+    return false;
+  return !vecTy.isScalable();
+}

@banach-space
Copy link
Contributor Author

@dcaballe mindful that you are working on VectorType refactor, I'm only sharing this as a data point in the discussion. if people like it then I'm happy to land it, but I'd rather wait for you to share your ideas before taking this further.

I do like how non-intrusive this is (thanks for the suggestion @kuhar !) - it could help us evaluate whether we need to extend VectorType or whether it's a matter of fine-tuning the APIs and the underlying implementation.

@banach-space banach-space force-pushed the andrzej/add_vec_convenience_types branch from 4f29ec7 to 796ebb3 Compare April 8, 2024 15:11
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Apr 8, 2024
Copy link

github-actions bot commented Apr 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I guess that we should also be able to use this in ODS?

using VectorType::VectorType;

static bool classof(Type type);
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need something akin to resolveTypeID? I saw that @joker-eph added those for ops in #87170 but I'm not sure if the same thing applies to types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to only be required for operations for now I believe. Worst case they could probably be easy to make available using using VectorType::resolveTypeID;

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great with the review comments applied, thank you! Love these kind of convenient strong types.

using VectorType::VectorType;

static bool classof(Type type);
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to only be required for operations for now I believe. Worst case they could probably be easy to make available using using VectorType::resolveTypeID;

@banach-space banach-space force-pushed the andrzej/add_vec_convenience_types branch from c248b37 to cc17e63 Compare April 16, 2024 15:52
@@ -16,17 +16,29 @@
namespace mlir {
namespace vector {

/// A vector type containing at least one scalable dimension
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ubernit: missing . at the end

This PR adds two small convenience Vector types:

  * `ScalableVectorType` and `FixedWidthVectorType`.

The goal of these new types is two-fold:
  * enable idiomatic checks like `isa<ScalableVectorType>(...)`,
  * make the split into "Scalable" and "Fixed-wdith" vectors a bit more
    explicit and more visible in the code-base.

Depends on llvm#87999
@banach-space
Copy link
Contributor Author

Apologies for the delay with this one - got distracted with the discussion on how to model "scalability" and the life happened 😅

I will land it in the next few days, unless there are new comments.

…e vectors

FixedWidthVectorType -> FixedVectorType for consistency with LLVM
Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit comments.

…scalable vectors

Fix typos, add comments, remove include
@banach-space banach-space changed the title [mlir][vector] Add convenience types for scalable vectors [mlir] Add ScalableVectorType and FixedVectorType Dec 2, 2024
@banach-space
Copy link
Contributor Author

Good catches @kiranchandramohan , thank you 🙏🏻

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG.

@banach-space banach-space merged commit fed3a9b into llvm:main Dec 3, 2024
8 checks passed
@dcaballe
Copy link
Contributor

dcaballe commented Dec 3, 2024

We want this to be consistently available to everyone using VectorType so, should we move this to #include "mlir/IR/BuiltinTypes.h"?

banach-space added a commit to banach-space/llvm-project that referenced this pull request Dec 4, 2024
As a follow-on for llvm#87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".
@banach-space
Copy link
Contributor Author

We want this to be consistently available to everyone using VectorType so, should we move this to #include "mlir/IR/BuiltinTypes.h"?

Guess what has blocked me from using these wrappers in *.td files ;-)

banach-space added a commit that referenced this pull request Dec 4, 2024
As a follow-on for #87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants