72
72
from pandas .core .arrays import (
73
73
Categorical ,
74
74
DatetimeArray ,
75
+ ExtensionArray ,
75
76
TimedeltaArray ,
76
77
)
77
78
from pandas .core .arrays .string_ import StringDtype
108
109
SequenceNotStr ,
109
110
StorageOptions ,
110
111
WriteBuffer ,
112
+ npt ,
111
113
)
112
114
113
115
from pandas import (
@@ -1216,7 +1218,7 @@ def get_buffer(
1216
1218
1217
1219
1218
1220
def format_array (
1219
- values : Any ,
1221
+ values : ArrayLike ,
1220
1222
formatter : Callable | None ,
1221
1223
float_format : FloatFormatType | None = None ,
1222
1224
na_rep : str = "NaN" ,
@@ -1233,7 +1235,7 @@ def format_array(
1233
1235
1234
1236
Parameters
1235
1237
----------
1236
- values
1238
+ values : np.ndarray or ExtensionArray
1237
1239
formatter
1238
1240
float_format
1239
1241
na_rep
@@ -1258,10 +1260,13 @@ def format_array(
1258
1260
fmt_klass : type [GenericArrayFormatter ]
1259
1261
if lib .is_np_dtype (values .dtype , "M" ):
1260
1262
fmt_klass = Datetime64Formatter
1263
+ values = cast (DatetimeArray , values )
1261
1264
elif isinstance (values .dtype , DatetimeTZDtype ):
1262
1265
fmt_klass = Datetime64TZFormatter
1266
+ values = cast (DatetimeArray , values )
1263
1267
elif lib .is_np_dtype (values .dtype , "m" ):
1264
1268
fmt_klass = Timedelta64Formatter
1269
+ values = cast (TimedeltaArray , values )
1265
1270
elif isinstance (values .dtype , ExtensionDtype ):
1266
1271
fmt_klass = ExtensionArrayFormatter
1267
1272
elif lib .is_np_dtype (values .dtype , "fc" ):
@@ -1300,7 +1305,7 @@ def format_array(
1300
1305
class GenericArrayFormatter :
1301
1306
def __init__ (
1302
1307
self ,
1303
- values : Any ,
1308
+ values : ArrayLike ,
1304
1309
digits : int = 7 ,
1305
1310
formatter : Callable | None = None ,
1306
1311
na_rep : str = "NaN" ,
@@ -1622,9 +1627,11 @@ def _format_strings(self) -> list[str]:
1622
1627
1623
1628
1624
1629
class Datetime64Formatter (GenericArrayFormatter ):
1630
+ values : DatetimeArray
1631
+
1625
1632
def __init__ (
1626
1633
self ,
1627
- values : np . ndarray | Series | DatetimeIndex | DatetimeArray ,
1634
+ values : DatetimeArray ,
1628
1635
nat_rep : str = "NaT" ,
1629
1636
date_format : None = None ,
1630
1637
** kwargs ,
@@ -1637,21 +1644,23 @@ def _format_strings(self) -> list[str]:
1637
1644
"""we by definition have DO NOT have a TZ"""
1638
1645
values = self .values
1639
1646
1640
- if not isinstance (values , DatetimeIndex ):
1641
- values = DatetimeIndex (values )
1647
+ dti = DatetimeIndex (values )
1642
1648
1643
1649
if self .formatter is not None and callable (self .formatter ):
1644
- return [self .formatter (x ) for x in values ]
1650
+ return [self .formatter (x ) for x in dti ]
1645
1651
1646
- fmt_values = values ._data ._format_native_types (
1652
+ fmt_values = dti ._data ._format_native_types (
1647
1653
na_rep = self .nat_rep , date_format = self .date_format
1648
1654
)
1649
1655
return fmt_values .tolist ()
1650
1656
1651
1657
1652
1658
class ExtensionArrayFormatter (GenericArrayFormatter ):
1659
+ values : ExtensionArray
1660
+
1653
1661
def _format_strings (self ) -> list [str ]:
1654
1662
values = extract_array (self .values , extract_numpy = True )
1663
+ values = cast (ExtensionArray , values )
1655
1664
1656
1665
formatter = self .formatter
1657
1666
fallback_formatter = None
@@ -1813,13 +1822,10 @@ def get_format_datetime64(
1813
1822
1814
1823
1815
1824
def get_format_datetime64_from_values (
1816
- values : np . ndarray | DatetimeArray | DatetimeIndex , date_format : str | None
1825
+ values : DatetimeArray , date_format : str | None
1817
1826
) -> str | None :
1818
1827
"""given values and a date_format, return a string format"""
1819
- if isinstance (values , np .ndarray ) and values .ndim > 1 :
1820
- # We don't actually care about the order of values, and DatetimeIndex
1821
- # only accepts 1D values
1822
- values = values .ravel ()
1828
+ assert isinstance (values , DatetimeArray )
1823
1829
1824
1830
ido = is_dates_only (values )
1825
1831
if ido :
@@ -1829,6 +1835,8 @@ def get_format_datetime64_from_values(
1829
1835
1830
1836
1831
1837
class Datetime64TZFormatter (Datetime64Formatter ):
1838
+ values : DatetimeArray
1839
+
1832
1840
def _format_strings (self ) -> list [str ]:
1833
1841
"""we by definition have a TZ"""
1834
1842
ido = is_dates_only (self .values )
@@ -1842,9 +1850,11 @@ def _format_strings(self) -> list[str]:
1842
1850
1843
1851
1844
1852
class Timedelta64Formatter (GenericArrayFormatter ):
1853
+ values : TimedeltaArray
1854
+
1845
1855
def __init__ (
1846
1856
self ,
1847
- values : np . ndarray | TimedeltaIndex ,
1857
+ values : TimedeltaArray ,
1848
1858
nat_rep : str = "NaT" ,
1849
1859
box : bool = False ,
1850
1860
** kwargs ,
@@ -1861,7 +1871,7 @@ def _format_strings(self) -> list[str]:
1861
1871
1862
1872
1863
1873
def get_format_timedelta64 (
1864
- values : np . ndarray | TimedeltaIndex | TimedeltaArray ,
1874
+ values : TimedeltaArray ,
1865
1875
nat_rep : str | float = "NaT" ,
1866
1876
box : bool = False ,
1867
1877
) -> Callable :
@@ -1872,18 +1882,13 @@ def get_format_timedelta64(
1872
1882
If box, then show the return in quotes
1873
1883
"""
1874
1884
values_int = values .view (np .int64 )
1885
+ values_int = cast ("npt.NDArray[np.int64]" , values_int )
1875
1886
1876
1887
consider_values = values_int != iNaT
1877
1888
1878
1889
one_day_nanos = 86400 * 10 ** 9
1879
- # error: Unsupported operand types for % ("ExtensionArray" and "int")
1880
- not_midnight = values_int % one_day_nanos != 0 # type: ignore[operator]
1881
- # error: Argument 1 to "__call__" of "ufunc" has incompatible type
1882
- # "Union[Any, ExtensionArray, ndarray]"; expected
1883
- # "Union[Union[int, float, complex, str, bytes, generic],
1884
- # Sequence[Union[int, float, complex, str, bytes, generic]],
1885
- # Sequence[Sequence[Any]], _SupportsArray]"
1886
- both = np .logical_and (consider_values , not_midnight ) # type: ignore[arg-type]
1890
+ not_midnight = values_int % one_day_nanos != 0
1891
+ both = np .logical_and (consider_values , not_midnight )
1887
1892
even_days = both .sum () == 0
1888
1893
1889
1894
if even_days :
@@ -1941,7 +1946,7 @@ def just(x: str) -> str:
1941
1946
return result
1942
1947
1943
1948
1944
- def _trim_zeros_complex (str_complexes : np . ndarray , decimal : str = "." ) -> list [str ]:
1949
+ def _trim_zeros_complex (str_complexes : ArrayLike , decimal : str = "." ) -> list [str ]:
1945
1950
"""
1946
1951
Separates the real and imaginary parts from the complex number, and
1947
1952
executes the _trim_zeros_float method on each of those.
@@ -1987,7 +1992,7 @@ def _trim_zeros_single_float(str_float: str) -> str:
1987
1992
1988
1993
1989
1994
def _trim_zeros_float (
1990
- str_floats : np . ndarray | list [str ], decimal : str = "."
1995
+ str_floats : ArrayLike | list [str ], decimal : str = "."
1991
1996
) -> list [str ]:
1992
1997
"""
1993
1998
Trims the maximum number of trailing zeros equally from
@@ -2000,7 +2005,7 @@ def _trim_zeros_float(
2000
2005
def is_number_with_decimal (x ) -> bool :
2001
2006
return re .match (number_regex , x ) is not None
2002
2007
2003
- def should_trim (values : np . ndarray | list [str ]) -> bool :
2008
+ def should_trim (values : ArrayLike | list [str ]) -> bool :
2004
2009
"""
2005
2010
Determine if an array of strings should be trimmed.
2006
2011
0 commit comments