Skip to content

Commit d009f6e

Browse files
committed
[mlir] Convert ConstShapeOp to a static tensor type.
ConstShapeOp knows its shape, so it should also have a static tensor type. Differential Revision: https://reviews.llvm.org/D111127
1 parent 9ce4f37 commit d009f6e

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
191191
Type indexTy = rewriter.getIndexType();
192192
Value tensor =
193193
rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
194-
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
194+
Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy);
195195
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
196196
return success();
197197
}

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,29 +89,29 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
8989

9090
// Lower `const_shape` to `tensor.from_elements`.
9191
// CHECK-LABEL: @const_shape
92-
// CHECK-SAME: () -> tensor<?xindex>
93-
func @const_shape() -> tensor<?xindex> {
92+
// CHECK-SAME: () -> tensor<3xindex>
93+
func @const_shape() -> tensor<3xindex> {
9494
// CHECK: %[[C1:.*]] = constant 1 : index
9595
// CHECK: %[[C2:.*]] = constant 2 : index
9696
// CHECK: %[[C3:.*]] = constant 3 : index
9797
// CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]]
98-
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
99-
// CHECK: return %[[RESULT]] : tensor<?xindex>
100-
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
101-
return %shape : tensor<?xindex>
98+
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<3xindex>
99+
// CHECK: return %[[RESULT]] : tensor<3xindex>
100+
%shape = shape.const_shape [1, 2, 3] : tensor<3xindex>
101+
return %shape : tensor<3xindex>
102102
}
103103

104104
// -----
105105

106106
// Lower `const_shape` in the case of rank 0.
107107
// CHECK-LABEL: func @const_shape_zero_elements
108-
// CHECK-SAME: () -> tensor<?xindex>
109-
func @const_shape_zero_elements() -> tensor<?xindex> {
108+
// CHECK-SAME: () -> tensor<0xindex>
109+
func @const_shape_zero_elements() -> tensor<0xindex> {
110110
// CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex>
111-
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
112-
// CHECK: return %[[RESULT]] : tensor<?xindex>
113-
%shape = shape.const_shape [] : tensor<?xindex>
114-
return %shape : tensor<?xindex>
111+
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<0xindex>
112+
// CHECK: return %[[RESULT]] : tensor<0xindex>
113+
%shape = shape.const_shape [] : tensor<0xindex>
114+
return %shape : tensor<0xindex>
115115
}
116116

117117
// -----

0 commit comments

Comments
 (0)