Skip to content

[versioning] Save ML-Agents version in checkpoints and check on load #4035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 30, 2020

Conversation

ervteng
Copy link
Contributor

@ervteng ervteng commented May 28, 2020

Proposed change(s)

This PR saves the semantic version of the trainer package (major, minor, and patch) as 3 variables in the TF graph. It is also exported into the .nn file. This lets us check whether a checkpoint is being loaded from the same version of ML-Agents, and (in the future) check which version of ML-Agents created a particular NN file.

Note that this is different than the existing version number in the NN, which is checked by C#. That number corresponds to the input and output tensors. It is possible, when upgrading the Trainer code, for an NN file to remain compatible with C# but not be loadable into Python for training (e.g. if the network architecture changes).

Currently, we throw a warning if the versions don't match when a user tries to load a checkpoint.

Types of change(s)

  • Bug fix
  • New feature
  • Code refactor
  • Breaking change
  • Documentation update
  • Other (please describe)

Checklist

  • Added tests that prove my fix is effective or that my feature works
  • Updated the changelog (if applicable)
  • Updated the documentation (if applicable)
  • Updated the migration guide (if applicable)

@ervteng ervteng requested review from chriselion and andrewcoh May 28, 2020 18:31


logger = get_logger(__name__)


# This is the version number of the inputs and outputs of the model, and
# determines compatibility with inference in Barracuda.
API_VERSION_NUMBER = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MODEL_FORMAT_VERSION_NUMBER? "API" doesn't really feel right here.

:param version_string: The semantic-versioned version string (X.Y.Z).
:return: A Tuple containing (major_ver, minor_ver, patch_ver).
"""
split_ver = version_string.split(".")[0:3] # Remove dev tag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use distutils.version.LooseVersion to simplify this:

>>> from distutils.version import LooseVersion
>>> v = LooseVersion("1.2.3.dev4")
>>> v
LooseVersion ('1.2.3.dev4')
>>> v.version
[1, 2, 3, 'dev', 4]
>>> v.version[0:3]
[1, 2, 3]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that is very handy! Updated to use the LooseVersion class.

@ervteng ervteng merged commit 2b7b6e8 into master May 30, 2020
@delete-merged-branch delete-merged-branch bot deleted the develop-checktfver branch May 30, 2020 00:55
@github-actions github-actions bot locked as resolved and limited conversation to collaborators May 30, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants