Skip to content

PERF: Updated andrews_curves to use Numpy arrays for its samples #11534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions asv_bench/benchmarks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
try:
from pandas import date_range
except ImportError:

def date_range(start=None, end=None, periods=None, freq=None):
return DatetimeIndex(start, end, periods=periods, offset=freq)
from pandas.tools.plotting import andrews_curves


class plot_timeseries_period(object):
Expand All @@ -16,4 +16,17 @@ def setup(self):
self.df = DataFrame(np.random.randn(self.N, self.M), index=date_range('1/1/1975', periods=self.N))

def time_plot_timeseries_period(self):
self.df.plot()
self.df.plot()

class plot_andrews_curves(object):
goal_time = 0.6

def setup(self):
self.N = 500
self.M = 10
data_dict = {x: np.random.randn(self.N) for x in range(self.M)}
data_dict["Name"] = ["A"] * self.N
self.df = DataFrame(data_dict)

def time_plot_andrews_curves(self):
andrews_curves(self.df, "Name")
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.18.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Removal of prior version deprecations/changes
Performance Improvements
~~~~~~~~~~~~~~~~~~~~~~~~


- Improved performance of ``andrews_curves`` (:issue:`11534`)



Expand Down
20 changes: 20 additions & 0 deletions pandas/tests/test_graphics_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,26 @@ def test_andrews_curves(self):
cmaps = lmap(cm.jet, np.linspace(0, 1, df['Name'].nunique()))
self._check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df['Name'][:10])

length = 10
df = DataFrame({"A": random.rand(length),
"B": random.rand(length),
"C": random.rand(length),
"Name": ["A"] * length})

_check_plot_works(andrews_curves, frame=df, class_column='Name')

rgba = ('#556270', '#4ECDC4', '#C7F464')
ax = _check_plot_works(andrews_curves, frame=df, class_column='Name', color=rgba)
self._check_colors(ax.get_lines()[:10], linecolors=rgba, mapping=df['Name'][:10])

cnames = ['dodgerblue', 'aquamarine', 'seagreen']
ax = _check_plot_works(andrews_curves, frame=df, class_column='Name', color=cnames)
self._check_colors(ax.get_lines()[:10], linecolors=cnames, mapping=df['Name'][:10])

ax = _check_plot_works(andrews_curves, frame=df, class_column='Name', colormap=cm.jet)
cmaps = lmap(cm.jet, np.linspace(0, 1, df['Name'].nunique()))
self._check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df['Name'][:10])

colors = ['b', 'g', 'r']
df = DataFrame({"A": [1, 2, 3],
"B": [1, 2, 3],
Expand Down
41 changes: 28 additions & 13 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,15 @@ def normalize(series):
def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
colormap=None, **kwds):
"""
Generates a matplotlib plot of Andrews curves, for visualising clusters of multivariate data.

Andrews curves have the functional form:

f(t) = x_1/sqrt(2) + x_2 sin(t) + x_3 cos(t) + x_4 sin(2t) + x_5 cos(2t) + ...

Where x coefficients correspond to the values of each dimension and t is linearly spaced between -pi and +pi. Each
row of frame then corresponds to a single curve.

Parameters:
-----------
frame : DataFrame
Expand All @@ -527,28 +536,34 @@ def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
ax: Matplotlib axis object

"""
from math import sqrt, pi, sin, cos
from math import sqrt, pi
import matplotlib.pyplot as plt

def function(amplitudes):
def f(x):
def f(t):
x1 = amplitudes[0]
result = x1 / sqrt(2.0)
harmonic = 1.0
for x_even, x_odd in zip(amplitudes[1::2], amplitudes[2::2]):
result += (x_even * sin(harmonic * x) +
x_odd * cos(harmonic * x))
harmonic += 1.0
if len(amplitudes) % 2 != 0:
result += amplitudes[-1] * sin(harmonic * x)

# Take the rest of the coefficients and resize them appropriately. Take a copy of amplitudes as otherwise
# numpy deletes the element from amplitudes itself.
coeffs = np.delete(np.copy(amplitudes), 0)
coeffs.resize((coeffs.size + 1) / 2, 2)

# Generate the harmonics and arguments for the sin and cos functions.
harmonics = np.arange(0, coeffs.shape[0]) + 1
trig_args = np.outer(harmonics, t)

result += np.sum(coeffs[:, 0, np.newaxis] * np.sin(trig_args) +
coeffs[:, 1, np.newaxis] * np.cos(trig_args),
axis=0)
return result
return f

n = len(frame)
class_col = frame[class_column]
classes = frame[class_column].drop_duplicates()
df = frame.drop(class_column, axis=1)
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
t = np.linspace(-pi, pi, samples)
used_legends = set([])

color_values = _get_standard_colors(num_colors=len(classes),
Expand All @@ -560,14 +575,14 @@ def f(x):
for i in range(n):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way to get a massive speedup, is do this ALL in a vectorized way, e.g.
pass in the entire numpy array (df.values), and do the calculation, then plot in a loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, of course. I'll do that and the asv benchmark. Will probably be over the weekend.

row = df.iloc[i].values
f = function(row)
y = [f(t) for t in x]
y = f(t)
kls = class_col.iat[i]
label = com.pprint_thing(kls)
if label not in used_legends:
used_legends.add(label)
ax.plot(x, y, color=colors[kls], label=label, **kwds)
ax.plot(t, y, color=colors[kls], label=label, **kwds)
else:
ax.plot(x, y, color=colors[kls], **kwds)
ax.plot(t, y, color=colors[kls], **kwds)

ax.legend(loc='upper right')
ax.grid()
Expand Down