42
42
import matplotlib .pyplot as plt
43
43
44
44
from sklearn import datasets
45
- from sklearn .utils import shuffle
45
+ from sklearn .model_selection import train_test_split
46
46
from sklearn .metrics import mean_squared_error
47
47
from sklearn .svm import NuSVR
48
48
from sklearn .ensemble import GradientBoostingRegressor
49
49
from sklearn .linear_model import SGDClassifier
50
50
from sklearn .metrics import hamming_loss
51
51
52
-
53
52
# Initialize random generator
54
53
np .random .seed (0 )
55
54
@@ -72,12 +71,14 @@ def generate_data(case):
72
71
"""Generate regression/classification data."""
73
72
if case == "regression" :
74
73
X , y = datasets .load_diabetes (return_X_y = True )
74
+ train_size = 0.8
75
75
elif case == "classification" :
76
76
X , y = datasets .fetch_20newsgroups_vectorized (subset = "all" , return_X_y = True )
77
- X , y = shuffle (X , y )
78
- offset = int (X .shape [0 ] * 0.8 )
79
- X_train , y_train = X [:offset ], y [:offset ]
80
- X_test , y_test = X [offset :], y [offset :]
77
+ train_size = 0.4 # to make the example run faster
78
+
79
+ X_train , X_test , y_train , y_test = train_test_split (
80
+ X , y , train_size = train_size , random_state = 0
81
+ )
81
82
82
83
data = {"X_train" : X_train , "X_test" : X_test , "y_train" : y_train , "y_test" : y_test }
83
84
return data
@@ -174,33 +175,37 @@ def _count_nonzero_coefficients(estimator):
174
175
"prediction_performance_label" : "Hamming Loss (Misclassification Ratio)" ,
175
176
"postfit_hook" : lambda x : x .sparsify (),
176
177
"data" : classification_data ,
177
- "n_samples" : 30 ,
178
+ "n_samples" : 5 ,
178
179
},
179
180
{
180
181
"estimator" : NuSVR ,
181
182
"tuned_params" : {"C" : 1e3 , "gamma" : 2 ** - 15 },
182
183
"changing_param" : "nu" ,
183
- "changing_param_values" : [0.1 , 0.25 , 0.5 , 0.75 , 0.9 ],
184
+ "changing_param_values" : [0.05 , 0.1 , 0.2 , 0.35 , 0.5 ],
184
185
"complexity_label" : "n_support_vectors" ,
185
186
"complexity_computer" : lambda x : len (x .support_vectors_ ),
186
187
"data" : regression_data ,
187
188
"postfit_hook" : lambda x : x ,
188
189
"prediction_performance_computer" : mean_squared_error ,
189
190
"prediction_performance_label" : "MSE" ,
190
- "n_samples" : 30 ,
191
+ "n_samples" : 15 ,
191
192
},
192
193
{
193
194
"estimator" : GradientBoostingRegressor ,
194
- "tuned_params" : {"loss" : "squared_error" },
195
+ "tuned_params" : {
196
+ "loss" : "squared_error" ,
197
+ "learning_rate" : 0.05 ,
198
+ "max_depth" : 2 ,
199
+ },
195
200
"changing_param" : "n_estimators" ,
196
- "changing_param_values" : [10 , 50 , 100 , 200 , 500 ],
201
+ "changing_param_values" : [10 , 25 , 50 , 75 , 100 ],
197
202
"complexity_label" : "n_trees" ,
198
203
"complexity_computer" : lambda x : x .n_estimators ,
199
204
"data" : regression_data ,
200
205
"postfit_hook" : lambda x : x ,
201
206
"prediction_performance_computer" : mean_squared_error ,
202
207
"prediction_performance_label" : "MSE" ,
203
- "n_samples" : 30 ,
208
+ "n_samples" : 15 ,
204
209
},
205
210
]
206
211
@@ -255,7 +260,9 @@ def plot_influence(conf, mse_values, prediction_times, complexities):
255
260
ax2 .yaxis .label .set_color (line2 .get_color ())
256
261
ax2 .tick_params (axis = "y" , colors = line2 .get_color ())
257
262
258
- plt .legend ((line1 , line2 ), ("prediction error" , "latency" ), loc = "upper right" )
263
+ plt .legend (
264
+ (line1 , line2 ), ("prediction error" , "prediction latency" ), loc = "upper right"
265
+ )
259
266
260
267
plt .title (
261
268
"Influence of varying '%s' on %s"
@@ -268,7 +275,6 @@ def plot_influence(conf, mse_values, prediction_times, complexities):
268
275
plot_influence (conf , prediction_performances , prediction_times , complexities )
269
276
plt .show ()
270
277
271
-
272
278
##############################################################################
273
279
# Conclusion
274
280
# ----------
0 commit comments