Skip to content

Commit 56d2cb8

Browse files
ColtAllentwiecki
authored andcommitted
Refactored hyp2f1_der in terms of gamma and ran black formatting
1 parent 23c708a commit 56d2cb8

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

pytensor/scalar/math.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
upgrade_to_float64,
3535
upgrade_to_float_no_complex,
3636
ScalarType,
37-
ScalarVariable
37+
ScalarVariable,
3838
)
3939

4040

@@ -1526,6 +1526,9 @@ class Hyp2F1Der(ScalarOp):
15261526
"""
15271527
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
15281528
1529+
Currently written in terms of gamma until poch and factorial Ops are ready:
1530+
poch(z, m) = (gamma(z + m) / gamma(m))
1531+
factorial(n) = gamma(n+1)
15291532
"""
15301533

15311534
nin = 5
@@ -1538,27 +1541,33 @@ def _infinisum(f):
15381541

15391542
n, res = 0, f(0)
15401543
while True:
1541-
term = f(n+1)
1544+
term = f(n + 1)
15421545
if RuntimeWarning:
15431546
break
1544-
if (res+term)-res == 0:
1547+
if (res + term) - res == 0:
15451548
break
1546-
n,res = n+1, res+term
1549+
n, res = n + 1, res + term
15471550
return res
15481551

15491552
def _hyp2f1_da(a, b, c, z):
15501553
"""
15511554
Derivative of hyp2f1 wrt a
1555+
15521556
"""
15531557

15541558
if abs(z) >= 1:
1555-
raise NotImplementedError('Gradient not supported for |z| >= 1')
1559+
raise NotImplementedError("Gradient not supported for |z| >= 1")
15561560

15571561
else:
1558-
15591562
term1 = _infinisum(
1560-
lambda k: (poch(a, k) * poch(b, k) * psi(a + k) * (z**k))
1561-
/ (poch(c, k) * factorial(k))
1563+
lambda k: (
1564+
(gamma(a + k) / gamma(a))
1565+
* (gamma(b + k) / gamma(b))
1566+
* psi(a + k)
1567+
* (z**k)
1568+
)
1569+
/ (gamma(c + k) / gamma(c))
1570+
* gamma(k + 1)
15621571
)
15631572
term2 = psi(a) * hyp2f1(a, b, c, z)
15641573

@@ -1570,12 +1579,18 @@ def _hyp2f1_db(a, b, c, z):
15701579
"""
15711580

15721581
if abs(z) >= 1:
1573-
raise NotImplementedError('Gradient not supported for |z| >= 1')
1582+
raise NotImplementedError("Gradient not supported for |z| >= 1")
15741583

15751584
else:
15761585
term1 = _infinisum(
1577-
lambda k: (poch(a, k) * poch(b, k) * psi(b + k) * (z**k))
1578-
/ (poch(c, k) * factorial(k))
1586+
lambda k: (
1587+
(gamma(a + k) / gamma(a))
1588+
* (gamma(b + k) / gamma(b))
1589+
* psi(b + k)
1590+
* (z**k)
1591+
)
1592+
/ (gamma(c + k) / gamma(c))
1593+
* gamma(k + 1)
15791594
)
15801595
term2 = psi(b) * hyp2f1(a, b, c, z)
15811596

@@ -1586,13 +1601,19 @@ def _hyp2f1_dc(a, b, c, z):
15861601
Derivative of hyp2f1 wrt c
15871602
"""
15881603
if abs(z) >= 1:
1589-
raise NotImplementedError('Gradient not supported for |z| >= 1')
1604+
raise NotImplementedError("Gradient not supported for |z| >= 1")
15901605

15911606
else:
15921607
term1 = psi(c) * hyp2f1(a, b, c, z)
15931608
term2 = _infinisum(
1594-
lambda k: (poch(a, k) * poch(b, k) * psi(c + k) * (z**k))
1595-
/ (poch(c, k) * factorial(k))
1609+
lambda k: (
1610+
(gamma(a + k) / gamma(a))
1611+
* (gamma(b + k) / gamma(b))
1612+
* psi(c + k)
1613+
* (z**k)
1614+
)
1615+
/ (gamma(c + k) / gamma(c))
1616+
* gamma(k + 1)
15961617
)
15971618
return term1 - term2
15981619

@@ -1624,7 +1645,7 @@ def poch(z: ScalarType, m: ScalarType) -> ScalarVariable:
16241645
Pochhammer symbol (rising factorial) function.
16251646
16261647
"""
1627-
return gamma(z+m) / gamma(z)
1648+
return gamma(z + m) / gamma(z)
16281649

16291650

16301651
def factorial(n: ScalarType) -> ScalarVariable:

tests/tensor/test_math_scipy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def test_deprecated_module():
759759
random_ranged(0, 1000, (2, 3)),
760760
random_ranged(0, 1000, (2, 3)),
761761
random_ranged(0, 1000, (2, 3)),
762-
random_ranged(0, .5, (2, 3)),
762+
random_ranged(0, 0.5, (2, 3)),
763763
),
764764
)
765765

@@ -785,30 +785,30 @@ def test_deprecated_module():
785785
)
786786

787787

788-
789-
@pytest.mark.parametrize("z, m",[random_ranged(0, 5, (2, 1), rng=rng),random_ranged(0, 5, (2, 1), rng=rng)])
790-
def test_poch(z,m):
788+
@pytest.mark.parametrize(
789+
"z, m", [random_ranged(0, 5, (2, 1), rng=rng), random_ranged(0, 5, (2, 1), rng=rng)]
790+
)
791+
def test_poch(z, m):
791792

792793
z, m = at.scalars("z", "m")
793794

794795
poch = at.poch(z, m)
795796

796797
actual = function([z, m], poch)
797-
expected = scipy.special.poch(z, m)
798+
expected = scipy.special.poch(z, m)
798799

799800
assert actual == expected
800801

801802

802-
803-
@pytest.mark.parametrize("n",random_ranged(0, 5, (2, 1)))
803+
@pytest.mark.parametrize("n", random_ranged(0, 5, (2, 1)))
804804
def test_factorial(n):
805805

806806
n = at.scalars("n")
807807

808808
factorial = at.factorial(n)
809809

810810
actual = function([n], factorial)
811-
expected = scipy.special.factorial(n)
811+
expected = scipy.special.factorial(n)
812812

813813
assert actual == expected
814814

0 commit comments

Comments
 (0)