Skip to content

Commit a4803d8

Browse files
committed
[mlir][Tosa] Fix Clamp verifier to handle quantized types.
1 parent 7c7896b commit a4803d8

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,18 @@ LogicalResult tosa::AvgPool2dOp::verify() {
312312
LogicalResult tosa::ClampOp::verify() {
313313
mlir::Type inputETy =
314314
llvm::cast<ShapedType>(getInput().getType()).getElementType();
315+
if (auto quantType =
316+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
317+
inputETy = quantType.getStorageType();
318+
}
315319
mlir::Type maxFpType = getMaxFpAttr().getType();
316320
mlir::Type minFpType = getMinFpAttr().getType();
317321
mlir::Type outputETy =
318322
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
323+
if (auto quantType =
324+
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
325+
outputETy = quantType.getStorageType();
326+
}
319327
unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
320328

321329
if (inputETy != outputETy)

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
152152
return %0 : tensor<13x21x3xbf16>
153153
}
154154

155+
// -----
156+
// CHECK-LABEL: clamp_quantized
157+
func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>> {
158+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
159+
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
160+
}
161+
155162
// -----
156163
// CHECK-LABEL: sigmoid
157164
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)