Skip to content

Commit be1c373

Browse files
committed
Create test for mismatch between C and python Psi implementation
1 parent 92eef5e commit be1c373

File tree

2 files changed

+62
-39
lines changed

2 files changed

+62
-39
lines changed

pytensor/scalar/math.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -378,53 +378,42 @@ def L_op(self, inputs, outputs, grads):
378378

379379
def c_support_code(self, **kwargs):
380380
return """
381-
// For GPU support
382-
#ifdef WITHIN_KERNEL
383-
#define DEVICE WITHIN_KERNEL
384-
#else
385-
#define DEVICE
386-
#endif
387-
388-
#ifndef ga_double
389-
#define ga_double double
390-
#endif
391-
392381
#ifndef _PSIFUNCDEFINED
393382
#define _PSIFUNCDEFINED
394-
DEVICE double _psi(ga_double x) {
383+
double _psi(double x) {
395384
396-
/*taken from
397-
Bernardo, J. M. (1976). Algorithm AS 103:
398-
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
399-
http://www.uv.es/~bernardo/1976AppStatist.pdf */
385+
/*taken from
386+
Bernardo, J. M. (1976). Algorithm AS 103:
387+
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
388+
http://www.uv.es/~bernardo/1976AppStatist.pdf */
400389
401-
ga_double y, R, psi_ = 0;
402-
ga_double S = 1.0e-5;
403-
ga_double C = 8.5;
404-
ga_double S3 = 8.333333333e-2;
405-
ga_double S4 = 8.333333333e-3;
406-
ga_double S5 = 3.968253968e-3;
407-
ga_double D1 = -0.5772156649;
390+
double y, R, psi_ = 0;
391+
double S = 1.0e-5;
392+
double C = 8.5;
393+
double S3 = 8.333333333e-2;
394+
double S4 = 8.333333333e-3;
395+
double S5 = 3.968253968e-3;
396+
double D1 = -0.5772156649;
408397
409-
y = x;
398+
y = x;
410399
411-
if (y <= 0.0)
412-
return psi_;
400+
if (y <= 0.0)
401+
return psi_;
413402
414-
if (y <= S)
415-
return D1 - 1.0/y;
403+
if (y <= S)
404+
return D1 - 1.0/y;
416405
417-
while (y < C) {
418-
psi_ = psi_ - 1.0 / y;
419-
y = y + 1;
420-
}
406+
while (y < C) {
407+
psi_ = psi_ - 1.0 / y;
408+
y = y + 1;
409+
}
421410
422-
R = 1.0 / y;
423-
psi_ = psi_ + log(y) - .5 * R ;
424-
R= R*R;
425-
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
411+
R = 1.0 / y;
412+
psi_ = psi_ + log(y) - .5 * R ;
413+
R= R*R;
414+
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
426415
427-
return psi_;
416+
return psi_;
428417
}
429418
#endif
430419
"""
@@ -433,8 +422,8 @@ def c_code(self, node, name, inp, out, sub):
433422
(x,) = inp
434423
(z,) = out
435424
if node.inputs[0].type in float_types:
436-
return f"""{z} =
437-
_psi({x});"""
425+
dtype = "npy_" + node.outputs[0].dtype
426+
return f"({dtype}){z} = _psi({x});"
438427
raise NotImplementedError("only floating point is implemented")
439428

440429

tests/scalar/test_math.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
import scipy
56
import scipy.special as sp
67

78
import pytensor.tensor as pt
@@ -19,6 +20,7 @@
1920
gammal,
2021
gammau,
2122
hyp2f1,
23+
psi,
2224
)
2325
from tests.link.test_link import make_function
2426

@@ -149,3 +151,35 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
149151
(var.owner and isinstance(var.owner.op, ScalarLoop))
150152
for var in ancestors(grad)
151153
)
154+
155+
156+
@pytest.mark.parametrize(
157+
"linker",
158+
[
159+
"py",
160+
pytest.param(
161+
"c",
162+
marks=pytest.mark.xfail(
163+
reason="C implementation does not support negative inputs"
164+
),
165+
),
166+
],
167+
)
168+
def test_psi(linker):
169+
x = float64("x")
170+
out = psi(x)
171+
172+
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
173+
fn.dprint()
174+
175+
x_test = np.float64(0.5)
176+
np.testing.assert_allclose(
177+
fn(x_test),
178+
scipy.special.psi(x_test),
179+
strict=True,
180+
)
181+
np.testing.assert_allclose(
182+
fn(-x_test),
183+
scipy.special.psi(-x_test),
184+
strict=True,
185+
)

0 commit comments

Comments
 (0)