34
34
upgrade_to_float64 ,
35
35
upgrade_to_float_no_complex ,
36
36
ScalarType ,
37
- ScalarVariable
37
+ ScalarVariable ,
38
38
)
39
39
40
40
@@ -1526,6 +1526,9 @@ class Hyp2F1Der(ScalarOp):
1526
1526
"""
1527
1527
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1528
1528
1529
+ Currently written in terms of gamma until poch and factorial Ops are ready:
1530
+ poch(z, m) = (gamma(z + m) / gamma(m))
1531
+ factorial(n) = gamma(n+1)
1529
1532
"""
1530
1533
1531
1534
nin = 5
@@ -1538,27 +1541,33 @@ def _infinisum(f):
1538
1541
1539
1542
n , res = 0 , f (0 )
1540
1543
while True :
1541
- term = f (n + 1 )
1544
+ term = f (n + 1 )
1542
1545
if RuntimeWarning :
1543
1546
break
1544
- if (res + term )- res == 0 :
1547
+ if (res + term ) - res == 0 :
1545
1548
break
1546
- n ,res = n + 1 , res + term
1549
+ n , res = n + 1 , res + term
1547
1550
return res
1548
1551
1549
1552
def _hyp2f1_da (a , b , c , z ):
1550
1553
"""
1551
1554
Derivative of hyp2f1 wrt a
1555
+
1552
1556
"""
1553
1557
1554
1558
if abs (z ) >= 1 :
1555
- raise NotImplementedError (' Gradient not supported for |z| >= 1' )
1559
+ raise NotImplementedError (" Gradient not supported for |z| >= 1" )
1556
1560
1557
1561
else :
1558
-
1559
1562
term1 = _infinisum (
1560
- lambda k : (poch (a , k ) * poch (b , k ) * psi (a + k ) * (z ** k ))
1561
- / (poch (c , k ) * factorial (k ))
1563
+ lambda k : (
1564
+ (gamma (a + k ) / gamma (a ))
1565
+ * (gamma (b + k ) / gamma (b ))
1566
+ * psi (a + k )
1567
+ * (z ** k )
1568
+ )
1569
+ / (gamma (c + k ) / gamma (c ))
1570
+ * gamma (k + 1 )
1562
1571
)
1563
1572
term2 = psi (a ) * hyp2f1 (a , b , c , z )
1564
1573
@@ -1570,12 +1579,18 @@ def _hyp2f1_db(a, b, c, z):
1570
1579
"""
1571
1580
1572
1581
if abs (z ) >= 1 :
1573
- raise NotImplementedError (' Gradient not supported for |z| >= 1' )
1582
+ raise NotImplementedError (" Gradient not supported for |z| >= 1" )
1574
1583
1575
1584
else :
1576
1585
term1 = _infinisum (
1577
- lambda k : (poch (a , k ) * poch (b , k ) * psi (b + k ) * (z ** k ))
1578
- / (poch (c , k ) * factorial (k ))
1586
+ lambda k : (
1587
+ (gamma (a + k ) / gamma (a ))
1588
+ * (gamma (b + k ) / gamma (b ))
1589
+ * psi (b + k )
1590
+ * (z ** k )
1591
+ )
1592
+ / (gamma (c + k ) / gamma (c ))
1593
+ * gamma (k + 1 )
1579
1594
)
1580
1595
term2 = psi (b ) * hyp2f1 (a , b , c , z )
1581
1596
@@ -1586,13 +1601,19 @@ def _hyp2f1_dc(a, b, c, z):
1586
1601
Derivative of hyp2f1 wrt c
1587
1602
"""
1588
1603
if abs (z ) >= 1 :
1589
- raise NotImplementedError (' Gradient not supported for |z| >= 1' )
1604
+ raise NotImplementedError (" Gradient not supported for |z| >= 1" )
1590
1605
1591
1606
else :
1592
1607
term1 = psi (c ) * hyp2f1 (a , b , c , z )
1593
1608
term2 = _infinisum (
1594
- lambda k : (poch (a , k ) * poch (b , k ) * psi (c + k ) * (z ** k ))
1595
- / (poch (c , k ) * factorial (k ))
1609
+ lambda k : (
1610
+ (gamma (a + k ) / gamma (a ))
1611
+ * (gamma (b + k ) / gamma (b ))
1612
+ * psi (c + k )
1613
+ * (z ** k )
1614
+ )
1615
+ / (gamma (c + k ) / gamma (c ))
1616
+ * gamma (k + 1 )
1596
1617
)
1597
1618
return term1 - term2
1598
1619
@@ -1624,7 +1645,7 @@ def poch(z: ScalarType, m: ScalarType) -> ScalarVariable:
1624
1645
Pochhammer symbol (rising factorial) function.
1625
1646
1626
1647
"""
1627
- return gamma (z + m ) / gamma (z )
1648
+ return gamma (z + m ) / gamma (z )
1628
1649
1629
1650
1630
1651
def factorial (n : ScalarType ) -> ScalarVariable :
0 commit comments