Skip to content

Commit 23c708a

Browse files
ColtAllentwiecki
authored andcommitted
Refactored Poch and Factorial into helper functions for at.Gamma. Rewrote tests.
1 parent 0122903 commit 23c708a

File tree

3 files changed

+44
-101
lines changed

3 files changed

+44
-101
lines changed

pytensor/scalar/math.py

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
upgrade_to_float,
3434
upgrade_to_float64,
3535
upgrade_to_float_no_complex,
36+
ScalarType,
37+
ScalarVariable
3638
)
3739

3840

@@ -1494,7 +1496,11 @@ class Hyp2F1(ScalarOp):
14941496

14951497
@staticmethod
14961498
def st_impl(a, b, c, z):
1497-
return scipy.special.hyp2f1(a, b, c, z)
1499+
1500+
if abs(z) >= 1:
1501+
raise NotImplementedError("hyp2f1 only supported for z < 1.")
1502+
else:
1503+
return scipy.special.hyp2f1(a, b, c, z)
14981504

14991505
def impl(self, a, b, c, z):
15001506
return Hyp2F1.st_impl(a, b, c, z)
@@ -1551,10 +1557,10 @@ def _hyp2f1_da(a, b, c, z):
15511557
else:
15521558

15531559
term1 = _infinisum(
1554-
lambda k: (scipy.special.poch(a, k) * scipy.special.poch(b, k) * scipy.special.digamma(a + k) * (z**k))
1555-
/ (scipy.special.poch(c, k) * scipy.special.factorial(k))
1560+
lambda k: (poch(a, k) * poch(b, k) * psi(a + k) * (z**k))
1561+
/ (poch(c, k) * factorial(k))
15561562
)
1557-
term2 = scipy.special.digamma(a) * scipy.special.hyp2f1(a, b, c, z)
1563+
term2 = psi(a) * hyp2f1(a, b, c, z)
15581564

15591565
return term1 - term2
15601566

@@ -1568,10 +1574,10 @@ def _hyp2f1_db(a, b, c, z):
15681574

15691575
else:
15701576
term1 = _infinisum(
1571-
lambda k: (scipy.special.poch(a, k) * scipy.special.poch(b, k) * scipy.special.digamma(b + k) * (z**k))
1572-
/ (scipy.special.poch(c, k) * scipy.special.factorial(k))
1577+
lambda k: (poch(a, k) * poch(b, k) * psi(b + k) * (z**k))
1578+
/ (poch(c, k) * factorial(k))
15731579
)
1574-
term2 = scipy.special.digamma(b) * scipy.special.hyp2f1(a, b, c, z)
1580+
term2 = psi(b) * hyp2f1(a, b, c, z)
15751581

15761582
return term1 - term2
15771583

@@ -1583,10 +1589,10 @@ def _hyp2f1_dc(a, b, c, z):
15831589
raise NotImplementedError('Gradient not supported for |z| >= 1')
15841590

15851591
else:
1586-
term1 = scipy.special.digamma(c) * scipy.special.hyp2f1(a, b, c, z)
1592+
term1 = psi(c) * hyp2f1(a, b, c, z)
15871593
term2 = _infinisum(
1588-
lambda k: (scipy.special.poch(a, k) * scipy.special.poch(b, k) * scipy.special.digamma(c + k) * (z**k))
1589-
/ (scipy.special.poch(c, k) * scipy.special.factorial(k))
1594+
lambda k: (poch(a, k) * poch(b, k) * psi(c + k) * (z**k))
1595+
/ (poch(c, k) * factorial(k))
15901596
)
15911597
return term1 - term2
15921598

@@ -1595,7 +1601,7 @@ def _hyp2f1_dz(a, b, c, z):
15951601
Derivative of hyp2f1 wrt z
15961602
"""
15971603

1598-
return ((a * b) / c) * scipy.special.hyp2f1(a + 1, b + 1, c + 1, z)
1604+
return ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z)
15991605

16001606
if wrt == 0:
16011607
return _hyp2f1_da(a, b, c, z)
@@ -1613,58 +1619,17 @@ def c_code(self, *args, **kwargs):
16131619
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
16141620

16151621

1616-
class Poch(BinaryScalarOp):
1622+
def poch(z: ScalarType, m: ScalarType) -> ScalarVariable:
16171623
"""
16181624
Pochhammer symbol (rising factorial) function.
16191625
16201626
"""
1621-
1622-
nfunc_spec = ("scipy.special.poch", 2, 1)
1623-
1624-
@staticmethod
1625-
def st_impl(z, m):
1626-
return gamma(z+m) / gamma(z)
1627-
1628-
def impl(self, z, m):
1629-
return Poch.st_impl(z, m)
1630-
1631-
def grad(self, inputs, grads):
1632-
z, m = inputs
1633-
(gz,) = grads
1634-
return [
1635-
gz * poch(z, m) * (tri_gamma(z + m) - tri_gamma(z)),
1636-
gz * poch(z, m) * tri_gamma(z + m)
1637-
]
1638-
1639-
def c_code(self, *args, **kwargs):
1640-
raise NotImplementedError()
1627+
return gamma(z+m) / gamma(z)
16411628

16421629

1643-
poch = Poch(upgrade_to_float, name="poch")
1644-
1645-
1646-
class Factorial(UnaryScalarOp):
1630+
def factorial(n: ScalarType) -> ScalarVariable:
16471631
"""
16481632
Factorial function of a scalar or array of numbers.
16491633
16501634
"""
1651-
1652-
nfunc_spec = ("scipy.special.factorial", 1, 1)
1653-
1654-
@staticmethod
1655-
def st_impl(n):
1656-
return gamma(n+1)
1657-
1658-
def impl(self, n):
1659-
return Factorial.st_impl(n)
1660-
1661-
def grad(self, inputs, grads):
1662-
(n,) = inputs
1663-
(gz,) = grads
1664-
return [gz * gamma(n+1) * tri_gamma(n+1)]
1665-
1666-
def c_code(self, *args, **kwargs):
1667-
raise NotImplementedError()
1668-
1669-
1670-
factorial = Factorial(upgrade_to_float, name="factorial")
1635+
return gamma(n + 1)

pytensor/tensor/inplace.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -397,16 +397,6 @@ def hyp2f1_inplace(a, b, c, z):
397397
"""gaussian hypergeometric function"""
398398

399399

400-
@scalar_elemwise
401-
def poch_inplace(z, m):
402-
"""pochhammer symbol (rising factorial) function"""
403-
404-
405-
@scalar_elemwise
406-
def factorial_inplace(n):
407-
"""factorial function"""
408-
409-
410400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
411401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
412402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

tests/tensor/test_math_scipy.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def scipy_special_gammal(k, x):
7272
expected_erfcx = scipy.special.erfcx
7373
expected_sigmoid = scipy.special.expit
7474
expected_hyp2f1 = scipy.special.hyp2f1
75-
expected_poch = scipy.special.poch
76-
expected_factorial = scipy.special.factorial
7775

7876
TestErfBroadcast = makeBroadcastTester(
7977
op=at.erf,
@@ -786,43 +784,33 @@ def test_deprecated_module():
786784
)
787785
)
788786

789-
TestPochBroadcast = makeBroadcastTester(
790-
op=at.poch,
791-
expected=expected_poch,
792-
good=_good_broadcast_binary_poch,
793-
eps=2e-10,
794-
mode=mode_no_scipy,
795-
)
796787

797-
TestPochInplaceBroadcast = makeBroadcastTester(
798-
op=inplace.poch_inplace,
799-
expected=expected_poch,
800-
good=_good_broadcast_binary_poch,
801-
mode=mode_no_scipy,
802-
inplace=True,
803-
)
804788

805-
_good_broadcast_unary_factorial = dict(
806-
normal=(
807-
random_ranged(0, 5, (2, 1), rng=rng),
808-
)
809-
)
789+
@pytest.mark.parametrize("z, m",[random_ranged(0, 5, (2, 1), rng=rng),random_ranged(0, 5, (2, 1), rng=rng)])
790+
def test_poch(z,m):
810791

811-
TestFactorialBroadcast = makeBroadcastTester(
812-
op=at.factorial,
813-
expected=expected_factorial,
814-
good=_good_broadcast_unary_factorial,
815-
eps=2e-10,
816-
mode=mode_no_scipy,
817-
)
792+
z, m = at.scalars("z", "m")
818793

819-
TestFactorialInplaceBroadcast = makeBroadcastTester(
820-
op=inplace.factorial_inplace,
821-
expected=expected_factorial,
822-
good=_good_broadcast_unary_factorial,
823-
mode=mode_no_scipy,
824-
inplace=True,
825-
)
794+
poch = at.poch(z, m)
795+
796+
actual = function([z, m], poch)
797+
expected = scipy.special.poch(z, m)
798+
799+
assert actual == expected
800+
801+
802+
803+
@pytest.mark.parametrize("n",random_ranged(0, 5, (2, 1)))
804+
def test_factorial(n):
805+
806+
n = at.scalars("n")
807+
808+
factorial = at.factorial(n)
809+
810+
actual = function([n], factorial)
811+
expected = scipy.special.factorial(n)
812+
813+
assert actual == expected
826814

827815

828816
class TestBetaIncGrad:

0 commit comments

Comments
 (0)