Skip to content

Commit f3ad76b

Browse files
ColtAllentwiecki
authored andcommitted
Moved factorial and poch into tensor.special
1 parent 56d2cb8 commit f3ad76b

File tree

5 files changed

+67
-76
lines changed

5 files changed

+67
-76
lines changed

pytensor/scalar/math.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
upgrade_to_float,
3434
upgrade_to_float64,
3535
upgrade_to_float_no_complex,
36-
ScalarType,
37-
ScalarVariable,
3836
)
3937

4038

@@ -1526,9 +1524,6 @@ class Hyp2F1Der(ScalarOp):
15261524
"""
15271525
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
15281526
1529-
Currently written in terms of gamma until poch and factorial Ops are ready:
1530-
poch(z, m) = (gamma(z + m) / gamma(m))
1531-
factorial(n) = gamma(n+1)
15321527
"""
15331528

15341529
nin = 5
@@ -1637,20 +1632,4 @@ def c_code(self, *args, **kwargs):
16371632
raise NotImplementedError()
16381633

16391634

1640-
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
1641-
1642-
1643-
def poch(z: ScalarType, m: ScalarType) -> ScalarVariable:
1644-
"""
1645-
Pochhammer symbol (rising factorial) function.
1646-
1647-
"""
1648-
return gamma(z + m) / gamma(z)
1649-
1650-
1651-
def factorial(n: ScalarType) -> ScalarVariable:
1652-
"""
1653-
Factorial function of a scalar or array of numbers.
1654-
1655-
"""
1656-
return gamma(n + 1)
1635+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/math.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ def gammal(k, x):
13841384
"""Lower incomplete gamma function."""
13851385

13861386

1387+
@scalar_elemwise
1388+
def hyp2f1(a, b, c, z):
1389+
"""Gaussian hypergeometric function."""
1390+
1391+
13871392
@scalar_elemwise
13881393
def j0(x):
13891394
"""Bessel function of the first kind of order 0."""
@@ -1866,21 +1871,6 @@ def clip(x, min, max):
18661871
# for grep: clamp, bound
18671872

18681873

1869-
@scalar_elemwise
1870-
def hyp2f1(a, b, c, z):
1871-
"""gaussian hypergeometric function"""
1872-
1873-
1874-
@scalar_elemwise
1875-
def poch(z, m):
1876-
"""pochhammer symbol (rising factorial) function"""
1877-
1878-
1879-
@scalar_elemwise
1880-
def factorial(n):
1881-
"""factorial function"""
1882-
1883-
18841874
pprint.assign(add, printing.OperatorPrinter("+", -2, "either"))
18851875
pprint.assign(mul, printing.OperatorPrinter("*", -1, "either"))
18861876
pprint.assign(sub, printing.OperatorPrinter("-", -2, "left"))
@@ -3148,8 +3138,6 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31483138
"logaddexp",
31493139
"logsumexp",
31503140
"hyp2f1",
3151-
"poch",
3152-
"factorial",
31533141
]
31543142

31553143
DEPRECATED_NAMES = [

pytensor/tensor/special.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import warnings
22
from textwrap import dedent
3+
from typing import TYPE_CHECKING
34

45
import numpy as np
56
import scipy
67

78
from pytensor.graph.basic import Apply
89
from pytensor.link.c.op import COp
910
from pytensor.tensor.basic import as_tensor_variable
10-
from pytensor.tensor.math import neg, sum
11+
from pytensor.tensor.math import neg, sum, gamma
12+
13+
14+
if TYPE_CHECKING:
15+
from pytensor.tensor import TensorLike, TensorVariable
1116

1217

1318
class SoftmaxGrad(COp):
@@ -768,7 +773,25 @@ def log_softmax(c, axis=UNSET_AXIS):
768773
return LogSoftmax(axis=axis)(c)
769774

770775

776+
def poch(z: "TensorLike", m: "TensorLike") -> "TensorVariable":
777+
"""
778+
Pochhammer symbol (rising factorial) function.
779+
780+
"""
781+
return gamma(z + m) / gamma(z)
782+
783+
784+
def factorial(n: "TensorLike") -> "TensorVariable":
785+
"""
786+
Factorial function of a scalar or array of numbers.
787+
788+
"""
789+
return gamma(n + 1)
790+
791+
771792
__all__ = [
772793
"softmax",
773794
"log_softmax",
795+
"poch",
796+
"factorial",
774797
]

tests/tensor/test_math_scipy.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -768,50 +768,18 @@ def test_deprecated_module():
768768
expected=expected_hyp2f1,
769769
good=_good_broadcast_quaternary_hyp2f1,
770770
eps=2e-10,
771+
mode=mode_no_scipy,
771772
)
772773

773774
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
774775
op=inplace.hyp2f1_inplace,
775776
expected=expected_hyp2f1,
776777
good=_good_broadcast_quaternary_hyp2f1,
778+
eps=2e-10,
779+
mode=mode_no_scipy,
777780
inplace=True,
778781
)
779782

780-
_good_broadcast_binary_poch = dict(
781-
normal=(
782-
random_ranged(0, 5, (2, 1), rng=rng),
783-
random_ranged(0, 5, (2, 1), rng=rng),
784-
)
785-
)
786-
787-
788-
@pytest.mark.parametrize(
789-
"z, m", [random_ranged(0, 5, (2, 1), rng=rng), random_ranged(0, 5, (2, 1), rng=rng)]
790-
)
791-
def test_poch(z, m):
792-
793-
z, m = at.scalars("z", "m")
794-
795-
poch = at.poch(z, m)
796-
797-
actual = function([z, m], poch)
798-
expected = scipy.special.poch(z, m)
799-
800-
assert actual == expected
801-
802-
803-
@pytest.mark.parametrize("n", random_ranged(0, 5, (2, 1)))
804-
def test_factorial(n):
805-
806-
n = at.scalars("n")
807-
808-
factorial = at.factorial(n)
809-
810-
actual = function([n], factorial)
811-
expected = scipy.special.factorial(n)
812-
813-
assert actual == expected
814-
815783

816784
class TestBetaIncGrad:
817785
def test_stan_grad_partial(self):

tests/tensor/test_special.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import numpy as np
22
import pytest
3+
from scipy.special import factorial as scipy_factorial
34
from scipy.special import log_softmax as scipy_log_softmax
5+
from scipy.special import poch as scipy_poch
46
from scipy.special import softmax as scipy_softmax
57

68
from pytensor.compile.function import function
79
from pytensor.configdefaults import config
8-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax
10+
from pytensor.tensor import scalar, scalars
11+
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax, poch, factorial
912
from pytensor.tensor.type import matrix, tensor3, tensor4, vector
13+
from tests.tensor.utils import random_ranged
1014
from tests import unittest_tools as utt
1115

1216

@@ -134,3 +138,32 @@ def test_valid_axis(self):
134138

135139
with pytest.raises(ValueError):
136140
SoftmaxGrad(-4)(*x)
141+
142+
143+
@pytest.mark.parametrize(
144+
"z, m", [random_ranged(0, 5, (2,)), random_ranged(0, 5, (2,))]
145+
)
146+
def test_poch(z, m):
147+
148+
_z, _m = scalars("z", "m")
149+
150+
actual_fn = function([_z, _m], poch(_z, _m))
151+
actual = actual_fn(z, m)
152+
153+
expected = scipy_poch(z, m)
154+
155+
assert np.allclose(actual, expected)
156+
157+
158+
@pytest.mark.parametrize("n", random_ranged(0, 5, (1,)))
159+
def test_factorial(n):
160+
161+
_n = scalar("n")
162+
163+
actual_fn = function([_n], factorial(_n))
164+
actual = actual_fn(n)
165+
166+
expected = scipy_factorial(n)
167+
168+
assert np.allclose(actual, expected)
169+

0 commit comments

Comments
 (0)