Skip to content

Commit 5fa66df

Browse files
authored
[integration] Update Ray Tune integration for Ray 2.7 (#26499)
* fix tune integration for ray 2.7+ Signed-off-by: Justin Yu <[email protected]> * add version check for ray tune backend availability Signed-off-by: Justin Yu <[email protected]> * missing import Signed-off-by: Justin Yu <[email protected]> * pin min version instead Signed-off-by: Justin Yu <[email protected]> * address comments Signed-off-by: Justin Yu <[email protected]> * some fixes Signed-off-by: Justin Yu <[email protected]> * fix unnecessary final checkpoint Signed-off-by: Justin Yu <[email protected]> * fix lint Signed-off-by: Justin Yu <[email protected]> * dep table fix Signed-off-by: Justin Yu <[email protected]> * fix lint Signed-off-by: Justin Yu <[email protected]> --------- Signed-off-by: Justin Yu <[email protected]>
1 parent ffd426e commit 5fa66df

File tree

5 files changed

+52
-54
lines changed

5 files changed

+52
-54
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
"pytest-timeout",
150150
"pytest-xdist",
151151
"python>=3.8.0",
152-
"ray[tune]",
152+
"ray[tune]>=2.7.0",
153153
"regex!=2019.12.17",
154154
"requests",
155155
"rhoknp>=1.1.0,<1.3.1",

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"pytest-timeout": "pytest-timeout",
5656
"pytest-xdist": "pytest-xdist",
5757
"python": "python>=3.8.0",
58-
"ray[tune]": "ray[tune]",
58+
"ray[tune]": "ray[tune]>=2.7.0",
5959
"regex": "regex!=2019.12.17",
6060
"requests": "requests",
6161
"rhoknp": "rhoknp>=1.1.0,<1.3.1",

src/transformers/hyperparameter_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .integrations import (
1717
is_optuna_available,
18-
is_ray_available,
18+
is_ray_tune_available,
1919
is_sigopt_available,
2020
is_wandb_available,
2121
run_hp_search_optuna,
@@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):
8181

8282
@staticmethod
8383
def is_available():
84-
return is_ray_available()
84+
return is_ray_tune_available()
8585

8686
def run(self, trainer, n_trials: int, direction: str, **kwargs):
8787
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)

src/transformers/integrations/integration_utils.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,9 @@ def _objective(trial, checkpoint_dir=None):
236236

237237
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
238238
import ray
239+
import ray.train
239240

240-
def _objective(trial, local_trainer, checkpoint_dir=None):
241+
def _objective(trial: dict, local_trainer):
241242
try:
242243
from transformers.utils.notebook import NotebookProgressCallback
243244

@@ -246,19 +247,34 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
246247
except ModuleNotFoundError:
247248
pass
248249

249-
checkpoint = None
250-
if checkpoint_dir:
251-
for subdir in os.listdir(checkpoint_dir):
252-
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
253-
checkpoint = os.path.join(checkpoint_dir, subdir)
254250
local_trainer.objective = None
255-
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
251+
252+
checkpoint = ray.train.get_checkpoint()
253+
if checkpoint:
254+
# Upon trial resume, the local_trainer's objective gets reset to None.
255+
# If `local_trainer.train` is a noop (training has already reached
256+
# the target number of epochs/steps), then this would
257+
# trigger an unnecessary extra checkpoint at the end of training.
258+
# -> Set the objective to a dummy value upon resume as a workaround.
259+
local_trainer.objective = "objective"
260+
261+
with checkpoint.as_directory() as checkpoint_dir:
262+
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
263+
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
264+
else:
265+
local_trainer.train(trial=trial)
266+
256267
# If there hasn't been any evaluation during the training loop.
257268
if getattr(local_trainer, "objective", None) is None:
258269
metrics = local_trainer.evaluate()
259270
local_trainer.objective = local_trainer.compute_objective(metrics)
260-
local_trainer._tune_save_checkpoint()
261-
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
271+
272+
metrics.update({"objective": local_trainer.objective, "done": True})
273+
274+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
275+
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
276+
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
277+
ray.train.report(metrics, checkpoint=checkpoint)
262278

263279
if not trainer._memory_tracker.skip_memory_metrics:
264280
from ..trainer_utils import TrainerMemoryTracker
@@ -296,28 +312,10 @@ def _objective(trial, local_trainer, checkpoint_dir=None):
296312
from ray.tune import CLIReporter
297313

298314
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
299-
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
300-
# `keep_checkpoints_num=0` would disabled checkpointing
301-
trainer.use_tune_checkpoints = True
302-
if kwargs["keep_checkpoints_num"] > 1:
303-
logger.warning(
304-
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
305-
"Checkpoints are usually huge, "
306-
"consider setting `keep_checkpoints_num=1`."
307-
)
315+
308316
if "scheduler" in kwargs:
309317
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
310318

311-
# Check if checkpointing is enabled for PopulationBasedTraining
312-
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
313-
if not trainer.use_tune_checkpoints:
314-
logger.warning(
315-
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
316-
"This means your trials will train from scratch everytime they are exploiting "
317-
"new configurations. Consider enabling checkpointing by passing "
318-
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
319-
)
320-
321319
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
322320
if isinstance(
323321
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)

src/transformers/trainer.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import re
2929
import shutil
3030
import sys
31+
import tempfile
3132
import time
3233
import warnings
3334
from collections.abc import Mapping
@@ -595,7 +596,6 @@ def __init__(
595596
# returned to 0 every time flos need to be logged
596597
self.current_flos = 0
597598
self.hp_search_backend = None
598-
self.use_tune_checkpoints = False
599599
default_label_names = find_labels(self.model.__class__)
600600
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
601601
self.can_return_loss = can_return_loss(self.model.__class__)
@@ -1201,7 +1201,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
12011201
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
12021202
if self.hp_search_backend is None or trial is None:
12031203
return
1204-
self.objective = self.compute_objective(metrics.copy())
1204+
metrics = metrics.copy()
1205+
self.objective = self.compute_objective(metrics)
12051206
if self.hp_search_backend == HPSearchBackend.OPTUNA:
12061207
import optuna
12071208

@@ -1211,24 +1212,23 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste
12111212
self.callback_handler.on_train_end(self.args, self.state, self.control)
12121213
raise optuna.TrialPruned()
12131214
elif self.hp_search_backend == HPSearchBackend.RAY:
1214-
from ray import tune
1215-
1216-
if self.control.should_save:
1217-
self._tune_save_checkpoint()
1218-
tune.report(objective=self.objective, **metrics)
1219-
1220-
def _tune_save_checkpoint(self):
1221-
from ray import tune
1222-
1223-
if not self.use_tune_checkpoints:
1224-
return
1225-
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1226-
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
1227-
self.save_model(output_dir, _internal_call=True)
1228-
if self.args.should_save:
1229-
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
1230-
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
1231-
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1215+
import ray.train
1216+
1217+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
1218+
checkpoint = None
1219+
if self.control.should_save:
1220+
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
1221+
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
1222+
metrics["objective"] = self.objective
1223+
ray.train.report(metrics, checkpoint=checkpoint)
1224+
1225+
def _tune_save_checkpoint(self, checkpoint_dir: str):
1226+
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
1227+
self.save_model(output_dir, _internal_call=True)
1228+
if self.args.should_save:
1229+
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
1230+
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
1231+
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
12321232

12331233
def call_model_init(self, trial=None):
12341234
model_init_argcount = number_of_arguments(self.model_init)
@@ -2004,9 +2004,9 @@ def _get_output_dir(self, trial):
20042004
if self.hp_search_backend == HPSearchBackend.OPTUNA:
20052005
run_id = trial.number
20062006
elif self.hp_search_backend == HPSearchBackend.RAY:
2007-
from ray import tune
2007+
import ray.train
20082008

2009-
run_id = tune.get_trial_id()
2009+
run_id = ray.train.get_context().get_trial_id()
20102010
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
20112011
run_id = trial.id
20122012
elif self.hp_search_backend == HPSearchBackend.WANDB:

0 commit comments

Comments
 (0)