Skip to content

Commit 3976cdf

Browse files
ColtAllentwiecki
authored andcommitted
local and remote branch differences resolved.
1 parent d601940 commit 3976cdf

File tree

3 files changed

+93
-182
lines changed

3 files changed

+93
-182
lines changed

pytensor/scalar/math.py.rej

Lines changed: 7 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,10 @@
11
diff a/pytensor/scalar/math.py b/pytensor/scalar/math.py (rejected hunks)
2-
@@ -1493,3 +1493,155 @@ def c_code(self, *args, **kwargs):
2+
@@ -1644,8 +1644,4 @@ def c_code(self, *args, **kwargs):
3+
raise NotImplementedError()
34

45

5-
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
6-
+
7-
+
8-
+class Hyp2F1(ScalarOp):
9-
+ """
10-
+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
11-
+
12-
+ """
13-
+
14-
+ nin = 4
15-
+ nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
16-
+
17-
+ @staticmethod
18-
+ def st_impl(a, b, c, z):
19-
+
20-
+ if abs(z) >= 1:
21-
+ raise NotImplementedError("hyp2f1 only supported for z < 1.")
22-
+ else:
23-
+ return scipy.special.hyp2f1(a, b, c, z)
24-
+
25-
+ def impl(self, a, b, c, z):
26-
+ return Hyp2F1.st_impl(a, b, c, z)
27-
+
28-
+ def grad(self, inputs, grads):
29-
+ a, b, c, z = inputs
30-
+ (gz,) = grads
31-
+ return [
32-
+ gz * hyp2f1_der(a, b, c, z, 0),
33-
+ gz * hyp2f1_der(a, b, c, z, 1),
34-
+ gz * hyp2f1_der(a, b, c, z, 2),
35-
+ gz * hyp2f1_der(a, b, c, z, 3),
36-
+ ]
37-
+
38-
+ def c_code(self, *args, **kwargs):
39-
+ raise NotImplementedError()
40-
+
41-
+
42-
+hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
43-
+
44-
+
45-
+class Hyp2F1Der(ScalarOp):
46-
+ """
47-
+ Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
48-
+
49-
+ """
50-
+
51-
+ nin = 5
52-
+
53-
+ def impl(self, a, b, c, z, wrt):
54-
+ def _infinisum(f):
55-
+ """
56-
+ Utility function for infinite summations.
57-
+ """
58-
+
59-
+ n, res = 0, f(0)
60-
+ while True:
61-
+ term = f(n + 1)
62-
+ if RuntimeWarning:
63-
+ break
64-
+ if (res + term) - res == 0:
65-
+ break
66-
+ n, res = n + 1, res + term
67-
+ return res
68-
+
69-
+ def _hyp2f1_da(a, b, c, z):
70-
+ """
71-
+ Derivative of hyp2f1 wrt a
72-
+
73-
+ """
74-
+
75-
+ if abs(z) >= 1:
76-
+ raise NotImplementedError("Gradient not supported for |z| >= 1")
77-
+
78-
+ else:
79-
+ term1 = _infinisum(
80-
+ lambda k: (
81-
+ (gamma(a + k) / gamma(a))
82-
+ * (gamma(b + k) / gamma(b))
83-
+ * psi(a + k)
84-
+ * (z**k)
85-
+ )
86-
+ / (gamma(c + k) / gamma(c))
87-
+ * gamma(k + 1)
88-
+ )
89-
+ term2 = psi(a) * hyp2f1(a, b, c, z)
90-
+
91-
+ return term1 - term2
92-
+
93-
+ def _hyp2f1_db(a, b, c, z):
94-
+ """
95-
+ Derivative of hyp2f1 wrt b
96-
+ """
97-
+
98-
+ if abs(z) >= 1:
99-
+ raise NotImplementedError("Gradient not supported for |z| >= 1")
100-
+
101-
+ else:
102-
+ term1 = _infinisum(
103-
+ lambda k: (
104-
+ (gamma(a + k) / gamma(a))
105-
+ * (gamma(b + k) / gamma(b))
106-
+ * psi(b + k)
107-
+ * (z**k)
108-
+ )
109-
+ / (gamma(c + k) / gamma(c))
110-
+ * gamma(k + 1)
111-
+ )
112-
+ term2 = psi(b) * hyp2f1(a, b, c, z)
113-
+
114-
+ return term1 - term2
115-
+
116-
+ def _hyp2f1_dc(a, b, c, z):
117-
+ """
118-
+ Derivative of hyp2f1 wrt c
119-
+ """
120-
+ if abs(z) >= 1:
121-
+ raise NotImplementedError("Gradient not supported for |z| >= 1")
122-
+
123-
+ else:
124-
+ term1 = psi(c) * hyp2f1(a, b, c, z)
125-
+ term2 = _infinisum(
126-
+ lambda k: (
127-
+ (gamma(a + k) / gamma(a))
128-
+ * (gamma(b + k) / gamma(b))
129-
+ * psi(c + k)
130-
+ * (z**k)
131-
+ )
132-
+ / (gamma(c + k) / gamma(c))
133-
+ * gamma(k + 1)
134-
+ )
135-
+ return term1 - term2
136-
+
137-
+ def _hyp2f1_dz(a, b, c, z):
138-
+ """
139-
+ Derivative of hyp2f1 wrt z
140-
+ """
141-
+
142-
+ return ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z)
143-
+
144-
+ if wrt == 0:
145-
+ return _hyp2f1_da(a, b, c, z)
146-
+ elif wrt == 1:
147-
+ return _hyp2f1_db(a, b, c, z)
148-
+ elif wrt == 2:
149-
+ return _hyp2f1_dc(a, b, c, z)
150-
+ elif wrt == 3:
151-
+ return _hyp2f1_dz(a, b, c, z)
152-
+
153-
+ def c_code(self, *args, **kwargs):
154-
+ raise NotImplementedError()
155-
+
156-
+
157-
+hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
6+
-<<<<<<< HEAD
7+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
8+
-=======
9+
-hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
10+
->>>>>>> db46ad7d8e0497a3e7be10c7a7551c332a6d877e

pytensor/tensor/special.py.rej

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,51 @@
11
diff a/pytensor/tensor/special.py b/pytensor/tensor/special.py (rejected hunks)
2-
@@ -7,7 +8,11 @@
2+
@@ -8,19 +8,11 @@
33
from pytensor.graph.basic import Apply
44
from pytensor.link.c.op import COp
55
from pytensor.tensor.basic import as_tensor_variable
6-
-from pytensor.tensor.math import neg, sum
7-
+from pytensor.tensor.math import gamma, neg, sum
8-
+
9-
+
10-
+if TYPE_CHECKING:
11-
+ pass
6+
-<<<<<<< HEAD
7+
-from pytensor.tensor.math import neg, sum, gamma
8+
-
9+
-
10+
-if TYPE_CHECKING:
11+
- from pytensor.tensor import TensorLike, TensorVariable
12+
-=======
13+
from pytensor.tensor.math import gamma, neg, sum
14+
15+
16+
if TYPE_CHECKING:
17+
pass
18+
->>>>>>> db46ad7d8e0497a3e7be10c7a7551c332a6d877e
1219

1320

1421
class SoftmaxGrad(COp):
22+
@@ -781,21 +773,6 @@ def log_softmax(c, axis=UNSET_AXIS):
23+
return LogSoftmax(axis=axis)(c)
24+
25+
26+
-<<<<<<< HEAD
27+
-def poch(z: "TensorLike", m: "TensorLike") -> "TensorVariable":
28+
- """
29+
- Pochhammer symbol (rising factorial) function.
30+
-
31+
- """
32+
- return gamma(z + m) / gamma(z)
33+
-
34+
-
35+
-def factorial(n: "TensorLike") -> "TensorVariable":
36+
- """
37+
- Factorial function of a scalar or array of numbers.
38+
-
39+
- """
40+
-=======
41+
def poch(z, m):
42+
"""Compute the Pochhammer/rising factorial."""
43+
return gamma(z + m) / gamma(z)
44+
@@ -803,7 +780,6 @@ def poch(z, m):
45+
46+
def factorial(n):
47+
"""Compute the factorial."""
48+
->>>>>>> db46ad7d8e0497a3e7be10c7a7551c332a6d877e
49+
return gamma(n + 1)
50+
51+

tests/tensor/test_special.py.rej

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,49 @@
11
diff a/tests/tensor/test_special.py b/tests/tensor/test_special.py (rejected hunks)
2-
@@ -1,13 +1,25 @@
3-
import numpy as np
4-
import pytest
5-
+from scipy.special import factorial as scipy_factorial
6-
from scipy.special import log_softmax as scipy_log_softmax
7-
+from scipy.special import poch as scipy_poch
8-
from scipy.special import softmax as scipy_softmax
9-
2+
@@ -8,9 +8,6 @@
103
from pytensor.compile.function import function
114
from pytensor.configdefaults import config
12-
-from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax
13-
+from pytensor.tensor import scalar, scalars
14-
+from pytensor.tensor.special import (
15-
+ LogSoftmax,
16-
+ Softmax,
17-
+ SoftmaxGrad,
18-
+ factorial,
19-
+ log_softmax,
20-
+ poch,
21-
+ softmax,
22-
+)
5+
from pytensor.tensor import scalar, scalars
6+
-<<<<<<< HEAD
7+
-from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax, poch, factorial
8+
-=======
9+
from pytensor.tensor.special import (
10+
LogSoftmax,
11+
Softmax,
12+
@@ -20,7 +17,6 @@
13+
poch,
14+
softmax,
15+
)
16+
->>>>>>> db46ad7d8e0497a3e7be10c7a7551c332a6d877e
2317
from pytensor.tensor.type import matrix, tensor3, tensor4, vector
18+
from tests.tensor.utils import random_ranged
2419
from tests import unittest_tools as utt
25-
+from tests.tensor.utils import random_ranged
20+
@@ -153,13 +149,7 @@ def test_valid_axis(self):
21+
SoftmaxGrad(-4)(*x)
22+
23+
24+
-<<<<<<< HEAD
25+
26+
- "z, m", [random_ranged(0, 5, (2,)), random_ranged(0, 5, (2,))]
27+
-)
28+
-=======
29+
@pytest.mark.parametrize("z, m", [random_ranged(0, 5, (2,)), random_ranged(0, 5, (2,))])
30+
->>>>>>> db46ad7d8e0497a3e7be10c7a7551c332a6d877e
31+
def test_poch(z, m):
32+
33+
_z, _m = scalars("z", "m")
34+
@@ -179,15 +169,7 @@ def test_factorial(n):
35+
36+
actual_fn = function([_n], factorial(_n))
37+
actual = actual_fn(n)
38+
-<<<<<<< HEAD
39+
-
40+
- expected = scipy_factorial(n)
41+
-
42+
- assert np.allclose(actual, expected)
43+
-
44+
-=======
2645

46+
expected = scipy_factorial(n)
2747

28-
class TestSoftmax(utt.InferShapeTester):
48+
assert np.allclose(actual, expected)
49+
->>>>>>> db46ad7d8e0497a3e7be10c7a7551c332a6d877e

0 commit comments

Comments
 (0)