Skip to content

Commit 9d191f1

Browse files
authored
[mlir][tosa] Fix RFFT2D verifier for width=1 (#130279)
Current formula assumes width is a multiple of 2 but TOSA only requires a power of 2, which 1 is.
1 parent 995c0f7 commit 9d191f1

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ LogicalResult tosa::RFFT2dOp::verify() {
888888
// Output width dimension expected to be input_width / 2 + 1
889889
const int64_t outputWidth = outputType.getDimSize(2);
890890
if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
891-
(outputWidth - 1) * 2 != width)
891+
(outputWidth != (width / 2) + 1))
892892
return emitOpError(
893893
"expected output width to be equal to input_width / 2 + 1, got ")
894894
<< outputWidth;

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tenso
175175
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
176176
}
177177

178+
// -----
179+
// CHECK-LABEL: rfft2d_width1
180+
func.func @test_rfft2d_width1(%arg0: tensor<1x1x1xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>) {
181+
%0, %1 = tosa.rfft2d %arg0 : (tensor<1x1x1xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>)
182+
return %0, %1 : tensor<1x1x1xf32>, tensor<1x1x1xf32>
183+
}
184+
178185
// -----
179186
// CHECK-LABEL: rfft2d_with_local_bound
180187
func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {

0 commit comments

Comments
 (0)