33
33
upgrade_to_float ,
34
34
upgrade_to_float64 ,
35
35
upgrade_to_float_no_complex ,
36
+ ScalarType ,
37
+ ScalarVariable
36
38
)
37
39
38
40
@@ -1494,7 +1496,11 @@ class Hyp2F1(ScalarOp):
1494
1496
1495
1497
@staticmethod
1496
1498
def st_impl (a , b , c , z ):
1497
- return scipy .special .hyp2f1 (a , b , c , z )
1499
+
1500
+ if abs (z ) >= 1 :
1501
+ raise NotImplementedError ("hyp2f1 only supported for z < 1." )
1502
+ else :
1503
+ return scipy .special .hyp2f1 (a , b , c , z )
1498
1504
1499
1505
def impl (self , a , b , c , z ):
1500
1506
return Hyp2F1 .st_impl (a , b , c , z )
@@ -1551,10 +1557,10 @@ def _hyp2f1_da(a, b, c, z):
1551
1557
else :
1552
1558
1553
1559
term1 = _infinisum (
1554
- lambda k : (scipy . special . poch (a , k ) * scipy . special . poch (b , k ) * scipy . special . digamma (a + k ) * (z ** k ))
1555
- / (scipy . special . poch (c , k ) * scipy . special . factorial (k ))
1560
+ lambda k : (poch (a , k ) * poch (b , k ) * psi (a + k ) * (z ** k ))
1561
+ / (poch (c , k ) * factorial (k ))
1556
1562
)
1557
- term2 = scipy . special . digamma (a ) * scipy . special . hyp2f1 (a , b , c , z )
1563
+ term2 = psi (a ) * hyp2f1 (a , b , c , z )
1558
1564
1559
1565
return term1 - term2
1560
1566
@@ -1568,10 +1574,10 @@ def _hyp2f1_db(a, b, c, z):
1568
1574
1569
1575
else :
1570
1576
term1 = _infinisum (
1571
- lambda k : (scipy . special . poch (a , k ) * scipy . special . poch (b , k ) * scipy . special . digamma (b + k ) * (z ** k ))
1572
- / (scipy . special . poch (c , k ) * scipy . special . factorial (k ))
1577
+ lambda k : (poch (a , k ) * poch (b , k ) * psi (b + k ) * (z ** k ))
1578
+ / (poch (c , k ) * factorial (k ))
1573
1579
)
1574
- term2 = scipy . special . digamma (b ) * scipy . special . hyp2f1 (a , b , c , z )
1580
+ term2 = psi (b ) * hyp2f1 (a , b , c , z )
1575
1581
1576
1582
return term1 - term2
1577
1583
@@ -1583,10 +1589,10 @@ def _hyp2f1_dc(a, b, c, z):
1583
1589
raise NotImplementedError ('Gradient not supported for |z| >= 1' )
1584
1590
1585
1591
else :
1586
- term1 = scipy . special . digamma (c ) * scipy . special . hyp2f1 (a , b , c , z )
1592
+ term1 = psi (c ) * hyp2f1 (a , b , c , z )
1587
1593
term2 = _infinisum (
1588
- lambda k : (scipy . special . poch (a , k ) * scipy . special . poch (b , k ) * scipy . special . digamma (c + k ) * (z ** k ))
1589
- / (scipy . special . poch (c , k ) * scipy . special . factorial (k ))
1594
+ lambda k : (poch (a , k ) * poch (b , k ) * psi (c + k ) * (z ** k ))
1595
+ / (poch (c , k ) * factorial (k ))
1590
1596
)
1591
1597
return term1 - term2
1592
1598
@@ -1595,7 +1601,7 @@ def _hyp2f1_dz(a, b, c, z):
1595
1601
Derivative of hyp2f1 wrt z
1596
1602
"""
1597
1603
1598
- return ((a * b ) / c ) * scipy . special . hyp2f1 (a + 1 , b + 1 , c + 1 , z )
1604
+ return ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z )
1599
1605
1600
1606
if wrt == 0 :
1601
1607
return _hyp2f1_da (a , b , c , z )
@@ -1613,58 +1619,17 @@ def c_code(self, *args, **kwargs):
1613
1619
hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
1614
1620
1615
1621
1616
- class Poch ( BinaryScalarOp ) :
1622
+ def poch ( z : ScalarType , m : ScalarType ) -> ScalarVariable :
1617
1623
"""
1618
1624
Pochhammer symbol (rising factorial) function.
1619
1625
1620
1626
"""
1621
-
1622
- nfunc_spec = ("scipy.special.poch" , 2 , 1 )
1623
-
1624
- @staticmethod
1625
- def st_impl (z , m ):
1626
- return gamma (z + m ) / gamma (z )
1627
-
1628
- def impl (self , z , m ):
1629
- return Poch .st_impl (z , m )
1630
-
1631
- def grad (self , inputs , grads ):
1632
- z , m = inputs
1633
- (gz ,) = grads
1634
- return [
1635
- gz * poch (z , m ) * (tri_gamma (z + m ) - tri_gamma (z )),
1636
- gz * poch (z , m ) * tri_gamma (z + m )
1637
- ]
1638
-
1639
- def c_code (self , * args , ** kwargs ):
1640
- raise NotImplementedError ()
1627
+ return gamma (z + m ) / gamma (z )
1641
1628
1642
1629
1643
- poch = Poch (upgrade_to_float , name = "poch" )
1644
-
1645
-
1646
- class Factorial (UnaryScalarOp ):
1630
+ def factorial (n : ScalarType ) -> ScalarVariable :
1647
1631
"""
1648
1632
Factorial function of a scalar or array of numbers.
1649
1633
1650
1634
"""
1651
-
1652
- nfunc_spec = ("scipy.special.factorial" , 1 , 1 )
1653
-
1654
- @staticmethod
1655
- def st_impl (n ):
1656
- return gamma (n + 1 )
1657
-
1658
- def impl (self , n ):
1659
- return Factorial .st_impl (n )
1660
-
1661
- def grad (self , inputs , grads ):
1662
- (n ,) = inputs
1663
- (gz ,) = grads
1664
- return [gz * gamma (n + 1 ) * tri_gamma (n + 1 )]
1665
-
1666
- def c_code (self , * args , ** kwargs ):
1667
- raise NotImplementedError ()
1668
-
1669
-
1670
- factorial = Factorial (upgrade_to_float , name = "factorial" )
1635
+ return gamma (n + 1 )
0 commit comments