Skip to content

Commit fa5a607

Browse files
committed
[mlir][tosa] Fix tosa.slice shape inference for ShapedType:kDynamicShape
Change for kDynamicShape means the size needs to be updated to a new value for slice operation shape inference. Landing fix. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D138314
1 parent f65e8c3 commit fa5a607

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,12 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
590590
return success();
591591
}
592592

593+
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
594+
return to_vector(llvm::map_range(shape, [](int64_t dim) {
595+
return dim == -1 ? ShapedType::kDynamicSize : dim;
596+
}));
597+
}
598+
593599
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
594600
MLIRContext *context, ::llvm::Optional<Location> location,
595601
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -601,7 +607,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
601607
outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
602608
}
603609

604-
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
610+
inferredReturnShapes.push_back(ShapedTypeComponents(
611+
convertToMlirShape(outputShape)));
605612
return success();
606613
}
607614

@@ -655,11 +662,6 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
655662
return success();
656663
}
657664

658-
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
659-
return to_vector(llvm::map_range(shape, [](int64_t dim) {
660-
return dim == -1 ? ShapedType::kDynamicSize : dim;
661-
}));
662-
}
663665

664666
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
665667
MLIRContext *context, ::llvm::Optional<Location> location,

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,15 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
539539

540540
// -----
541541

542+
// CHECK-LABEL: @test_slice_dynamic
543+
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
544+
// CHECK: "tosa.slice"(%arg0) {size = [7, -1, 1], start = [1, 0, 0]} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
545+
%0 = "tosa.slice"(%arg0) {size = [7, -1, 1], start = [1, 0, 0]} : (tensor<10x?x2xf32>) -> tensor<?x?x?xf32>
546+
return
547+
}
548+
549+
// -----
550+
542551
// CHECK-LABEL: @test_tile
543552
func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
544553
// CHECK: "tosa.tile"(%arg0) {multiples = [2, 1, 5]} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32>

0 commit comments

Comments
 (0)