Skip to content

Commit d601940

Browse files
ColtAllentwiecki
authored andcommitted
Add Hyp2F1, poch, and factorial
1 parent f3ad76b commit d601940

File tree

5 files changed

+224
-7
lines changed

5 files changed

+224
-7
lines changed

pytensor/scalar/math.py.rej

Lines changed: 155 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,157 @@
11
diff a/pytensor/scalar/math.py b/pytensor/scalar/math.py (rejected hunks)
2-
@@ -7,7 +7,6 @@
3-
import os
4-
import warnings
2+
@@ -1493,3 +1493,155 @@ def c_code(self, *args, **kwargs):
53

6-
-import mpmath as mp
7-
import numpy as np
8-
import scipy.special
9-
import scipy.stats
4+
5+
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
6+
+
7+
+
8+
+class Hyp2F1(ScalarOp):
9+
+ """
10+
+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
11+
+
12+
+ """
13+
+
14+
+ nin = 4
15+
+ nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
16+
+
17+
+ @staticmethod
18+
+ def st_impl(a, b, c, z):
19+
+
20+
+ if abs(z) >= 1:
21+
+ raise NotImplementedError("hyp2f1 only supported for z < 1.")
22+
+ else:
23+
+ return scipy.special.hyp2f1(a, b, c, z)
24+
+
25+
+ def impl(self, a, b, c, z):
26+
+ return Hyp2F1.st_impl(a, b, c, z)
27+
+
28+
+ def grad(self, inputs, grads):
29+
+ a, b, c, z = inputs
30+
+ (gz,) = grads
31+
+ return [
32+
+ gz * hyp2f1_der(a, b, c, z, 0),
33+
+ gz * hyp2f1_der(a, b, c, z, 1),
34+
+ gz * hyp2f1_der(a, b, c, z, 2),
35+
+ gz * hyp2f1_der(a, b, c, z, 3),
36+
+ ]
37+
+
38+
+ def c_code(self, *args, **kwargs):
39+
+ raise NotImplementedError()
40+
+
41+
+
42+
+hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
43+
+
44+
+
45+
+class Hyp2F1Der(ScalarOp):
46+
+ """
47+
+ Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
48+
+
49+
+ """
50+
+
51+
+ nin = 5
52+
+
53+
+ def impl(self, a, b, c, z, wrt):
54+
+ def _infinisum(f):
55+
+ """
56+
+ Utility function for infinite summations.
57+
+ """
58+
+
59+
+ n, res = 0, f(0)
60+
+ while True:
61+
+ term = f(n + 1)
62+
+ if RuntimeWarning:
63+
+ break
64+
+ if (res + term) - res == 0:
65+
+ break
66+
+ n, res = n + 1, res + term
67+
+ return res
68+
+
69+
+ def _hyp2f1_da(a, b, c, z):
70+
+ """
71+
+ Derivative of hyp2f1 wrt a
72+
+
73+
+ """
74+
+
75+
+ if abs(z) >= 1:
76+
+ raise NotImplementedError("Gradient not supported for |z| >= 1")
77+
+
78+
+ else:
79+
+ term1 = _infinisum(
80+
+ lambda k: (
81+
+ (gamma(a + k) / gamma(a))
82+
+ * (gamma(b + k) / gamma(b))
83+
+ * psi(a + k)
84+
+ * (z**k)
85+
+ )
86+
+ / (gamma(c + k) / gamma(c))
87+
+ * gamma(k + 1)
88+
+ )
89+
+ term2 = psi(a) * hyp2f1(a, b, c, z)
90+
+
91+
+ return term1 - term2
92+
+
93+
+ def _hyp2f1_db(a, b, c, z):
94+
+ """
95+
+ Derivative of hyp2f1 wrt b
96+
+ """
97+
+
98+
+ if abs(z) >= 1:
99+
+ raise NotImplementedError("Gradient not supported for |z| >= 1")
100+
+
101+
+ else:
102+
+ term1 = _infinisum(
103+
+ lambda k: (
104+
+ (gamma(a + k) / gamma(a))
105+
+ * (gamma(b + k) / gamma(b))
106+
+ * psi(b + k)
107+
+ * (z**k)
108+
+ )
109+
+ / (gamma(c + k) / gamma(c))
110+
+ * gamma(k + 1)
111+
+ )
112+
+ term2 = psi(b) * hyp2f1(a, b, c, z)
113+
+
114+
+ return term1 - term2
115+
+
116+
+ def _hyp2f1_dc(a, b, c, z):
117+
+ """
118+
+ Derivative of hyp2f1 wrt c
119+
+ """
120+
+ if abs(z) >= 1:
121+
+ raise NotImplementedError("Gradient not supported for |z| >= 1")
122+
+
123+
+ else:
124+
+ term1 = psi(c) * hyp2f1(a, b, c, z)
125+
+ term2 = _infinisum(
126+
+ lambda k: (
127+
+ (gamma(a + k) / gamma(a))
128+
+ * (gamma(b + k) / gamma(b))
129+
+ * psi(c + k)
130+
+ * (z**k)
131+
+ )
132+
+ / (gamma(c + k) / gamma(c))
133+
+ * gamma(k + 1)
134+
+ )
135+
+ return term1 - term2
136+
+
137+
+ def _hyp2f1_dz(a, b, c, z):
138+
+ """
139+
+ Derivative of hyp2f1 wrt z
140+
+ """
141+
+
142+
+ return ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z)
143+
+
144+
+ if wrt == 0:
145+
+ return _hyp2f1_da(a, b, c, z)
146+
+ elif wrt == 1:
147+
+ return _hyp2f1_db(a, b, c, z)
148+
+ elif wrt == 2:
149+
+ return _hyp2f1_dc(a, b, c, z)
150+
+ elif wrt == 3:
151+
+ return _hyp2f1_dz(a, b, c, z)
152+
+
153+
+ def c_code(self, *args, **kwargs):
154+
+ raise NotImplementedError()
155+
+
156+
+
157+
+hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/inplace.py.rej

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
diff a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py (rejected hunks)
2+
@@ -392,6 +392,11 @@ def conj_inplace(a):
3+
"""elementwise conjugate (inplace on `a`)"""
4+
5+
6+
+@scalar_elemwise
7+
+def hyp2f1_inplace(a, b, c, z):
8+
+ """gaussian hypergeometric function"""
9+
+
10+
+
11+
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
12+
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
13+
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

pytensor/tensor/math.py.rej

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
diff a/pytensor/tensor/math.py b/pytensor/tensor/math.py (rejected hunks)
2+
@@ -1386,6 +1386,11 @@ def gammal(k, x):
3+
"""Lower incomplete gamma function."""
4+
5+
6+
+@scalar_elemwise
7+
+def hyp2f1(a, b, c, z):
8+
+ """Gaussian hypergeometric function."""
9+
+
10+
+
11+
@scalar_elemwise
12+
def j0(x):
13+
"""Bessel function of the first kind of order 0."""
14+
@@ -3134,6 +3139,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
15+
"power",
16+
"logaddexp",
17+
"logsumexp",
18+
+ "hyp2f1",
19+
]
20+
21+
DEPRECATED_NAMES = [

tests/tensor/test_math_scipy.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,32 @@ def test_deprecated_module():
780780
inplace=True,
781781
)
782782

783+
_good_broadcast_quaternary_hyp2f1 = dict(
784+
normal=(
785+
random_ranged(0, 1000, (2, 3)),
786+
random_ranged(0, 1000, (2, 3)),
787+
random_ranged(0, 1000, (2, 3)),
788+
random_ranged(0, 0.5, (2, 3)),
789+
),
790+
)
791+
792+
TestHyp2F1Broadcast = makeBroadcastTester(
793+
op=at.hyp2f1,
794+
expected=expected_hyp2f1,
795+
good=_good_broadcast_quaternary_hyp2f1,
796+
eps=2e-10,
797+
mode=mode_no_scipy,
798+
)
799+
800+
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
801+
op=inplace.hyp2f1_inplace,
802+
expected=expected_hyp2f1,
803+
good=_good_broadcast_quaternary_hyp2f1,
804+
eps=2e-10,
805+
mode=mode_no_scipy,
806+
inplace=True,
807+
)
808+
783809

784810
class TestBetaIncGrad:
785811
def test_stan_grad_partial(self):

tests/tensor/test_math_scipy.py.rej

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
diff a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py (rejected hunks)
2+
@@ -71,6 +71,7 @@ def scipy_special_gammal(k, x):
3+
expected_iv = scipy.special.iv
4+
expected_erfcx = scipy.special.erfcx
5+
expected_sigmoid = scipy.special.expit
6+
+expected_hyp2f1 = scipy.special.hyp2f1
7+
8+
TestErfBroadcast = makeBroadcastTester(
9+
op=at.erf,

0 commit comments

Comments
 (0)