-
Notifications
You must be signed in to change notification settings - Fork 4.3k
[refactor] Store and restore state along with checkpoints #4025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ml-agents/mlagents/trainers/stats.py
Outdated
@@ -304,6 +343,39 @@ def __init__(self, category: str): | |||
def add_writer(writer: StatsWriter) -> None: | |||
StatsReporter.writers.append(writer) | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StatsReporter doesn't seem like it's the right place to be saving and loading global state. Can you make a new class for this? TrainingGlobalState?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason I put it in StatsReporter was that the trainers already have one initialized with the right key (the brain_name). No problem, it can be moved
ml-agents/mlagents/trainers/learn.py
Outdated
@@ -82,6 +82,9 @@ def run_training(run_seed: int, options: RunOptions) -> None: | |||
) | |||
# Make run logs directory | |||
os.makedirs(run_logs_dir, exist_ok=True) | |||
# Load any needed states | |||
if checkpoint_settings.resume: | |||
StatsReporter.load_state(os.path.join(run_logs_dir, "training_status.json")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move "training_status.json"
to a constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved it to a constant and the method to a helper method similar to the timing tree and configuration file writes.
|
||
for brain_name, curriculum in self.brains_to_curricula.items(): | ||
# Create a temporary StatsReporter with the right brain name | ||
_statsreporter = StatsReporter(brain_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear what's going on here; it feels really hacky (and probably brittle).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's less of a concern now since we're no longer using StatsReporters, but we're still using the brain_name
to refer to meta curriculums.
ml-agents/mlagents/trainers/stats.py
Outdated
# Update saved state. | ||
StatsReporter.saved_state.update(loaded_dict) | ||
except FileNotFoundError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a warning here, or is it expected that this won't be there most of the time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a warning. This should only happen if the user is loading from an older version of ML-Agents that did not save out the training_status.json
file, or (as Vince mentioned below) the file was not saved out due to a crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't overload StatsReporter for this; I don't think it's the right place to stash the data.
ml-agents/mlagents/trainers/learn.py
Outdated
@@ -159,6 +162,7 @@ def run_training(run_seed: int, options: RunOptions) -> None: | |||
env_manager.close() | |||
write_run_options(write_path, options) | |||
write_timing_tree(run_logs_dir) | |||
StatsReporter.save_state(os.path.join(run_logs_dir, "training_status.json")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the best place to store this global state. I would put it within the save model calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, that makes sense since we'd want to resume on the event of a crash as well.
However the save model calls will be moved to within the trainers, since it seems logical that each trainer drives its own checkpointing (at possibly different frequencies). Wouldn't want the global save to be there as well. We could have each trainer manage its own save state, or have every trainer just trigger the save_state function.
An alternative would be to just write to the JSON every time a new state is written. For stuff like the lesson number (which is very infrequently written) this is OK.
Co-authored-by: Chris Elion <[email protected]>
…ml-agents into develop-lessonresume
Add warning if file not found
type of status needed to be saved (e.g. Lesson Number). Finally the Value is the float value | ||
attached to this stat. | ||
""" | ||
self.category: str = category |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still feels awkward having a mix of instance and static data. Why not drop self.category
and add an extra argument to restore_parameter_state
and store_parameter_state
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went this way to keep it inline with the StatsReporter, and so that if a trainer uses it multiple times it doesn't need to continue to pass in the category.
Don't have a strong preference since it's used much less frequently than StatsReporter; I'm OK with changing it to a parameter to the class methods. We could also then make the class methods static and not need an instance of this at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to two args for get_parameter_state
and set_parameter_state
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove self.category (and the initializer) now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - oversight on my part. Removed both.
with open(path, "w") as f: | ||
json.dump(GlobalTrainingStatus.saved_state, f, indent=4) | ||
|
||
def store_parameter_state(self, key: StatusType, value: Any) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: set_parameter_state()
and get_paramater_state()
? I think restore_parameter_state
is a bad name since "restore" makes it sound like it's doing loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
|
||
def restore_parameter_state(self, key: StatusType) -> Any: | ||
""" | ||
Stores an arbitrary-named parameter in training_status.json. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy-pasted docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Fixed
|
||
statsreporter_new = GlobalTrainingStatus("Category1") | ||
GlobalTrainingStatus.load_state(path_dir) | ||
restored_val = statsreporter_new.restore_parameter_state(StatusType.LESSON_NUM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also test that restore_parameter_state() on an unknown category or StatusType returns None instead of raising an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test for these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question, looks good though
Proposed change(s)
This PR adds a mechanism to store data that is then written out as a JSON file (
training_status.json
) that can be loaded on resume. This is done through a new class (GlobalTrainingStatus
) that keeps both metadata bout the JSON file and the key/values that need to be written, organized by behavior name. Currently, it stores just the lesson number during curriculum. This value is loaded when--resume
is specified.training_status.json
is versioned in a very similar way to thetimers.json
file, and versions are checked on resume. Warnings are thrown if the version doesn't match.We also no longer need the
--lesson
CLI option as that was used to reset a lesson on resume.Types of change(s)
Checklist
Other comments