@@ -474,13 +474,33 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
474
474
auto resultTy = getType ().dyn_cast <RankedTensorType>();
475
475
if (!lhsTy || !rhsTy || !resultTy)
476
476
return {};
477
- if (lhsTy != rhsTy)
478
- return {};
479
-
477
+
480
478
auto resultETy = resultTy.getElementType ();
481
479
auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
482
480
auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
483
481
482
+ if (lhsTy == resultTy) {
483
+ if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
484
+ if (rhsAttr.getSplatValue <APFloat>().isZero ())
485
+ return getInput1 ();
486
+ }
487
+ }
488
+
489
+ if (lhsTy != rhsTy) {
490
+ if (lhsAttr && rhsAttr) {
491
+ if (lhsTy == resultTy && rhsAttr.isSplat ()) {
492
+ APFloat r = rhsAttr.getSplatValue <APFloat>();
493
+ std::vector<APFloat> v;
494
+ v.resize (lhsAttr.size (), APFloat (0.0 ));
495
+ for (int i=0 ;i<lhsAttr.size (); ++i) {
496
+ v[i] = lhsAttr.getValues <APFloat>()[i] + r;
497
+ }
498
+ return DenseElementsAttr::get (resultTy, v);
499
+ }
500
+ }
501
+ }
502
+
503
+
484
504
if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <FloatType>()) {
485
505
if (lhsAttr.getSplatValue <APFloat>().isZero ())
486
506
return getInput2 ();
@@ -883,7 +903,62 @@ OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
883
903
return input;
884
904
}
885
905
886
- OpFoldResult ReverseOp::fold (ArrayRef<Attribute> operands) {
906
+ OpFoldResult RsqrtOp::fold (FoldAdaptor adaptor) {
907
+ auto inputTy = getInput1 ().getType ().dyn_cast <RankedTensorType>();
908
+ auto outputTy = getType ().dyn_cast <RankedTensorType>();
909
+
910
+ if (!inputTy || !outputTy)
911
+ return {};
912
+
913
+ if (inputTy != outputTy)
914
+ return {};
915
+
916
+ auto operand = adaptor.getInput1 ().dyn_cast_or_null <ElementsAttr>();
917
+ if (!operand)
918
+ return {};
919
+
920
+ if (!inputTy.getElementType ().isF32 ())
921
+ return {};
922
+
923
+ std::vector<float > v;
924
+ v.resize (operand.size ());
925
+ for (int i=0 ;i<operand.size (); ++i) {
926
+ v[i] = 1.0 / sqrt (operand.getValues <float >()[i]);
927
+ }
928
+ return DenseElementsAttr::get (outputTy, ArrayRef<float >{v});
929
+ }
930
+
931
+ OpFoldResult PowOp::fold (FoldAdaptor adaptor) {
932
+ auto inputTy = getInput1 ().getType ().dyn_cast <RankedTensorType>();
933
+ auto input2Ty = getInput2 ().getType ().dyn_cast <RankedTensorType>();
934
+ auto outputTy = getType ().dyn_cast <RankedTensorType>();
935
+
936
+ if (!inputTy || !input2Ty || !outputTy)
937
+ return {};
938
+
939
+ if (inputTy != outputTy || input2Ty != outputTy)
940
+ return {};
941
+
942
+ auto operand1 = adaptor.getInput1 ().dyn_cast_or_null <ElementsAttr>();
943
+ if (!operand1)
944
+ return {};
945
+
946
+ auto operand2 = adaptor.getInput2 ().dyn_cast_or_null <ElementsAttr>();
947
+ if (!operand2)
948
+ return {};
949
+
950
+ if (!operand1.getElementType ().isF32 () || !operand2.getElementType ().isF32 ())
951
+ return {};
952
+
953
+ std::vector<float > v;
954
+ v.resize (operand1.size ());
955
+ for (int i=0 ;i<operand1.size (); ++i) {
956
+ v[i] = pow (operand1.getValues <float >()[i], operand2.getValues <float >()[i]);
957
+ }
958
+ return DenseElementsAttr::get (outputTy, ArrayRef<float >{v});
959
+ }
960
+
961
+ OpFoldResult ReverseOp::fold (FoldAdaptor adaptor) {
887
962
auto operand = getInput ();
888
963
auto operandTy = operand.getType ().cast <ShapedType>();
889
964
auto axis = getAxis ();
@@ -898,7 +973,21 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
898
973
return {};
899
974
}
900
975
901
- OpFoldResult SliceOp::fold (ArrayRef<Attribute> operands) {
976
+ OpFoldResult ReciprocalOp::fold (FoldAdaptor adaptor) {
977
+ auto src = adaptor.getInput1 ().dyn_cast_or_null <mlir::DenseElementsAttr>();
978
+
979
+ if (!src)
980
+ return nullptr ;
981
+
982
+ std::vector<float > v;
983
+ v.resize (src.getNumElements ());
984
+ for (int i=0 ; i< src.getNumElements (); ++i)
985
+ v[i] = 1.0 / src.getValues <float >()[i];
986
+
987
+ return mlir::DenseElementsAttr::get (src.getType (), ArrayRef (v));
988
+ }
989
+
990
+ OpFoldResult SliceOp::fold (FoldAdaptor adaptor) {
902
991
auto inputTy = getInput ().getType ().dyn_cast <RankedTensorType>();
903
992
auto outputTy = getType ().dyn_cast <RankedTensorType>();
904
993
0 commit comments