Skip to content

Commit 1be48fd

Browse files
authored
[mlir][TosaToLinalg] Fix TosaToLinalg to restrict tosa.cast types to integer or float (llvm#128859)
This PR fixes a bug where `TosaToLinalg` incorrectly allows `tosa.cast` to accept types other than integer or float. Fixes llvm#116342.
1 parent c690b30 commit 1be48fd

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
524524
if (isa<tosa::CastOp>(op)) {
525525
Type srcTy = elementTy;
526526
Type dstTy = resultTypes.front();
527+
if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
528+
(void)rewriter.notifyMatchFailure(op, "unsupported type");
529+
return nullptr;
530+
}
531+
527532
bool bitExtend =
528533
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
529534

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,11 @@ func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
5454
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
5555
return %0 : tensor<2x3x4xf32>
5656
}
57+
58+
// -----
59+
60+
func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
61+
// expected-error@+1 {{failed to legalize operation 'tosa.cast'}}
62+
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
63+
return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
64+
}

0 commit comments

Comments
 (0)