Skip to content

Commit dde7b80

Browse files
[mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes (llvm#69192)
In TOSA MLIR dialect, fix the definition of the Clamp op to accept fp16 & bf16 datatype for the min_fp and max_fp attributes. Add ClampOp verifier to check attributes types compatibility. Add related test cases in Tosa/ops.mlir. Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 2e4161d commit dde7b80

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,16 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
380380
Tosa_Tensor:$input,
381381
I64Attr:$min_int,
382382
I64Attr:$max_int,
383-
F32Attr:$min_fp,
384-
F32Attr:$max_fp
383+
Tosa_FloatAttr:$min_fp,
384+
Tosa_FloatAttr:$max_fp
385385
);
386386

387387
let results = (outs
388388
Tosa_Tensor:$output
389389
);
390390

391391
let hasCanonicalizer = 1;
392+
let hasVerifier = 1;
392393
}
393394

394395
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
197197
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
198198
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
199199

200+
def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
201+
"arbitrary float attribute"> {
202+
let storageType = [{ ::mlir::FloatAttr }];
203+
let returnType = [{ ::mlir::APFloat }];
204+
}
205+
200206
//===----------------------------------------------------------------------===//
201207
// Iterable attributes.
202208
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,32 @@ LogicalResult tosa::AvgPool2dOp::verify() {
309309
return emitOpError("input/output element types are incompatible.");
310310
}
311311

312+
LogicalResult tosa::ClampOp::verify() {
313+
mlir::Type inputETy =
314+
llvm::cast<ShapedType>(getInput().getType()).getElementType();
315+
mlir::Type maxFpType = getMaxFpAttr().getType();
316+
mlir::Type minFpType = getMinFpAttr().getType();
317+
mlir::Type outputETy =
318+
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
319+
unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
320+
321+
if (inputETy != outputETy)
322+
return emitOpError("input/output element types are incompatible.");
323+
324+
// if input datatype is float, check that the two min/max_fp attributes share
325+
// the same type and that their type is either the same of the input's
326+
// datatype, or a float type whose bitwidth > input datatype bitwidth
327+
if (!inputETy.isInteger(dataTypeBitWidth)) {
328+
if (((maxFpType != minFpType) ||
329+
(maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
330+
inputETy.getIntOrFloatBitWidth())))
331+
return emitOpError("min/max attributes types are incompatible with "
332+
"input/output element types.");
333+
}
334+
335+
return success();
336+
}
337+
312338
//===----------------------------------------------------------------------===//
313339
// TOSA Operator Quantization Builders.
314340
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
138138
return %0 : tensor<13x21x3xf32>
139139
}
140140

141+
// -----
142+
// CHECK-LABEL: clamp_f16
143+
func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
144+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
145+
return %0 : tensor<13x21x3xf16>
146+
}
147+
148+
// -----
149+
// CHECK-LABEL: clamp_bf16
150+
func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
151+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
152+
return %0 : tensor<13x21x3xbf16>
153+
}
154+
141155
// -----
142156
// CHECK-LABEL: sigmoid
143157
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)