Skip to content

Commit 398173f

Browse files
siavrezglemaitre
authored andcommitted
MNT Accelerate plot_model_complexity_influence.py example using a subset of dataset for classification (scikit-learn#21742)
1 parent 72f4303 commit 398173f

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

examples/applications/plot_model_complexity_influence.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@
4242
import matplotlib.pyplot as plt
4343

4444
from sklearn import datasets
45-
from sklearn.utils import shuffle
45+
from sklearn.model_selection import train_test_split
4646
from sklearn.metrics import mean_squared_error
4747
from sklearn.svm import NuSVR
4848
from sklearn.ensemble import GradientBoostingRegressor
4949
from sklearn.linear_model import SGDClassifier
5050
from sklearn.metrics import hamming_loss
5151

52-
5352
# Initialize random generator
5453
np.random.seed(0)
5554

@@ -72,12 +71,14 @@ def generate_data(case):
7271
"""Generate regression/classification data."""
7372
if case == "regression":
7473
X, y = datasets.load_diabetes(return_X_y=True)
74+
train_size = 0.8
7575
elif case == "classification":
7676
X, y = datasets.fetch_20newsgroups_vectorized(subset="all", return_X_y=True)
77-
X, y = shuffle(X, y)
78-
offset = int(X.shape[0] * 0.8)
79-
X_train, y_train = X[:offset], y[:offset]
80-
X_test, y_test = X[offset:], y[offset:]
77+
train_size = 0.4 # to make the example run faster
78+
79+
X_train, X_test, y_train, y_test = train_test_split(
80+
X, y, train_size=train_size, random_state=0
81+
)
8182

8283
data = {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}
8384
return data
@@ -174,33 +175,37 @@ def _count_nonzero_coefficients(estimator):
174175
"prediction_performance_label": "Hamming Loss (Misclassification Ratio)",
175176
"postfit_hook": lambda x: x.sparsify(),
176177
"data": classification_data,
177-
"n_samples": 30,
178+
"n_samples": 5,
178179
},
179180
{
180181
"estimator": NuSVR,
181182
"tuned_params": {"C": 1e3, "gamma": 2 ** -15},
182183
"changing_param": "nu",
183-
"changing_param_values": [0.1, 0.25, 0.5, 0.75, 0.9],
184+
"changing_param_values": [0.05, 0.1, 0.2, 0.35, 0.5],
184185
"complexity_label": "n_support_vectors",
185186
"complexity_computer": lambda x: len(x.support_vectors_),
186187
"data": regression_data,
187188
"postfit_hook": lambda x: x,
188189
"prediction_performance_computer": mean_squared_error,
189190
"prediction_performance_label": "MSE",
190-
"n_samples": 30,
191+
"n_samples": 15,
191192
},
192193
{
193194
"estimator": GradientBoostingRegressor,
194-
"tuned_params": {"loss": "squared_error"},
195+
"tuned_params": {
196+
"loss": "squared_error",
197+
"learning_rate": 0.05,
198+
"max_depth": 2,
199+
},
195200
"changing_param": "n_estimators",
196-
"changing_param_values": [10, 50, 100, 200, 500],
201+
"changing_param_values": [10, 25, 50, 75, 100],
197202
"complexity_label": "n_trees",
198203
"complexity_computer": lambda x: x.n_estimators,
199204
"data": regression_data,
200205
"postfit_hook": lambda x: x,
201206
"prediction_performance_computer": mean_squared_error,
202207
"prediction_performance_label": "MSE",
203-
"n_samples": 30,
208+
"n_samples": 15,
204209
},
205210
]
206211

@@ -255,7 +260,9 @@ def plot_influence(conf, mse_values, prediction_times, complexities):
255260
ax2.yaxis.label.set_color(line2.get_color())
256261
ax2.tick_params(axis="y", colors=line2.get_color())
257262

258-
plt.legend((line1, line2), ("prediction error", "latency"), loc="upper right")
263+
plt.legend(
264+
(line1, line2), ("prediction error", "prediction latency"), loc="upper right"
265+
)
259266

260267
plt.title(
261268
"Influence of varying '%s' on %s"
@@ -268,7 +275,6 @@ def plot_influence(conf, mse_values, prediction_times, complexities):
268275
plot_influence(conf, prediction_performances, prediction_times, complexities)
269276
plt.show()
270277

271-
272278
##############################################################################
273279
# Conclusion
274280
# ----------

0 commit comments

Comments
 (0)