Skip to content

Commit 28d7671

Browse files
authored
[mlir] Add two clone methods about encoding to RankedTensorType. (#127709)
There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType: - dropEncoding(): Return a clone of this type without the encoding. - cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type. Signed-off-by: hanhanW <[email protected]>
1 parent fb191ef commit 28d7671

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
10351035
RankedTensorType clone(::mlir::Type elementType) {
10361036
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
10371037
}
1038+
1039+
/// Return a clone of this type without the encoding.
1040+
RankedTensorType dropEncoding() {
1041+
return RankedTensorType::get(getShape(), getElementType());
1042+
}
1043+
1044+
/// Return a clone of this type with the given new encoding and the same
1045+
/// shape and element type as this type.
1046+
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
1047+
return RankedTensorType::get(getShape(), getElementType(), encoding);
1048+
}
10381049
}];
10391050
let skipDefaultBuilders = 1;
10401051
let genVerifyDecl = 1;

mlir/unittests/IR/ShapedTypeTest.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
282282
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
283283
view = mlir::cast<TensorWithString>(viewCreated);
284284
EXPECT_EQ(view.getName(), "bob");
285+
286+
// Verify encoding clone methods.
287+
EXPECT_EQ(unitEncodingRankedTensorType,
288+
cast<RankedTensorType>(noEncodingRankedTensorType)
289+
.cloneWithEncoding(unitAttr));
290+
EXPECT_EQ(stringEncodingRankedTensorType,
291+
cast<RankedTensorType>(noEncodingRankedTensorType)
292+
.cloneWithEncoding(stringAttr));
293+
EXPECT_EQ(
294+
noEncodingRankedTensorType,
295+
cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
296+
EXPECT_EQ(
297+
noEncodingRankedTensorType,
298+
cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
285299
}
286300

287301
} // namespace

0 commit comments

Comments
 (0)