Skip to content

Commit 4759890

Browse files
authored
[mlir][tensor] Fix bug in insert_slice canonical. with tensor encoding (llvm#81045)
Previously, `InsertSliceOpSourceCastInserter` was incorrectly applied to a case when tensor types have an encoding attribute attached to them. The type `newSrcType` was missing that attribute from the old `srcType`, which made the expression `srcType == newSrcType` false, since `tensor<2x2xf32, "foo">` is not equal to `tensor<2x2xf32>`. That lead to an endless back and forth between `InsertSliceOpSourceCastInserter` that would introduce a cast and `InsertSliceOpCastFolder` that would remove it right after.
1 parent fbf43b0 commit 4759890

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,8 +2663,8 @@ struct InsertSliceOpSourceCastInserter final
26632663
if (!hasValidSizesOffsets(newSrcShape))
26642664
return failure();
26652665

2666-
RankedTensorType newSrcType =
2667-
RankedTensorType::get(newSrcShape, srcType.getElementType());
2666+
RankedTensorType newSrcType = RankedTensorType::get(
2667+
newSrcShape, srcType.getElementType(), srcType.getEncoding());
26682668
if (srcType == newSrcType ||
26692669
!preservesStaticInformation(srcType, newSrcType) ||
26702670
!tensor::CastOp::areCastCompatible(srcType, newSrcType))

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,24 @@ func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
555555

556556
// -----
557557

558+
// Do not insert a cast for the following example. The new source type wouldn't be "more static" than the old one.
559+
func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">,
560+
%arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo">
561+
{
562+
%0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
563+
return %0 : tensor<4x4xf32, "foo">
564+
}
565+
// CHECK-LABEL: func @insert_slice_canonicalize_encoding
566+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo">
567+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo">
568+
// CHECK-NOT: tensor.cast
569+
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]]
570+
// CHECK-SAME: [0, 0] [2, 2] [1, 1]
571+
// CHECK-SAME: : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
572+
// CHECK: return %[[RESULT]]
573+
574+
// -----
575+
558576
func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
559577
%arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
560578
{

0 commit comments

Comments
 (0)