@@ -633,10 +633,243 @@ static DynamicType BiggerType(DynamicType type) {
633
633
return type;
634
634
}
635
635
636
+ // / Structure to register intrinsic argument checks that must be performed.
637
+ using ArgumentVerifierFunc = bool (*)(
638
+ const std::vector<Expr<SomeType>> &, FoldingContext &);
639
+ struct ArgumentVerifier {
640
+ using Key = std::string_view;
641
+ // Needed for implicit compare with keys.
642
+ constexpr operator Key () const { return key; }
643
+ Key key;
644
+ ArgumentVerifierFunc verifier;
645
+ };
646
+
647
+ static constexpr int lastArg{-1 };
648
+ static constexpr int firstArg{0 };
649
+
650
+ static const Expr<SomeType> &GetArg (
651
+ int position, const std::vector<Expr<SomeType>> &args) {
652
+ if (position == lastArg) {
653
+ CHECK (!args.empty ());
654
+ return args.back ();
655
+ }
656
+ CHECK (position >= 0 && static_cast <std::size_t >(position) < args.size ());
657
+ return args[position];
658
+ }
659
+
660
+ template <typename T>
661
+ static bool IsInRange (const Expr<T> &expr, int lb, int ub) {
662
+ if (auto scalar{GetScalarConstantValue<T>(expr)}) {
663
+ auto lbValue{Scalar<T>::FromInteger (value::Integer<8 >{lb}).value };
664
+ auto ubValue{Scalar<T>::FromInteger (value::Integer<8 >{ub}).value };
665
+ return Satisfies (RelationalOperator::LE, lbValue.Compare (*scalar)) &&
666
+ Satisfies (RelationalOperator::LE, scalar->Compare (ubValue));
667
+ }
668
+ return true ;
669
+ }
670
+
671
+ // / Verify that the argument in an intrinsic call belongs to [lb, ub] if is
672
+ // / real.
673
+ template <int lb, int ub>
674
+ static bool VerifyInRangeIfReal (
675
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
676
+ if (const auto *someReal{
677
+ std::get_if<Expr<SomeReal>>(&GetArg (firstArg, args).u )}) {
678
+ bool isInRange{
679
+ std::visit ([&](const auto &x) -> bool { return IsInRange (x, lb, ub); },
680
+ someReal->u )};
681
+ if (!isInRange) {
682
+ context.messages ().Say (
683
+ " argument is out of range [%d., %d.]" _warn_en_US, lb, ub);
684
+ }
685
+ return isInRange;
686
+ }
687
+ return true ;
688
+ }
689
+
690
+ template <int argPosition, const char *argName>
691
+ static bool VerifyStrictlyPositiveIfReal (
692
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
693
+ if (const auto *someReal =
694
+ std::get_if<Expr<SomeReal>>(&GetArg (argPosition, args).u )) {
695
+ const bool isStrictlyPositive{std::visit (
696
+ [&](const auto &x) -> bool {
697
+ using T = typename std::decay_t <decltype (x)>::Result;
698
+ auto scalar{GetScalarConstantValue<T>(x)};
699
+ return Satisfies (
700
+ RelationalOperator::LT, Scalar<T>{}.Compare (*scalar));
701
+ },
702
+ someReal->u )};
703
+ if (!isStrictlyPositive) {
704
+ context.messages ().Say (
705
+ " argument '%s' must be strictly positive" _warn_en_US, argName);
706
+ }
707
+ return isStrictlyPositive;
708
+ }
709
+ return true ;
710
+ }
711
+
712
+ // / Verify that an intrinsic call argument is not zero if it is real.
713
+ template <int argPosition, const char *argName>
714
+ static bool VerifyNotZeroIfReal (
715
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
716
+ if (const auto *someReal =
717
+ std::get_if<Expr<SomeReal>>(&GetArg (argPosition, args).u )) {
718
+ const bool isNotZero{std::visit (
719
+ [&](const auto &x) -> bool {
720
+ using T = typename std::decay_t <decltype (x)>::Result;
721
+ auto scalar{GetScalarConstantValue<T>(x)};
722
+ return !scalar || !scalar->IsZero ();
723
+ },
724
+ someReal->u )};
725
+ if (!isNotZero) {
726
+ context.messages ().Say (
727
+ " argument '%s' must be different from zero" _warn_en_US, argName);
728
+ }
729
+ return isNotZero;
730
+ }
731
+ return true ;
732
+ }
733
+
734
+ // / Verify that the argument in an intrinsic call is not zero if is complex.
735
+ static bool VerifyNotZeroIfComplex (
736
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
737
+ if (const auto *someComplex =
738
+ std::get_if<Expr<SomeComplex>>(&GetArg (firstArg, args).u )) {
739
+ const bool isNotZero{std::visit (
740
+ [&](const auto &z) -> bool {
741
+ using T = typename std::decay_t <decltype (z)>::Result;
742
+ auto scalar{GetScalarConstantValue<T>(z)};
743
+ return !scalar || !scalar->IsZero ();
744
+ },
745
+ someComplex->u )};
746
+ if (!isNotZero) {
747
+ context.messages ().Say (
748
+ " complex argument must be different from zero" _warn_en_US);
749
+ }
750
+ return isNotZero;
751
+ }
752
+ return true ;
753
+ }
754
+
755
+ // Verify that the argument in an intrinsic call is not zero and not a negative
756
+ // integer.
757
+ static bool VerifyGammaLikeArgument (
758
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
759
+ if (const auto *someReal =
760
+ std::get_if<Expr<SomeReal>>(&GetArg (firstArg, args).u )) {
761
+ const bool isValid{std::visit (
762
+ [&](const auto &x) -> bool {
763
+ using T = typename std::decay_t <decltype (x)>::Result;
764
+ auto scalar{GetScalarConstantValue<T>(x)};
765
+ if (scalar) {
766
+ return !scalar->IsZero () &&
767
+ !(scalar->IsNegative () &&
768
+ scalar->ToWholeNumber ().value == scalar);
769
+ }
770
+ return true ;
771
+ },
772
+ someReal->u )};
773
+ if (!isValid) {
774
+ context.messages ().Say (
775
+ " argument must not be a negative integer or zero" _warn_en_US);
776
+ }
777
+ return isValid;
778
+ }
779
+ return true ;
780
+ }
781
+
782
+ // Verify that two real arguments are not both zero.
783
+ static bool VerifyAtan2LikeArguments (
784
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
785
+ if (const auto *someReal =
786
+ std::get_if<Expr<SomeReal>>(&GetArg (firstArg, args).u )) {
787
+ const bool isValid{std::visit (
788
+ [&](const auto &typedExpr) -> bool {
789
+ using T = typename std::decay_t <decltype (typedExpr)>::Result;
790
+ auto x{GetScalarConstantValue<T>(typedExpr)};
791
+ auto y{GetScalarConstantValue<T>(GetArg (lastArg, args))};
792
+ if (x && y) {
793
+ return !(x->IsZero () && y->IsZero ());
794
+ }
795
+ return true ;
796
+ },
797
+ someReal->u )};
798
+ if (!isValid) {
799
+ context.messages ().Say (
800
+ " 'x' and 'y' arguments must not be both zero" _warn_en_US);
801
+ }
802
+ return isValid;
803
+ }
804
+ return true ;
805
+ }
806
+
807
+ template <ArgumentVerifierFunc... F>
808
+ static bool CombineVerifiers (
809
+ const std::vector<Expr<SomeType>> &args, FoldingContext &context) {
810
+ return (... & F (args, context));
811
+ }
812
+
813
+ // / Define argument names to be used error messages when the intrinsic have
814
+ // / several arguments.
815
+ static constexpr char xName[]{" x" };
816
+ static constexpr char pName[]{" p" };
817
+
818
+ // / Register argument verifiers for all intrinsics folded with runtime.
819
+ static constexpr ArgumentVerifier intrinsicArgumentVerifiers[]{
820
+ {" acos" , VerifyInRangeIfReal<-1 , 1 >},
821
+ {" asin" , VerifyInRangeIfReal<-1 , 1 >},
822
+ {" atan2" , VerifyAtan2LikeArguments},
823
+ {" bessel_y0" , VerifyStrictlyPositiveIfReal<firstArg, xName>},
824
+ {" bessel_y1" , VerifyStrictlyPositiveIfReal<firstArg, xName>},
825
+ {" bessel_yn" , VerifyStrictlyPositiveIfReal<lastArg, xName>},
826
+ {" gamma" , VerifyGammaLikeArgument},
827
+ {" log" ,
828
+ CombineVerifiers<VerifyStrictlyPositiveIfReal<firstArg, xName>,
829
+ VerifyNotZeroIfComplex>},
830
+ {" log10" , VerifyStrictlyPositiveIfReal<firstArg, xName>},
831
+ {" log_gamma" , VerifyGammaLikeArgument},
832
+ {" mod" , VerifyNotZeroIfReal<lastArg, pName>},
833
+ };
834
+
835
+ const ArgumentVerifierFunc *findVerifier (const std::string &intrinsicName) {
836
+ static constexpr Fortran::common::StaticMultimapView<ArgumentVerifier>
837
+ verifiers (intrinsicArgumentVerifiers);
838
+ static_assert (verifiers.Verify (), " map must be sorted" );
839
+ auto range{verifiers.equal_range (intrinsicName)};
840
+ if (range.first != range.second ) {
841
+ return &range.first ->verifier ;
842
+ }
843
+ return nullptr ;
844
+ }
845
+
846
+ // / Ensure argument verifiers, if any, are run before calling the runtime
847
+ // / wrapper to fold an intrinsic.
848
+ static HostRuntimeWrapper AddArgumentVerifierIfAny (
849
+ const std::string &intrinsicName, const HostRuntimeFunction &hostFunction) {
850
+ if (const auto *verifier{findVerifier (intrinsicName)}) {
851
+ const HostRuntimeFunction *hostFunctionPtr = &hostFunction;
852
+ return [hostFunctionPtr, verifier](
853
+ FoldingContext &context, std::vector<Expr<SomeType>> &&args) {
854
+ const bool validArguments{(*verifier)(args, context)};
855
+ if (!validArguments) {
856
+ // Silence fp signal warnings since a more detailed warning about
857
+ // invalid arguments was already emitted.
858
+ parser::Messages localBuffer;
859
+ parser::ContextualMessages localMessages{&localBuffer};
860
+ FoldingContext localContext{context, localMessages};
861
+ return hostFunctionPtr->folder (localContext, std::move (args));
862
+ }
863
+ return hostFunctionPtr->folder (context, std::move (args));
864
+ };
865
+ }
866
+ return hostFunction.folder ;
867
+ }
868
+
636
869
std::optional<HostRuntimeWrapper> GetHostRuntimeWrapper (const std::string &name,
637
870
DynamicType resultType, const std::vector<DynamicType> &argTypes) {
638
871
if (const auto *hostFunction{SearchHostRuntime (name, resultType, argTypes)}) {
639
- return hostFunction-> folder ;
872
+ return AddArgumentVerifierIfAny (name, * hostFunction) ;
640
873
}
641
874
// If no exact match, search with "bigger" types and insert type
642
875
// conversions around the folder.
@@ -647,7 +880,8 @@ std::optional<HostRuntimeWrapper> GetHostRuntimeWrapper(const std::string &name,
647
880
}
648
881
if (const auto *hostFunction{
649
882
SearchHostRuntime (name, biggerResultType, biggerArgTypes)}) {
650
- return [hostFunction, resultType](
883
+ auto hostFolderWithChecks{AddArgumentVerifierIfAny (name, *hostFunction)};
884
+ return [hostFunction, resultType, hostFolderWithChecks](
651
885
FoldingContext &context, std::vector<Expr<SomeType>> &&args) {
652
886
auto nArgs{args.size ()};
653
887
for (size_t i{0 }; i < nArgs; ++i) {
@@ -657,7 +891,7 @@ std::optional<HostRuntimeWrapper> GetHostRuntimeWrapper(const std::string &name,
657
891
}
658
892
return Fold (context,
659
893
ConvertToType (
660
- resultType, hostFunction-> folder (context, std::move (args)))
894
+ resultType, hostFolderWithChecks (context, std::move (args)))
661
895
.value ());
662
896
};
663
897
}
0 commit comments