Skip to content

[mlir] Add unit test for RankedTensorType wrapper class example. #99789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2024

Conversation

jpienaar
Copy link
Member

Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly.

For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.

@jpienaar jpienaar requested review from joker-eph and River707 July 20, 2024 23:18
@llvmbot llvmbot added the mlir label Jul 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 20, 2024

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly.

For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.


Full diff: https://github.com/llvm/llvm-project/pull/99789.diff

1 Files Affected:

  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+60)
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..66ee416d7636e 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
 #include "gtest/gtest.h"
 #include <cstdint>
@@ -226,4 +227,63 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+  using RankedTensorType::RankedTensorType;
+
+  static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+                              StringRef name) {
+    return mlir::cast<TensorWithString>(RankedTensorType::get(
+        shape, elementType, StringAttr::get(elementType.getContext(), name)));
+  }
+
+  StringRef getName() const {
+    if (Attribute enc = getEncoding())
+      return mlir::cast<StringAttr>(enc).getValue();
+    return {};
+  }
+
+  static bool classof(Type type);
+};
+
+bool TensorWithString::classof(Type type) {
+  if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+    return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+  return false;
+}
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+  UnitAttr unitAttr = UnitAttr::get(&context);
+  Type unitEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, unitAttr);
+
+  StringAttr stringAttr = StringAttr::get(&context, "app");
+  Type stringEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, stringAttr);
+
+  EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+  EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+  ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+  // Cast to TensorWithString view.
+  auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+  ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+  EXPECT_EQ(view.getName(), "app");
+  // Verify one could cast view type back to base type.
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+  Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+  ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+  view = mlir::cast<TensorWithString>(viewCreated);
+  EXPECT_EQ(view.getName(), "bob");
+}
+
 } // namespace

Comment on lines 251 to 255
bool TensorWithString::classof(Type type) {
if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
return mlir::isa_and_present<StringAttr>(rt.getEncoding());
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this the only method defined out-of-line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made inline for simplicity (no good reason here).

Comment on lines +243 to +244
if (Attribute enc = getEncoding())
return mlir::cast<StringAttr>(enc).getValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to be adding these encoding wrappers, should we consider a templatized get encoding that does the cast/dyn_cast to the wrapped type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea yes. Like "getEncodingAs" ?

Add example as unit test for creating a "RankedTensorType with encoding"
view. This view provides a more typed API to the encoding while it
allows one to avoid repeated dyn_cast queries and accessing the encoding
directly.

For users with more advance encodings, the expectation would be a
separate attribute type, but here just StringAttr is used.
@jpienaar jpienaar merged commit d2f42c7 into llvm:main Jul 22, 2024
5 of 6 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
Add example as unit test for creating a wrapper type/view for
RankedTensorType with encoding. This view provides a more restricted &
typed API while it allows one to avoid repeated casting queries and
accessing the encoding directly.

For users with more advance encodings, the expectation would be a
separate attribute type, but here just StringAttr is used.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251186
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants