Skip to content

Commit 37665ac

Browse files
authored
Merge pull request #32 from Xilinx/christopher.FXML-1991_reciprocal_constant_folding
[FXML-1991] Reciprocal Constant Folding For a Splat Tensor
2 parents 2ec9d50 + 17885ab commit 37665ac

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,8 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [
10571057
let results = (outs
10581058
Tosa_Tensor:$output
10591059
);
1060+
1061+
let hasFolder = 1;
10601062
}
10611063

10621064
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,26 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
630630
return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
631631
}
632632

633+
OpFoldResult ReciprocalOp::fold(ArrayRef<Attribute> operands) {
634+
auto constantAttr = dyn_cast_or_null<DenseElementsAttr>(operands[0]);
635+
auto lhsTy = dyn_cast<RankedTensorType>(getInput1().getType());
636+
637+
if (!lhsTy || !constantAttr) {
638+
return {};
639+
}
640+
641+
if (!constantAttr.isSplat()) {
642+
return {};
643+
}
644+
645+
auto floatVal = constantAttr.getSplatValue<llvm::APFloat>();
646+
647+
auto recipAttr = FloatAttr::get(lhsTy.getElementType(), 1.0);
648+
APFloat recip = recipAttr.getValue();
649+
recip.divide(floatVal, APFloat::rmNearestTiesToEven);
650+
return DenseElementsAttr::get(lhsTy, recip);
651+
}
652+
633653
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
634654
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
635655
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,29 @@ func.func @fold_mul_splat_f32() -> tensor<10xf32> {
328328

329329
// -----
330330

331+
// CHECK-LABEL: @fold_reciprocal_splat_f32
332+
func.func @fold_reciprocal_splat_f32() -> tensor<f32> {
333+
%half = "tosa.const"() {value = dense<0.5> : tensor<f32>} : () -> tensor<f32>
334+
%recp = "tosa.reciprocal"(%half) : (tensor<f32>) -> tensor<f32>
335+
// CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<f32>}
336+
// CHECK: return %[[CST]]
337+
return %recp : tensor<f32>
338+
}
339+
340+
// -----
341+
342+
// CHECK-LABEL: @fold_reciprocal_splat_zero_f32
343+
func.func @fold_reciprocal_splat_zero_f32() -> tensor<f32> {
344+
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
345+
%recp = "tosa.reciprocal"(%zero) : (tensor<f32>) -> tensor<f32>
346+
// 0x7F800000 represents +inf as we have computed 1/0
347+
// CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>}
348+
// CHECK: return %[[CST]]
349+
return %recp : tensor<f32>
350+
}
351+
352+
// -----
353+
331354
// CHECK-LABEL: @fold_sub_zero_rhs_f32
332355
func.func @fold_sub_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
333356
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>

0 commit comments

Comments
 (0)