Skip to content

[refactor] Store and restore state along with checkpoints #4025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Jun 3, 2020

Conversation

ervteng
Copy link
Contributor

@ervteng ervteng commented May 26, 2020

Proposed change(s)

This PR adds a mechanism to store data that is then written out as a JSON file (training_status.json) that can be loaded on resume. This is done through a new class (GlobalTrainingStatus) that keeps both metadata bout the JSON file and the key/values that need to be written, organized by behavior name. Currently, it stores just the lesson number during curriculum. This value is loaded when --resume is specified.

training_status.json is versioned in a very similar way to the timers.json file, and versions are checked on resume. Warnings are thrown if the version doesn't match.

We also no longer need the --lesson CLI option as that was used to reset a lesson on resume.

Types of change(s)

  • Bug fix
  • New feature
  • Code refactor
  • Breaking change
  • Documentation update
  • Other (please describe)

Checklist

  • Added tests that prove my fix is effective or that my feature works
  • Updated the changelog (if applicable)
  • Updated the documentation (if applicable)
  • Updated the migration guide (if applicable)

Other comments

@ervteng ervteng marked this pull request as ready for review May 26, 2020 23:18
@@ -304,6 +343,39 @@ def __init__(self, category: str):
def add_writer(writer: StatsWriter) -> None:
StatsReporter.writers.append(writer)

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StatsReporter doesn't seem like it's the right place to be saving and loading global state. Can you make a new class for this? TrainingGlobalState?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason I put it in StatsReporter was that the trainers already have one initialized with the right key (the brain_name). No problem, it can be moved

@@ -82,6 +82,9 @@ def run_training(run_seed: int, options: RunOptions) -> None:
)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
# Load any needed states
if checkpoint_settings.resume:
StatsReporter.load_state(os.path.join(run_logs_dir, "training_status.json"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move "training_status.json" to a constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved it to a constant and the method to a helper method similar to the timing tree and configuration file writes.


for brain_name, curriculum in self.brains_to_curricula.items():
# Create a temporary StatsReporter with the right brain name
_statsreporter = StatsReporter(brain_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear what's going on here; it feels really hacky (and probably brittle).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's less of a concern now since we're no longer using StatsReporters, but we're still using the brain_name to refer to meta curriculums.

# Update saved state.
StatsReporter.saved_state.update(loaded_dict)
except FileNotFoundError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a warning here, or is it expected that this won't be there most of the time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a warning. This should only happen if the user is loading from an older version of ML-Agents that did not save out the training_status.json file, or (as Vince mentioned below) the file was not saved out due to a crash.

Copy link
Contributor

@chriselion chriselion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't overload StatsReporter for this; I don't think it's the right place to stash the data.

@@ -159,6 +162,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
env_manager.close()
write_run_options(write_path, options)
write_timing_tree(run_logs_dir)
StatsReporter.save_state(os.path.join(run_logs_dir, "training_status.json"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the best place to store this global state. I would put it within the save model calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that makes sense since we'd want to resume on the event of a crash as well.

However the save model calls will be moved to within the trainers, since it seems logical that each trainer drives its own checkpointing (at possibly different frequencies). Wouldn't want the global save to be there as well. We could have each trainer manage its own save state, or have every trainer just trigger the save_state function.

An alternative would be to just write to the JSON every time a new state is written. For stuff like the lesson number (which is very infrequently written) this is OK.

@ervteng ervteng requested a review from chriselion May 28, 2020 18:26
type of status needed to be saved (e.g. Lesson Number). Finally the Value is the float value
attached to this stat.
"""
self.category: str = category
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still feels awkward having a mix of instance and static data. Why not drop self.category and add an extra argument to restore_parameter_state and store_parameter_state?

Copy link
Contributor Author

@ervteng ervteng May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went this way to keep it inline with the StatsReporter, and so that if a trainer uses it multiple times it doesn't need to continue to pass in the category.

Don't have a strong preference since it's used much less frequently than StatsReporter; I'm OK with changing it to a parameter to the class methods. We could also then make the class methods static and not need an instance of this at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to two args for get_parameter_state and set_parameter_state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove self.category (and the initializer) now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - oversight on my part. Removed both.

with open(path, "w") as f:
json.dump(GlobalTrainingStatus.saved_state, f, indent=4)

def store_parameter_state(self, key: StatusType, value: Any) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: set_parameter_state() and get_paramater_state()? I think restore_parameter_state is a bad name since "restore" makes it sound like it's doing loading.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.


def restore_parameter_state(self, key: StatusType) -> Any:
"""
Stores an arbitrary-named parameter in training_status.json.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy-pasted docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Fixed


statsreporter_new = GlobalTrainingStatus("Category1")
GlobalTrainingStatus.load_state(path_dir)
restored_val = statsreporter_new.restore_parameter_state(StatusType.LESSON_NUM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test that restore_parameter_state() on an unknown category or StatusType returns None instead of raising an exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test for these.

Copy link
Contributor

@chriselion chriselion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question, looks good though

@ervteng ervteng merged commit 5d02292 into master Jun 3, 2020
@delete-merged-branch delete-merged-branch bot deleted the develop-lessonresume branch June 3, 2020 01:11
@github-actions github-actions bot locked as resolved and limited conversation to collaborators Jun 3, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants