@@ -236,8 +236,9 @@ def _objective(trial, checkpoint_dir=None):
236
236
237
237
def run_hp_search_ray (trainer , n_trials : int , direction : str , ** kwargs ) -> BestRun :
238
238
import ray
239
+ import ray .train
239
240
240
- def _objective (trial , local_trainer , checkpoint_dir = None ):
241
+ def _objective (trial : dict , local_trainer ):
241
242
try :
242
243
from transformers .utils .notebook import NotebookProgressCallback
243
244
@@ -246,19 +247,34 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
246
247
except ModuleNotFoundError :
247
248
pass
248
249
249
- checkpoint = None
250
- if checkpoint_dir :
251
- for subdir in os .listdir (checkpoint_dir ):
252
- if subdir .startswith (PREFIX_CHECKPOINT_DIR ):
253
- checkpoint = os .path .join (checkpoint_dir , subdir )
254
250
local_trainer .objective = None
255
- local_trainer .train (resume_from_checkpoint = checkpoint , trial = trial )
251
+
252
+ checkpoint = ray .train .get_checkpoint ()
253
+ if checkpoint :
254
+ # Upon trial resume, the local_trainer's objective gets reset to None.
255
+ # If `local_trainer.train` is a noop (training has already reached
256
+ # the target number of epochs/steps), then this would
257
+ # trigger an unnecessary extra checkpoint at the end of training.
258
+ # -> Set the objective to a dummy value upon resume as a workaround.
259
+ local_trainer .objective = "objective"
260
+
261
+ with checkpoint .as_directory () as checkpoint_dir :
262
+ checkpoint_path = next (Path (checkpoint_dir ).glob (f"{ PREFIX_CHECKPOINT_DIR } *" )).as_posix ()
263
+ local_trainer .train (resume_from_checkpoint = checkpoint_path , trial = trial )
264
+ else :
265
+ local_trainer .train (trial = trial )
266
+
256
267
# If there hasn't been any evaluation during the training loop.
257
268
if getattr (local_trainer , "objective" , None ) is None :
258
269
metrics = local_trainer .evaluate ()
259
270
local_trainer .objective = local_trainer .compute_objective (metrics )
260
- local_trainer ._tune_save_checkpoint ()
261
- ray .tune .report (objective = local_trainer .objective , ** metrics , done = True )
271
+
272
+ metrics .update ({"objective" : local_trainer .objective , "done" : True })
273
+
274
+ with tempfile .TemporaryDirectory () as temp_checkpoint_dir :
275
+ local_trainer ._tune_save_checkpoint (checkpoint_dir = temp_checkpoint_dir )
276
+ checkpoint = ray .train .Checkpoint .from_directory (temp_checkpoint_dir )
277
+ ray .train .report (metrics , checkpoint = checkpoint )
262
278
263
279
if not trainer ._memory_tracker .skip_memory_metrics :
264
280
from ..trainer_utils import TrainerMemoryTracker
@@ -296,28 +312,10 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
296
312
from ray .tune import CLIReporter
297
313
298
314
kwargs ["progress_reporter" ] = CLIReporter (metric_columns = ["objective" ])
299
- if "keep_checkpoints_num" in kwargs and kwargs ["keep_checkpoints_num" ] > 0 :
300
- # `keep_checkpoints_num=0` would disabled checkpointing
301
- trainer .use_tune_checkpoints = True
302
- if kwargs ["keep_checkpoints_num" ] > 1 :
303
- logger .warning (
304
- f"Currently keeping { kwargs ['keep_checkpoints_num' ]} checkpoints for each trial. "
305
- "Checkpoints are usually huge, "
306
- "consider setting `keep_checkpoints_num=1`."
307
- )
315
+
308
316
if "scheduler" in kwargs :
309
317
from ray .tune .schedulers import ASHAScheduler , HyperBandForBOHB , MedianStoppingRule , PopulationBasedTraining
310
318
311
- # Check if checkpointing is enabled for PopulationBasedTraining
312
- if isinstance (kwargs ["scheduler" ], PopulationBasedTraining ):
313
- if not trainer .use_tune_checkpoints :
314
- logger .warning (
315
- "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
316
- "This means your trials will train from scratch everytime they are exploiting "
317
- "new configurations. Consider enabling checkpointing by passing "
318
- "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
319
- )
320
-
321
319
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
322
320
if isinstance (
323
321
kwargs ["scheduler" ], (ASHAScheduler , MedianStoppingRule , HyperBandForBOHB , PopulationBasedTraining )
0 commit comments