Skip to content

Commit a200bdb

Browse files
lhutton1svkeerthy
authored andcommitted
[mlir][tosa] Allow unranked input/output tensors in resize ops (#141608)
This commit allows the input/output of the resize op to be unranked to account for shapes being computed during shape inference.
1 parent 346aa9b commit a200bdb

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,16 +2496,6 @@ LogicalResult tosa::ResizeOp::verify() {
24962496
const RankedTensorType outputType =
24972497
llvm::dyn_cast<RankedTensorType>(output.getType());
24982498

2499-
if (!inputType)
2500-
return emitOpError("expect a ranked input tensor");
2501-
if (!outputType)
2502-
return emitOpError("expect a ranked output tensor");
2503-
2504-
const int64_t oh = outputType.getDimSize(1);
2505-
const int64_t ow = outputType.getDimSize(2);
2506-
const int64_t ih = inputType.getDimSize(1);
2507-
const int64_t iw = inputType.getDimSize(2);
2508-
25092499
SmallVector<int64_t> scaleValues;
25102500
SmallVector<int64_t> offsetValues;
25112501
SmallVector<int64_t> borderValues;
@@ -2531,6 +2521,16 @@ LogicalResult tosa::ResizeOp::verify() {
25312521
const int64_t borderY = borderValues[0];
25322522
const int64_t borderX = borderValues[1];
25332523

2524+
if (!inputType)
2525+
return success();
2526+
if (!outputType)
2527+
return success();
2528+
2529+
const int64_t oh = outputType.getDimSize(1);
2530+
const int64_t ow = outputType.getDimSize(2);
2531+
const int64_t ih = inputType.getDimSize(1);
2532+
const int64_t iw = inputType.getDimSize(2);
2533+
25342534
// Don't check with input height that could be broadcast (ih != 1)
25352535
// since Linalg, a consumer of TOSA, expects broadcasting support
25362536
// in resize to be available. Taking the cautious approach for now,

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,26 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
743743
return %1 : tensor<1x64x64x8xf32>
744744
}
745745

746+
// -----
747+
// CHECK-LABEL: resize_unranked_output
748+
func.func @test_resize_unranked_output(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> {
749+
%scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
750+
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
751+
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
752+
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
753+
return %1 : tensor<*xf32>
754+
}
755+
756+
// -----
757+
// CHECK-LABEL: resize_unranked_input
758+
func.func @test_resize_unranked_input(%arg0: tensor<*xf32>) -> tensor<1x64x64x8xf32> {
759+
%scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
760+
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
761+
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
762+
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
763+
return %1 : tensor<1x64x64x8xf32>
764+
}
765+
746766
// -----
747767
// CHECK-LABEL: cast
748768
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)