Skip to content

Commit 96f8d12

Browse files
[mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes
In TOSA MLIR dialect, fix the definition of the Clamp op to accept fp16 & bf16 datatype for the min_fp and max_fp attributes. Add ClampOp verifier to check attributes types compatibility. Add related test cases in Tosa/ops.mlir. Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 4d6fc88 commit 96f8d12

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,16 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
380380
Tosa_Tensor:$input,
381381
I64Attr:$min_int,
382382
I64Attr:$max_int,
383-
F32Attr:$min_fp,
384-
F32Attr:$max_fp
383+
Tosa_FloatAttr:$min_fp,
384+
Tosa_FloatAttr:$max_fp
385385
);
386386

387387
let results = (outs
388388
Tosa_Tensor:$output
389389
);
390390

391391
let hasCanonicalizer = 1;
392+
let hasVerifier = 1;
392393
}
393394

394395
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
197197
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
198198
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
199199

200+
def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
201+
"arbitrary float attribute"> {
202+
let storageType = [{ ::mlir::FloatAttr }];
203+
let returnType = [{ ::mlir::APFloat }];
204+
}
205+
200206
//===----------------------------------------------------------------------===//
201207
// Iterable attributes.
202208
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,32 @@ LogicalResult tosa::AvgPool2dOp::verify() {
266266
return emitOpError("input/output element types are incompatible.");
267267
}
268268

269+
LogicalResult tosa::ClampOp::verify() {
270+
mlir::Type inputETy =
271+
llvm::cast<ShapedType>(getInput().getType()).getElementType();
272+
mlir::Type maxFpType = getMaxFpAttr().getType();
273+
mlir::Type minFpType = getMinFpAttr().getType();
274+
mlir::Type outputETy =
275+
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
276+
unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
277+
278+
if (inputETy != outputETy)
279+
return emitOpError("input/output element types are incompatible.");
280+
281+
// if input datatype is float, check that the two min/max_fp attributes share
282+
// the same type and that their type is either the same of the input's
283+
// datatype, or a float type whose bitwidth > input datatype bitwidth
284+
if (!inputETy.isInteger(dataTypeBitWidth)) {
285+
if (((maxFpType != minFpType) ||
286+
(maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
287+
inputETy.getIntOrFloatBitWidth())))
288+
return emitOpError("min/max attributes types are incompatible with "
289+
"input/output element types.");
290+
}
291+
292+
return success();
293+
}
294+
269295
//===----------------------------------------------------------------------===//
270296
// TOSA Operator Quantization Builders.
271297
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
124124
return %0 : tensor<13x21x3xf32>
125125
}
126126

127+
// -----
128+
// CHECK-LABEL: clamp_f16
129+
func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
130+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
131+
return %0 : tensor<13x21x3xf16>
132+
}
133+
134+
// -----
135+
// CHECK-LABEL: clamp_bf16
136+
func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
137+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
138+
return %0 : tensor<13x21x3xbf16>
139+
}
140+
127141
// -----
128142
// CHECK-LABEL: sigmoid
129143
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)