Skip to content

Commit d2f42c7

Browse files
authored
[mlir] Add unit test for RankedTensorType wrapper example. (#99789)
Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly. For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.
1 parent 8972979 commit d2f42c7

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

mlir/unittests/IR/ShapedTypeTest.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/IR/BuiltinTypes.h"
1212
#include "mlir/IR/Dialect.h"
1313
#include "mlir/IR/DialectInterface.h"
14+
#include "mlir/Support/LLVM.h"
1415
#include "llvm/ADT/SmallVector.h"
1516
#include "gtest/gtest.h"
1617
#include <cstdint>
@@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
226227
}
227228
}
228229

230+
/// Simple wrapper class to enable "isa querying" and simple accessing of
231+
/// encoding.
232+
class TensorWithString : public RankedTensorType {
233+
public:
234+
using RankedTensorType::RankedTensorType;
235+
236+
static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
237+
StringRef name) {
238+
return mlir::cast<TensorWithString>(RankedTensorType::get(
239+
shape, elementType, StringAttr::get(elementType.getContext(), name)));
240+
}
241+
242+
StringRef getName() const {
243+
if (Attribute enc = getEncoding())
244+
return mlir::cast<StringAttr>(enc).getValue();
245+
return {};
246+
}
247+
248+
static bool classof(Type type) {
249+
if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
250+
return mlir::isa_and_present<StringAttr>(rt.getEncoding());
251+
return false;
252+
}
253+
};
254+
255+
TEST(ShapedTypeTest, RankedTensorTypeView) {
256+
MLIRContext context;
257+
Type f32 = FloatType::getF32(&context);
258+
259+
Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
260+
261+
UnitAttr unitAttr = UnitAttr::get(&context);
262+
Type unitEncodingRankedTensorType =
263+
RankedTensorType::get({10, 20}, f32, unitAttr);
264+
265+
StringAttr stringAttr = StringAttr::get(&context, "app");
266+
Type stringEncodingRankedTensorType =
267+
RankedTensorType::get({10, 20}, f32, stringAttr);
268+
269+
EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
270+
EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
271+
ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
272+
273+
// Cast to TensorWithString view.
274+
auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
275+
ASSERT_TRUE(mlir::isa<TensorWithString>(view));
276+
EXPECT_EQ(view.getName(), "app");
277+
// Verify one could cast view type back to base type.
278+
ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
279+
280+
Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
281+
ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
282+
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
283+
view = mlir::cast<TensorWithString>(viewCreated);
284+
EXPECT_EQ(view.getName(), "bob");
285+
}
286+
229287
} // namespace

0 commit comments

Comments
 (0)