Skip to content

Commit eae2fec

Browse files
committed
Keep dtype whenever possible; add _update_array; docstring fixes
1 parent e2edd1a commit eae2fec

File tree

5 files changed

+120
-29
lines changed

5 files changed

+120
-29
lines changed

pandas/core/generic.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4177,7 +4177,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
41774177
"""
41784178
Modify in place using non-NA values from another DataFrame.
41794179
4180-
Aligns on indices. There is no return value.
4180+
Series/DataFrame will be aligned on indexes, and whenever possible,
4181+
the dtype of the individual Series of the caller will be preserved.
4182+
4183+
There is no return value.
41814184
41824185
Parameters
41834186
----------
@@ -4197,7 +4200,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
41974200
* False: only update values that are NA in
41984201
the original DataFrame.
41994202
4200-
filter_func : callable(1d-array) -> boolean 1d-array, optional
4203+
filter_func : callable(1d-array) -> bool 1d-array, optional
42014204
Can choose to replace values other than NA. Return True for values
42024205
that should be updated.
42034206
errors : {'raise', 'ignore'}, default 'ignore'
@@ -4207,7 +4210,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
42074210
Raises
42084211
------
42094212
ValueError
4210-
When `raise_conflict` is True and there's overlapping non-NA data.
4213+
When `errors='ignore'` and there's overlapping non-NA data.
42114214
42124215
Returns
42134216
-------
@@ -4274,10 +4277,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
42744277
>>> new_df = pd.DataFrame({'B': [4, np.nan, 6]})
42754278
>>> df.update(new_df)
42764279
>>> df
4277-
A B
4278-
0 1 4.0
4279-
1 2 500.0
4280-
2 3 6.0
4280+
A B
4281+
0 1 4
4282+
1 2 500
4283+
2 3 6
42814284
"""
42824285
from pandas import Series, DataFrame
42834286
# TODO: Support other joins
@@ -4291,14 +4294,20 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
42914294
this = self.values
42924295
that = other.values
42934296

4294-
# missing.update_array returns an np.ndarray
4295-
updated_values = missing.update_array(this, that,
4297+
# will return None if "this" remains unchanged
4298+
updated_array = missing._update_array(this, that,
42964299
overwrite=overwrite,
42974300
filter_func=filter_func,
42984301
errors=errors)
42994302
# don't overwrite unnecessarily
4300-
if updated_values is not None:
4301-
self._update_inplace(Series(updated_values, index=self.index))
4303+
if updated_array is not None:
4304+
# avoid unnecessary upcasting (introduced by alignment)
4305+
try:
4306+
updated = Series(updated_array, index=self.index,
4307+
dtype=this.dtype)
4308+
except ValueError:
4309+
updated = Series(updated_array, index=self.index)
4310+
self._update_inplace(updated)
43024311
else: # DataFrame
43034312
if not isinstance(other, ABCDataFrame):
43044313
other = DataFrame(other)
@@ -4309,11 +4318,23 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
43094318
this = self[col].values
43104319
that = other[col].values
43114320

4312-
updated = missing.update_array(this, that, overwrite=overwrite,
4313-
filter_func=filter_func,
4314-
errors=errors)
4321+
# will return None if "this" remains unchanged
4322+
updated_array = missing._update_array(this, that,
4323+
overwrite=overwrite,
4324+
filter_func=filter_func,
4325+
errors=errors)
43154326
# don't overwrite unnecessarily
4316-
if updated is not None:
4327+
if updated_array is not None:
4328+
# no problem to set DataFrame column with array
4329+
updated = updated_array
4330+
4331+
if updated_array.dtype != this.dtype:
4332+
# avoid unnecessary upcasting (introduced by alignment)
4333+
try:
4334+
updated = Series(updated_array, index=self.index,
4335+
dtype=this.dtype)
4336+
except ValueError:
4337+
pass
43174338
self[col] = updated
43184339

43194340
def filter(self, items=None, like=None, regex=None, axis=None):

pandas/core/missing.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,25 @@ def update_array(this, that, overwrite=True, filter_func=None,
106106
107107
Returns
108108
-------
109-
updated : np.ndarray (one-dimensional) or None
110-
The updated array. Return None if `this` remains unchanged
109+
updated : np.ndarray (one-dimensional)
110+
The updated array.
111111
112112
See Also
113113
--------
114114
Series.update : Similar method for `Series`.
115115
DataFrame.update : Similar method for `DataFrame`.
116116
dict.update : Similar method for `dict`.
117117
"""
118+
updated = _update_array(this, that, overwrite=overwrite,
119+
filter_func=filter_func, errors=errors)
120+
return this if updated is None else updated
121+
122+
123+
def _update_array(this, that, overwrite=True, filter_func=None,
124+
errors='ignore'):
125+
"""
126+
Same as update_array, except we return None if `this` is not updated.
127+
"""
118128
import pandas.core.computation.expressions as expressions
119129

120130
if filter_func is not None:

pandas/core/series.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,7 +2396,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
23962396
"""
23972397
Modify Series in place using non-NA values from passed Series.
23982398
2399-
Aligns on index.
2399+
Series will be aligned on indexes, and whenever possible, the dtype of
2400+
the caller will be preserved.
2401+
2402+
There is no return value.
24002403
24012404
Parameters
24022405
----------
@@ -2417,7 +2420,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
24172420
the original DataFrame.
24182421
24192422
.. versionadded:: 0.24.0
2420-
filter_func : callable(1d-array) -> boolean 1d-array, optional
2423+
filter_func : callable(1d-array) -> bool 1d-array, optional
24212424
Can choose to replace values other than NA. Return True for values
24222425
that should be updated.
24232426
@@ -2428,10 +2431,19 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
24282431
24292432
.. versionadded:: 0.24.0
24302433
2434+
Raises
2435+
------
2436+
ValueError
2437+
When `errors='ignore'` and there's overlapping non-NA data.
2438+
2439+
Returns
2440+
-------
2441+
Nothing, the Series is modified inplace.
2442+
24312443
See Also
24322444
--------
24332445
DataFrame.update : Similar method for `DataFrame`.
2434-
dict.update : Similar method for `dict`
2446+
dict.update : Similar method for `dict`.
24352447
24362448
Examples
24372449
--------
@@ -2465,10 +2477,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
24652477
>>> s = pd.Series([1, 2, 3])
24662478
>>> s.update(pd.Series([4, np.nan, 6]))
24672479
>>> s
2468-
0 4.0
2469-
1 2.0
2470-
2 6.0
2471-
dtype: float64
2480+
0 4
2481+
1 2
2482+
2 6
2483+
dtype: int64
24722484
"""
24732485
super(Series, self).update(other, join=join, overwrite=overwrite,
24742486
filter_func=filter_func,

pandas/tests/frame/test_combine_concat.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,25 @@ def test_update_dtypes(self):
279279
columns=['A', 'B', 'bool1', 'bool2'])
280280
assert_frame_equal(df, expected)
281281

282+
df = DataFrame([[10, 100], [11, 101], [12, 102]], columns=['A', 'B'])
283+
other = DataFrame([[61, 601], [63, 603]], columns=['A', 'B'],
284+
index=[1, 3])
285+
df.update(other)
286+
287+
expected = DataFrame([[10, 100], [61, 601], [12, 102]],
288+
columns=['A', 'B'])
289+
assert_frame_equal(df, expected)
290+
291+
# we always try to keep original dtype, even if other has different one
292+
df.update(other.astype(float))
293+
assert_frame_equal(df, expected)
294+
295+
# if keeping the dtype is not possible, we allow upcasting
296+
df.update(other + 0.1)
297+
expected = DataFrame([[10., 100.], [61.1, 601.1], [12., 102.]],
298+
columns=['A', 'B'])
299+
assert_frame_equal(df, expected)
300+
282301
def test_update_nooverwrite(self):
283302
df = DataFrame([[1.5, nan, 3.],
284303
[1.5, nan, 3.],

pandas/tests/series/test_combine_concat.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111
from pandas import DataFrame, DatetimeIndex, Series, compat, date_range
1212
import pandas.util.testing as tm
13-
from pandas.util.testing import assert_series_equal
13+
from pandas.util.testing import assert_series_equal, assert_frame_equal
1414

1515

1616
class TestSeriesCombine():
@@ -105,8 +105,8 @@ def test_combine_first(self):
105105
assert_series_equal(s, result)
106106

107107
def test_update(self):
108-
s = Series([1.5, nan, 3., 4., nan])
109-
s2 = Series([nan, 3.5, nan, 5.])
108+
s = Series([1.5, np.nan, 3., 4., np.nan])
109+
s2 = Series([np.nan, 3.5, np.nan, 5.])
110110
s.update(s2)
111111

112112
expected = Series([1.5, 3.5, 3., 5., np.nan])
@@ -116,8 +116,35 @@ def test_update(self):
116116
df = DataFrame([{"a": 1}, {"a": 3, "b": 2}])
117117
df['c'] = np.nan
118118

119-
# this will fail as long as series is a sub-class of ndarray
120-
# df['c'].update(Series(['foo'],index=[0])) #####
119+
df['c'].update(Series(['foo'], index=[0]))
120+
expected = DataFrame([[1, np.nan, 'foo'], [3, 2., np.nan]],
121+
columns=['a', 'b', 'c'])
122+
assert_frame_equal(df, expected)
123+
124+
def test_update_dtypes(self):
125+
s = Series([1., 2., False, True])
126+
127+
other = Series([45])
128+
s.update(other)
129+
130+
expected = Series([45., 2., False, True])
131+
assert_series_equal(s, expected)
132+
133+
s = Series([10, 11, 12])
134+
other = Series([61, 63], index=[1, 3])
135+
s.update(other)
136+
137+
expected = Series([10, 61, 12])
138+
assert_series_equal(s, expected)
139+
140+
# we always try to keep original dtype, even if other has different one
141+
s.update(other.astype(float))
142+
assert_series_equal(s, expected)
143+
144+
# if keeping the dtype is not possible, we allow upcasting
145+
s.update(other + 0.1)
146+
expected = Series([10., 61.1, 12.])
147+
assert_series_equal(s, expected)
121148

122149
def test_update_nooverwrite(self):
123150
s = Series([0, 1, 2, np.nan, np.nan, 5, 6, np.nan])
@@ -129,6 +156,8 @@ def test_update_nooverwrite(self):
129156
assert_series_equal(s, expected)
130157

131158
def test_update_filtered(self):
159+
# for small values, np.arange defaults to int32,
160+
# but pandas default (e.g. for "expected" below) is int64
132161
s = Series(np.arange(8), dtype='int64')
133162
other = Series(np.arange(8), dtype='int64') + 10
134163

0 commit comments

Comments
 (0)