|
11 | 11 | #include "mlir/IR/BuiltinTypes.h"
|
12 | 12 | #include "mlir/IR/Dialect.h"
|
13 | 13 | #include "mlir/IR/DialectInterface.h"
|
| 14 | +#include "mlir/Support/LLVM.h" |
14 | 15 | #include "llvm/ADT/SmallVector.h"
|
15 | 16 | #include "gtest/gtest.h"
|
16 | 17 | #include <cstdint>
|
@@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
|
226 | 227 | }
|
227 | 228 | }
|
228 | 229 |
|
| 230 | +/// Simple wrapper class to enable "isa querying" and simple accessing of |
| 231 | +/// encoding. |
| 232 | +class TensorWithString : public RankedTensorType { |
| 233 | +public: |
| 234 | + using RankedTensorType::RankedTensorType; |
| 235 | + |
| 236 | + static TensorWithString get(ArrayRef<int64_t> shape, Type elementType, |
| 237 | + StringRef name) { |
| 238 | + return mlir::cast<TensorWithString>(RankedTensorType::get( |
| 239 | + shape, elementType, StringAttr::get(elementType.getContext(), name))); |
| 240 | + } |
| 241 | + |
| 242 | + StringRef getName() const { |
| 243 | + if (Attribute enc = getEncoding()) |
| 244 | + return mlir::cast<StringAttr>(enc).getValue(); |
| 245 | + return {}; |
| 246 | + } |
| 247 | + |
| 248 | + static bool classof(Type type) { |
| 249 | + if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type)) |
| 250 | + return mlir::isa_and_present<StringAttr>(rt.getEncoding()); |
| 251 | + return false; |
| 252 | + } |
| 253 | +}; |
| 254 | + |
| 255 | +TEST(ShapedTypeTest, RankedTensorTypeView) { |
| 256 | + MLIRContext context; |
| 257 | + Type f32 = FloatType::getF32(&context); |
| 258 | + |
| 259 | + Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32); |
| 260 | + |
| 261 | + UnitAttr unitAttr = UnitAttr::get(&context); |
| 262 | + Type unitEncodingRankedTensorType = |
| 263 | + RankedTensorType::get({10, 20}, f32, unitAttr); |
| 264 | + |
| 265 | + StringAttr stringAttr = StringAttr::get(&context, "app"); |
| 266 | + Type stringEncodingRankedTensorType = |
| 267 | + RankedTensorType::get({10, 20}, f32, stringAttr); |
| 268 | + |
| 269 | + EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType)); |
| 270 | + EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType)); |
| 271 | + ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType)); |
| 272 | + |
| 273 | + // Cast to TensorWithString view. |
| 274 | + auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType); |
| 275 | + ASSERT_TRUE(mlir::isa<TensorWithString>(view)); |
| 276 | + EXPECT_EQ(view.getName(), "app"); |
| 277 | + // Verify one could cast view type back to base type. |
| 278 | + ASSERT_TRUE(mlir::isa<RankedTensorType>(view)); |
| 279 | + |
| 280 | + Type viewCreated = TensorWithString::get({10, 20}, f32, "bob"); |
| 281 | + ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated)); |
| 282 | + ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated)); |
| 283 | + view = mlir::cast<TensorWithString>(viewCreated); |
| 284 | + EXPECT_EQ(view.getName(), "bob"); |
| 285 | +} |
| 286 | + |
229 | 287 | } // namespace
|
0 commit comments