-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add unit test for RankedTensorType wrapper class example. #99789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Jacques Pienaar (jpienaar) ChangesAdd example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly. For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used. Full diff: https://github.com/llvm/llvm-project/pull/99789.diff 1 Files Affected:
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..66ee416d7636e 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
#include <cstdint>
@@ -226,4 +227,63 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+ using RankedTensorType::RankedTensorType;
+
+ static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+ StringRef name) {
+ return mlir::cast<TensorWithString>(RankedTensorType::get(
+ shape, elementType, StringAttr::get(elementType.getContext(), name)));
+ }
+
+ StringRef getName() const {
+ if (Attribute enc = getEncoding())
+ return mlir::cast<StringAttr>(enc).getValue();
+ return {};
+ }
+
+ static bool classof(Type type);
+};
+
+bool TensorWithString::classof(Type type) {
+ if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+ return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+ return false;
+}
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+ UnitAttr unitAttr = UnitAttr::get(&context);
+ Type unitEncodingRankedTensorType =
+ RankedTensorType::get({10, 20}, f32, unitAttr);
+
+ StringAttr stringAttr = StringAttr::get(&context, "app");
+ Type stringEncodingRankedTensorType =
+ RankedTensorType::get({10, 20}, f32, stringAttr);
+
+ EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+ EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+ ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+ // Cast to TensorWithString view.
+ auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+ ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+ EXPECT_EQ(view.getName(), "app");
+ // Verify one could cast view type back to base type.
+ ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+ Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+ ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+ ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+ view = mlir::cast<TensorWithString>(viewCreated);
+ EXPECT_EQ(view.getName(), "bob");
+}
+
} // namespace
|
mlir/unittests/IR/ShapedTypeTest.cpp
Outdated
bool TensorWithString::classof(Type type) { | ||
if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type)) | ||
return mlir::isa_and_present<StringAttr>(rt.getEncoding()); | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this the only method defined out-of-line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made inline for simplicity (no good reason here).
if (Attribute enc = getEncoding()) | ||
return mlir::cast<StringAttr>(enc).getValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're going to be adding these encoding wrappers, should we consider a templatized get encoding that does the cast/dyn_cast to the wrapped type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea yes. Like "getEncodingAs" ?
Add example as unit test for creating a "RankedTensorType with encoding" view. This view provides a more typed API to the encoding while it allows one to avoid repeated dyn_cast queries and accessing the encoding directly. For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.
Summary: Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly. For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251186
Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly.
For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.