Skip to content

REF: make plotting less stateful #55837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from collections.abc import Collection

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D

from pandas._typing import MatplotlibColor
Expand Down Expand Up @@ -177,7 +178,7 @@ def maybe_color_bp(self, bp) -> None:
if not self.kwds.get("capprops"):
setp(bp["caps"], color=caps, alpha=1)

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
if self.subplots:
self._return_obj = pd.Series(dtype=object)

Expand Down
50 changes: 26 additions & 24 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.axis import Axis
from matplotlib.figure import Figure

from pandas._typing import (
IndexLabel,
Expand Down Expand Up @@ -241,7 +242,8 @@ def __init__(
self.stacked = kwds.pop("stacked", False)

self.ax = ax
self.fig = fig
# TODO: deprecate fig keyword as it is ignored, not passed in tests
# as of 2023-11-05
self.axes = np.array([], dtype=object) # "real" version get set in `generate`

# parse errorbar input if given
Expand Down Expand Up @@ -449,11 +451,11 @@ def draw(self) -> None:
def generate(self) -> None:
self._args_adjust()
self._compute_plot_data()
self._setup_subplots()
self._make_plot()
fig = self._setup_subplots()
self._make_plot(fig)
self._add_table()
self._make_legend()
self._adorn_subplots()
self._adorn_subplots(fig)

for ax in self.axes:
self._post_plot_logic_common(ax, self.data)
Expand Down Expand Up @@ -495,7 +497,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
new_ax.set_yscale("symlog")
return new_ax

def _setup_subplots(self):
def _setup_subplots(self) -> Figure:
if self.subplots:
naxes = (
self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
Expand Down Expand Up @@ -538,8 +540,8 @@ def _setup_subplots(self):
elif self.logy == "sym" or self.loglog == "sym":
[a.set_yscale("symlog") for a in axes]

self.fig = fig
self.axes = axes
return fig

@property
def result(self):
Expand Down Expand Up @@ -637,7 +639,7 @@ def _compute_plot_data(self):

self.data = numeric_data.apply(self._convert_to_ndarray)

def _make_plot(self):
def _make_plot(self, fig: Figure):
raise AbstractMethodError(self)

def _add_table(self) -> None:
Expand Down Expand Up @@ -672,11 +674,11 @@ def _post_plot_logic_common(self, ax, data):
def _post_plot_logic(self, ax, data) -> None:
"""Post process for each axes. Overridden in child classes"""

def _adorn_subplots(self):
def _adorn_subplots(self, fig: Figure):
"""Common post process unrelated to data"""
if len(self.axes) > 0:
all_axes = self._get_subplots()
nrows, ncols = self._get_axes_layout()
all_axes = self._get_subplots(fig)
nrows, ncols = self._get_axes_layout(fig)
handle_shared_axes(
axarr=all_axes,
nplots=len(all_axes),
Expand Down Expand Up @@ -723,7 +725,7 @@ def _adorn_subplots(self):
for ax, title in zip(self.axes, self.title):
ax.set_title(title)
else:
self.fig.suptitle(self.title)
fig.suptitle(self.title)
else:
if is_list_like(self.title):
msg = (
Expand Down Expand Up @@ -1114,17 +1116,17 @@ def _get_errorbars(
errors[kw] = err
return errors

def _get_subplots(self):
def _get_subplots(self, fig: Figure):
from matplotlib.axes import Subplot

return [
ax
for ax in self.fig.get_axes()
for ax in fig.get_axes()
if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
]

def _get_axes_layout(self) -> tuple[int, int]:
axes = self._get_subplots()
def _get_axes_layout(self, fig: Figure) -> tuple[int, int]:
axes = self._get_subplots(fig)
x_set = set()
y_set = set()
for ax in axes:
Expand Down Expand Up @@ -1172,7 +1174,7 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

def _plot_colorbar(self, ax: Axes, **kwds):
def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds):
# Addresses issues #10611 and #10678:
# When plotting scatterplots and hexbinplots in IPython
# inline backend the colorbar axis height tends not to
Expand All @@ -1189,7 +1191,7 @@ def _plot_colorbar(self, ax: Axes, **kwds):
# use the last one which contains the latest information
# about the ax
img = ax.collections[-1]
return self.fig.colorbar(img, ax=ax, **kwds)
return fig.colorbar(img, ax=ax, **kwds)


class ScatterPlot(PlanePlot):
Expand All @@ -1209,7 +1211,7 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
c = self.data.columns[c]
self.c = c

def _make_plot(self):
def _make_plot(self, fig: Figure):
x, y, c, data = self.x, self.y, self.c, self.data
ax = self.axes[0]

Expand Down Expand Up @@ -1274,7 +1276,7 @@ def _make_plot(self):
)
if cb:
cbar_label = c if c_is_column else ""
cbar = self._plot_colorbar(ax, label=cbar_label)
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
if color_by_categorical:
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
cbar.ax.set_yticklabels(self.data[c].cat.categories)
Expand Down Expand Up @@ -1306,7 +1308,7 @@ def __init__(self, data, x, y, C=None, **kwargs) -> None:
C = self.data.columns[C]
self.C = C

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
x, y, data, C = self.x, self.y, self.data, self.C
ax = self.axes[0]
# pandas uses colormap, matplotlib uses cmap.
Expand All @@ -1321,7 +1323,7 @@ def _make_plot(self) -> None:

ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds)
if cb:
self._plot_colorbar(ax)
self._plot_colorbar(ax, fig=fig)

def _make_legend(self) -> None:
pass
Expand Down Expand Up @@ -1358,7 +1360,7 @@ def _is_ts_plot(self) -> bool:
def _use_dynamic_x(self):
return use_dynamic_x(self._get_ax(0), self.data)

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
if self._is_ts_plot():
data = maybe_convert_index(self._get_ax(0), self.data)

Expand Down Expand Up @@ -1680,7 +1682,7 @@ def _plot( # type: ignore[override]
def _start_base(self):
return self.bottom

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors()
ncolors = len(colors)

Expand Down Expand Up @@ -1842,7 +1844,7 @@ def _args_adjust(self) -> None:
def _validate_color_args(self) -> None:
pass

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
self.kwds.setdefault("colors", colors)

Expand Down
3 changes: 2 additions & 1 deletion pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from pandas._typing import PlottingOrientation

Expand Down Expand Up @@ -113,7 +114,7 @@ def _plot( # type: ignore[override]
cls._update_stacker(ax, stacking_id, n)
return patches

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors()
stacking_id = self._get_stacking_id()

Expand Down