@@ -378,53 +378,42 @@ def L_op(self, inputs, outputs, grads):
378
378
379
379
def c_support_code (self , ** kwargs ):
380
380
return """
381
- // For GPU support
382
- #ifdef WITHIN_KERNEL
383
- #define DEVICE WITHIN_KERNEL
384
- #else
385
- #define DEVICE
386
- #endif
387
-
388
- #ifndef ga_double
389
- #define ga_double double
390
- #endif
391
-
392
381
#ifndef _PSIFUNCDEFINED
393
382
#define _PSIFUNCDEFINED
394
- DEVICE double _psi(ga_double x) {
383
+ double _psi(double x) {
395
384
396
- /*taken from
397
- Bernardo, J. M. (1976). Algorithm AS 103:
398
- Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
399
- http://www.uv.es/~bernardo/1976AppStatist.pdf */
385
+ /*taken from
386
+ Bernardo, J. M. (1976). Algorithm AS 103:
387
+ Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
388
+ http://www.uv.es/~bernardo/1976AppStatist.pdf */
400
389
401
- ga_double y, R, psi_ = 0;
402
- ga_double S = 1.0e-5;
403
- ga_double C = 8.5;
404
- ga_double S3 = 8.333333333e-2;
405
- ga_double S4 = 8.333333333e-3;
406
- ga_double S5 = 3.968253968e-3;
407
- ga_double D1 = -0.5772156649;
390
+ double y, R, psi_ = 0;
391
+ double S = 1.0e-5;
392
+ double C = 8.5;
393
+ double S3 = 8.333333333e-2;
394
+ double S4 = 8.333333333e-3;
395
+ double S5 = 3.968253968e-3;
396
+ double D1 = -0.5772156649;
408
397
409
- y = x;
398
+ y = x;
410
399
411
- if (y <= 0.0)
412
- return psi_;
400
+ if (y <= 0.0)
401
+ return psi_;
413
402
414
- if (y <= S)
415
- return D1 - 1.0/y;
403
+ if (y <= S)
404
+ return D1 - 1.0/y;
416
405
417
- while (y < C) {
418
- psi_ = psi_ - 1.0 / y;
419
- y = y + 1;
420
- }
406
+ while (y < C) {
407
+ psi_ = psi_ - 1.0 / y;
408
+ y = y + 1;
409
+ }
421
410
422
- R = 1.0 / y;
423
- psi_ = psi_ + log(y) - .5 * R ;
424
- R= R*R;
425
- psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
411
+ R = 1.0 / y;
412
+ psi_ = psi_ + log(y) - .5 * R ;
413
+ R= R*R;
414
+ psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
426
415
427
- return psi_;
416
+ return psi_;
428
417
}
429
418
#endif
430
419
"""
@@ -433,8 +422,8 @@ def c_code(self, node, name, inp, out, sub):
433
422
(x ,) = inp
434
423
(z ,) = out
435
424
if node .inputs [0 ].type in float_types :
436
- return f""" { z } =
437
- _psi({ x } );"" "
425
+ dtype = "npy_" + node . outputs [ 0 ]. dtype
426
+ return f"( { dtype } ) { z } = _psi({ x } );"
438
427
raise NotImplementedError ("only floating point is implemented" )
439
428
440
429
0 commit comments