Skip to content

Commit a485a7d

Browse files
committed
[mlir] Add unit test for RankedTensorType wrapper class example.
Add example as unit test for creating a "RankedTensorType with encoding" view. This view provides a more typed API to the encoding while it allows one to avoid repeated dyn_cast queries and accessing the encoding directly. For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.
1 parent ae2012d commit a485a7d

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

mlir/unittests/IR/ShapedTypeTest.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/IR/BuiltinTypes.h"
1212
#include "mlir/IR/Dialect.h"
1313
#include "mlir/IR/DialectInterface.h"
14+
#include "mlir/Support/LLVM.h"
1415
#include "llvm/ADT/SmallVector.h"
1516
#include "gtest/gtest.h"
1617
#include <cstdint>
@@ -226,4 +227,63 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
226227
}
227228
}
228229

230+
/// Simple wrapper class to enable "isa querying" and simple accessing of
231+
/// encoding.
232+
class TensorWithString : public RankedTensorType {
233+
public:
234+
using RankedTensorType::RankedTensorType;
235+
236+
static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
237+
StringRef name) {
238+
return mlir::cast<TensorWithString>(RankedTensorType::get(
239+
shape, elementType, StringAttr::get(elementType.getContext(), name)));
240+
}
241+
242+
StringRef getName() const {
243+
if (Attribute enc = getEncoding())
244+
return mlir::cast<StringAttr>(enc).getValue();
245+
return {};
246+
}
247+
248+
static bool classof(Type type);
249+
};
250+
251+
bool TensorWithString::classof(Type type) {
252+
if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
253+
return mlir::isa_and_present<StringAttr>(rt.getEncoding());
254+
return false;
255+
}
256+
257+
TEST(ShapedTypeTest, RankedTensorTypeView) {
258+
MLIRContext context;
259+
Type f32 = FloatType::getF32(&context);
260+
261+
Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
262+
263+
UnitAttr unitAttr = UnitAttr::get(&context);
264+
Type unitEncodingRankedTensorType =
265+
RankedTensorType::get({10, 20}, f32, unitAttr);
266+
267+
StringAttr stringAttr = StringAttr::get(&context, "app");
268+
Type stringEncodingRankedTensorType =
269+
RankedTensorType::get({10, 20}, f32, stringAttr);
270+
271+
EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
272+
EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
273+
ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
274+
275+
// Cast to TensorWithString view.
276+
auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
277+
ASSERT_TRUE(mlir::isa<TensorWithString>(view));
278+
EXPECT_EQ(view.getName(), "app");
279+
// Verify one could cast view type back to base type.
280+
ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
281+
282+
Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
283+
ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
284+
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
285+
view = mlir::cast<TensorWithString>(viewCreated);
286+
EXPECT_EQ(view.getName(), "bob");
287+
}
288+
229289
} // namespace

0 commit comments

Comments
 (0)