Skip to content

Commit 76dbb86

Browse files
committed
TOSA: more folds
1 parent b4da2b3 commit 76dbb86

File tree

2 files changed

+100
-5
lines changed

2 files changed

+100
-5
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,8 @@ def Tosa_PowOp : Tosa_Op<"pow", [
764764
let results = (outs
765765
Tosa_Tensor:$z
766766
);
767+
768+
let hasFolder = 1;
767769
}
768770

769771
//===----------------------------------------------------------------------===//
@@ -1057,6 +1059,8 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [
10571059
let results = (outs
10581060
Tosa_Tensor:$output
10591061
);
1062+
1063+
let hasFolder = 1;
10601064
}
10611065

10621066
//===----------------------------------------------------------------------===//
@@ -1080,6 +1084,8 @@ def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [
10801084
let results = (outs
10811085
Tosa_Tensor:$output
10821086
);
1087+
1088+
let hasFolder = 1;
10831089
}
10841090

10851091
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,13 +474,33 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
474474
auto resultTy = getType().dyn_cast<RankedTensorType>();
475475
if (!lhsTy || !rhsTy || !resultTy)
476476
return {};
477-
if (lhsTy != rhsTy)
478-
return {};
479-
477+
480478
auto resultETy = resultTy.getElementType();
481479
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
482480
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
483481

482+
if (lhsTy == resultTy) {
483+
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
484+
if (rhsAttr.getSplatValue<APFloat>().isZero())
485+
return getInput1();
486+
}
487+
}
488+
489+
if (lhsTy != rhsTy) {
490+
if (lhsAttr && rhsAttr) {
491+
if (lhsTy == resultTy && rhsAttr.isSplat()) {
492+
APFloat r = rhsAttr.getSplatValue<APFloat>();
493+
std::vector<APFloat> v;
494+
v.resize(lhsAttr.size(), APFloat(0.0));
495+
for(int i=0;i<lhsAttr.size(); ++i) {
496+
v[i] = lhsAttr.getValues<APFloat>()[i] + r;
497+
}
498+
return DenseElementsAttr::get(resultTy, v);
499+
}
500+
}
501+
}
502+
503+
484504
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
485505
if (lhsAttr.getSplatValue<APFloat>().isZero())
486506
return getInput2();
@@ -883,7 +903,62 @@ OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
883903
return input;
884904
}
885905

886-
OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
906+
OpFoldResult RsqrtOp::fold(FoldAdaptor adaptor) {
907+
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
908+
auto outputTy = getType().dyn_cast<RankedTensorType>();
909+
910+
if (!inputTy || !outputTy)
911+
return {};
912+
913+
if (inputTy != outputTy)
914+
return {};
915+
916+
auto operand = adaptor.getInput1().dyn_cast_or_null<ElementsAttr>();
917+
if (!operand)
918+
return {};
919+
920+
if (!inputTy.getElementType().isF32())
921+
return {};
922+
923+
std::vector<float> v;
924+
v.resize(operand.size());
925+
for(int i=0;i<operand.size(); ++i) {
926+
v[i] = 1.0 / sqrt(operand.getValues<float>()[i]);
927+
}
928+
return DenseElementsAttr::get(outputTy, ArrayRef<float>{v});
929+
}
930+
931+
OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
932+
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
933+
auto input2Ty = getInput2().getType().dyn_cast<RankedTensorType>();
934+
auto outputTy = getType().dyn_cast<RankedTensorType>();
935+
936+
if (!inputTy || !input2Ty || !outputTy)
937+
return {};
938+
939+
if (inputTy != outputTy || input2Ty != outputTy)
940+
return {};
941+
942+
auto operand1 = adaptor.getInput1().dyn_cast_or_null<ElementsAttr>();
943+
if (!operand1)
944+
return {};
945+
946+
auto operand2 = adaptor.getInput2().dyn_cast_or_null<ElementsAttr>();
947+
if (!operand2)
948+
return {};
949+
950+
if (!operand1.getElementType().isF32() || !operand2.getElementType().isF32())
951+
return {};
952+
953+
std::vector<float> v;
954+
v.resize(operand1.size());
955+
for(int i=0;i<operand1.size(); ++i) {
956+
v[i] = pow(operand1.getValues<float>()[i], operand2.getValues<float>()[i]);
957+
}
958+
return DenseElementsAttr::get(outputTy, ArrayRef<float>{v});
959+
}
960+
961+
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
887962
auto operand = getInput();
888963
auto operandTy = operand.getType().cast<ShapedType>();
889964
auto axis = getAxis();
@@ -898,7 +973,21 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
898973
return {};
899974
}
900975

901-
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
976+
OpFoldResult ReciprocalOp::fold(FoldAdaptor adaptor) {
977+
auto src = adaptor.getInput1().dyn_cast_or_null<mlir::DenseElementsAttr>();
978+
979+
if (!src)
980+
return nullptr;
981+
982+
std::vector<float> v;
983+
v.resize(src.getNumElements());
984+
for(int i=0; i< src.getNumElements(); ++i)
985+
v[i] = 1.0 / src.getValues<float>()[i];
986+
987+
return mlir::DenseElementsAttr::get(src.getType(), ArrayRef(v));
988+
}
989+
990+
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
902991
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
903992
auto outputTy = getType().dyn_cast<RankedTensorType>();
904993

0 commit comments

Comments
 (0)