5
5
"""
6
6
7
7
import os
8
- import warnings
9
8
from textwrap import dedent
10
9
11
10
import numpy as np
26
25
expm1,
27
26
float64,
28
27
float_types,
28
+ floor,
29
29
identity,
30
+ integer_types,
30
31
isinf,
31
32
log,
32
33
log1p,
@@ -853,15 +854,13 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
853
854
s_sign = -s_sign
854
855
855
856
# log will cast >int16 to float64
856
- log_s_inc = log_x - log(n)
857
- if log_s_inc.type.dtype != log_s.type.dtype:
858
- log_s_inc = log_s_inc.astype(log_s.type.dtype)
859
- log_s += log_s_inc
857
+ log_s += log_x - log(n)
858
+ if log_s.type.dtype != dtype:
859
+ log_s = log_s.astype(dtype)
860
860
861
- new_log_delta = log_s - 2 * log(n + k)
862
- if new_log_delta.type.dtype != log_delta.type.dtype:
863
- new_log_delta = new_log_delta.astype(log_delta.type.dtype)
864
- log_delta = new_log_delta
861
+ log_delta = log_s - 2 * log(n + k)
862
+ if log_delta.type.dtype != dtype:
863
+ log_delta = log_delta.astype(dtype)
865
864
866
865
n += 1
867
866
return (
@@ -1581,9 +1580,9 @@ def grad(self, inputs, grads):
1581
1580
a, b, c, z = inputs
1582
1581
(gz,) = grads
1583
1582
return [
1584
- gz * hyp2f1_der (a, b, c, z, wrt=0),
1585
- gz * hyp2f1_der (a, b, c, z, wrt=1),
1586
- gz * hyp2f1_der (a, b, c, z, wrt=2),
1583
+ gz * hyp2f1_grad (a, b, c, z, wrt=0),
1584
+ gz * hyp2f1_grad (a, b, c, z, wrt=1),
1585
+ gz * hyp2f1_grad (a, b, c, z, wrt=2),
1587
1586
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
1588
1587
]
1589
1588
@@ -1594,134 +1593,165 @@ def c_code(self, *args, **kwargs):
1594
1593
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1595
1594
1596
1595
1597
- class Hyp2F1Der(ScalarOp ):
1598
- """
1599
- Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1596
+ def _unsafe_sign(x ):
1597
+ # Unlike scalar.sign we don't worry about x being 0 or nan
1598
+ return switch(x > 0, 1, -1)
1600
1599
1601
- Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1602
- """
1603
1600
1604
- nin = 5
1601
+ def hyp2f1_grad(a, b, c, z, wrt: int):
1602
+ dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
1605
1603
1606
- def impl(self, a, b, c, z, wrt):
1607
- def check_2f1_converges(a, b, c, z) -> bool:
1608
- num_terms = 0
1609
- is_polynomial = False
1604
+ def check_2f1_converges(a, b, c, z):
1605
+ def is_nonpositive_integer(x):
1606
+ if x.type.dtype not in integer_types:
1607
+ return eq(floor(x), x) & (x <= 0)
1608
+ else:
1609
+ return x <= 0
1610
1610
1611
- def is_nonpositive_integer(x):
1612
- return x <= 0 and x.is_integer()
1611
+ a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
1612
+ num_terms = switch(
1613
+ a_is_polynomial,
1614
+ floor(scalar_abs(a)).astype("int64"),
1615
+ 0,
1616
+ )
1613
1617
1614
- if is_nonpositive_integer(a) and abs(a ) >= num_terms:
1615
- is_polynomial = True
1616
- num_terms = int(np.floor(abs(a)))
1617
- if is_nonpositive_integer(b) and abs(b) >= num_terms:
1618
- is_polynomial = True
1619
- num_terms = int(np.floor(abs(b)) )
1618
+ b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b ) >= num_terms)
1619
+ num_terms = switch(
1620
+ b_is_polynomial,
1621
+ floor(scalar_abs(b)).astype("int64"),
1622
+ num_terms,
1623
+ )
1620
1624
1621
- is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1625
+ is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
1626
+ is_polynomial = a_is_polynomial | b_is_polynomial
1622
1627
1623
- return not is_undefined and (
1624
- is_polynomial or np.abs( z) < 1 or (np.abs(z) == 1 and c > (a + b))
1625
- )
1628
+ return (~ is_undefined) & (
1629
+ is_polynomial | (scalar_abs( z) < 1) | (eq(scalar_abs(z), 1) & ( c > (a + b) ))
1630
+ )
1626
1631
1627
- def compute_grad_2f1(a, b, c, z, wrt):
1628
- """
1629
- Notes
1630
- -----
1631
- The algorithm can be derived by looking at the ratio of two successive terms in the series
1632
- β_{k+1}/β_{k} = A(k)/B(k)
1633
- β_{k+1} = A(k)/B(k) * β_{k}
1634
- d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1635
-
1636
- In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1637
-
1638
- The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639
- by dropping the respective term
1640
- d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641
- d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642
- d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643
-
1644
- The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645
- tracking their signs.
1646
- """
1632
+ def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
1633
+ """
1634
+ Notes
1635
+ -----
1636
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1637
+ β_{k+1}/β_{k} = A(k)/B(k)
1638
+ β_{k+1} = A(k)/B(k) * β_{k}
1639
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1640
+
1641
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1642
+
1643
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1644
+ by dropping the respective term
1645
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1646
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1647
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1648
+
1649
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1650
+ tracking their signs.
1651
+ """
1652
+
1653
+ wrt_a = wrt_b = False
1654
+ if wrt == 0:
1655
+ wrt_a = True
1656
+ elif wrt == 1:
1657
+ wrt_b = True
1658
+ elif wrt != 2:
1659
+ raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1660
+
1661
+ min_steps = np.array(
1662
+ 10, dtype="int32"
1663
+ ) # https://github.com/stan-dev/math/issues/2857
1664
+ max_steps = switch(
1665
+ skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
1666
+ )
1667
+ precision = np.array(1e-14, dtype=config.floatX)
1647
1668
1648
- wrt_a = wrt_b = False
1649
- if wrt == 0:
1650
- wrt_a = True
1651
- elif wrt == 1:
1652
- wrt_b = True
1653
- elif wrt != 2:
1654
- raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1655
-
1656
- min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1657
- max_steps = int(1e6)
1658
- precision = 1e-14
1659
-
1660
- res = 0
1661
-
1662
- if z == 0:
1663
- return res
1664
-
1665
- log_g_old = -np.inf
1666
- log_t_old = 0.0
1667
- log_t_new = 0.0
1668
- sign_z = np.sign(z)
1669
- log_z = np.log(np.abs(z))
1670
-
1671
- log_g_old_sign = 1
1672
- log_t_old_sign = 1
1673
- log_t_new_sign = 1
1674
- sign_zk = sign_z
1675
-
1676
- for k in range(max_steps):
1677
- p = (a + k) * (b + k) / ((c + k) * (k + 1))
1678
- if p == 0:
1679
- return res
1680
- log_t_new += np.log(np.abs(p)) + log_z
1681
- log_t_new_sign = np.sign(p) * log_t_new_sign
1682
-
1683
- term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1684
- if wrt_a:
1685
- term += np.reciprocal(a + k)
1686
- elif wrt_b:
1687
- term += np.reciprocal(b + k)
1688
- else:
1689
- term -= np.reciprocal(c + k)
1690
-
1691
- log_g_old = log_t_new + np.log(np.abs(term))
1692
- log_g_old_sign = np.sign(term) * log_t_new_sign
1693
- g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1694
- res += g_current
1695
-
1696
- log_t_old = log_t_new
1697
- log_t_old_sign = log_t_new_sign
1698
- sign_zk *= sign_z
1699
-
1700
- if k >= min_steps and np.abs(g_current) <= precision:
1701
- return res
1702
-
1703
- warnings.warn(
1704
- f"hyp2f1_der did not converge after {k} iterations",
1705
- RuntimeWarning,
1706
- )
1707
- return np.nan
1669
+ grad = np.array(0, dtype=dtype)
1670
+
1671
+ log_g = np.array(-np.inf, dtype=dtype)
1672
+ log_g_sign = np.array(1, dtype="int8")
1673
+
1674
+ log_t = np.array(0.0, dtype=dtype)
1675
+ log_t_sign = np.array(1, dtype="int8")
1676
+
1677
+ log_z = log(scalar_abs(z))
1678
+ sign_z = _unsafe_sign(z)
1679
+
1680
+ sign_zk = sign_z
1681
+ k = np.array(0, dtype="int32")
1682
+
1683
+ def inner_loop(
1684
+ grad,
1685
+ log_g,
1686
+ log_g_sign,
1687
+ log_t,
1688
+ log_t_sign,
1689
+ sign_zk,
1690
+ k,
1691
+ a,
1692
+ b,
1693
+ c,
1694
+ log_z,
1695
+ sign_z,
1696
+ ):
1697
+ p = (a + k) * (b + k) / ((c + k) * (k + 1))
1698
+ if p.type.dtype != dtype:
1699
+ p = p.astype(dtype)
1700
+
1701
+ term = log_g_sign * log_t_sign * exp(log_g - log_t)
1702
+ if wrt_a:
1703
+ term += reciprocal(a + k)
1704
+ elif wrt_b:
1705
+ term += reciprocal(b + k)
1706
+ else:
1707
+ term -= reciprocal(c + k)
1708
+
1709
+ if term.type.dtype != dtype:
1710
+ term = term.astype(dtype)
1711
+
1712
+ log_t = log_t + log(scalar_abs(p)) + log_z
1713
+ log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1714
+ log_g = log_t + log(scalar_abs(term))
1715
+ log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
1716
+
1717
+ g_current = log_g_sign * exp(log_g) * sign_zk
1708
1718
1709
- # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1710
- if not check_2f1_converges(a, b, c, z):
1711
- warnings.warn(
1712
- f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}" ,
1713
- RuntimeWarning ,
1719
+ # If p==0, don't update grad and get out of while loop next
1720
+ grad = switch(
1721
+ eq(p, 0),
1722
+ grad ,
1723
+ grad + g_current ,
1714
1724
)
1715
- return np.nan
1716
1725
1717
- return compute_grad_2f1(a, b, c, z, wrt=wrt)
1726
+ sign_zk *= sign_z
1727
+ k += 1
1718
1728
1719
- def __call__(self, a, b, c, z, wrt, **kwargs):
1720
- # This allows wrt to be a keyword argument
1721
- return super().__call__(a, b, c, z, wrt, **kwargs)
1729
+ return (
1730
+ (grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
1731
+ (eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
1732
+ )
1722
1733
1723
- def c_code(self, *args, **kwargs):
1724
- raise NotImplementedError()
1734
+ init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
1735
+ constant = [a, b, c, log_z, sign_z]
1736
+ grad = _make_scalar_loop(
1737
+ max_steps, init, constant, inner_loop, name="hyp2f1_grad"
1738
+ )
1725
1739
1740
+ return switch(
1741
+ eq(z, 0),
1742
+ 0,
1743
+ grad,
1744
+ )
1726
1745
1727
- hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
1746
+ # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1747
+ z_is_zero = eq(z, 0)
1748
+ converges = check_2f1_converges(a, b, c, z)
1749
+ return switch(
1750
+ z_is_zero,
1751
+ 0,
1752
+ switch(
1753
+ converges,
1754
+ compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
1755
+ np.nan,
1756
+ ),
1757
+ )
0 commit comments