Skip to content

Make RL training compatible with PyTorch #1520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions reinforcement_learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ These examples demonstrate how to train reinforcement learning models on SageMak

**IMPORTANT for rllib users:** Some examples may break with latest [rllib](https://docs.ray.io/en/latest/rllib.html) due to breaking API changes. Please refer to [Amazon SageMaker RL Container](https://github.com/aws/sagemaker-rl-container) for the latest public images and modify the configs in entrypoint scripts according to [rllib algorithm config](https://docs.ray.io/en/latest/rllib-algorithms.html).

If you are using PyTorch rather than TensorFlow, please set `debugger_hook_config=False` when calling `RLEstimator()` to avoid TensorBoard conflicts.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the conflicts here? And isn't debugger_hook_config by default is False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not specified, a default one is created using the estimator's output_path, unless the region does not support SageMaker Debugger (link). tf.summary generated by ray cannot be picked up appropriately by SM Debgger.


- [Contextual Bandit with Live Environment](bandits_statlog_vw_customEnv) illustrates how you can manage your own contextual multi-armed bandit workflow on SageMaker using the built-in [Vowpal Wabbit](https://github.com/VowpalWabbit/vowpal_wabbit) (VW) container to train and deploy contextual bandit models.
- [Cartpole](rl_cartpole_coach) uses SageMaker RL base [docker image](https://github.com/aws/sagemaker-rl-container) to balance a broom upright.
- [Cartpole Batch](rl_cartpole_batch_coach) uses batch RL techniques to train Cartpole with offline data.
Expand Down
11 changes: 8 additions & 3 deletions reinforcement_learning/common/sagemaker_rl/ray_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,13 @@ def create_tf_serving_model(self, algorithm=None, env_string=None):
agent.restore(checkpoint)
export_tf_serving(agent, MODEL_OUTPUT_DIR)

def save_checkpoint_and_serving_model(self, algorithm=None, env_string=None):
def save_checkpoint_and_serving_model(self, algorithm=None, env_string=None, use_pytorch=False):
self.save_experiment_config()
self.copy_checkpoints_to_model_output()
self.create_tf_serving_model(algorithm, env_string)
if use_pytorch:
print("Skipped PyTorch serving.")
else:
self.create_tf_serving_model(algorithm, env_string)

# To ensure SageMaker local mode works fine
change_permissions_recursive(INTERMEDIATE_DIR, 0o777)
Expand Down Expand Up @@ -335,8 +338,10 @@ def launch(self):

algo = experiment_config["training"]["run"]
env_string = experiment_config["training"]["config"]["env"]
use_pytorch = experiment_config["training"]["config"].get("use_pytorch", False)
self.save_checkpoint_and_serving_model(algorithm=algo,
env_string=env_string)
env_string=env_string,
use_pytorch=use_pytorch)

@classmethod
def train_main(cls):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def change_permissions_recursive(path, mode):
for root, dirs, files in os.walk(path, topdown=False):
for dir in [os.path.join(root, d) for d in dirs]:
os.chmod(dir, mode)
for file in [os.path.join(root, f) for f in files]:
os.chmod(file, mode)
for file in [os.path.join(root, f) for f in files]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this change, this seems changing the logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an existing bug since the launch of SageMaker RL. There was plan to fix it but got delayed.

os.chmod(file, mode)


def export_tf_serving(agent, output_dir):
Expand Down
Empty file.
Empty file.
31 changes: 31 additions & 0 deletions reinforcement_learning/common/tests/unit/test_ray_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import pytest
from mock import Mock, MagicMock, patch

from sagemaker_rl.ray_launcher import SageMakerRayLauncher

@patch("sagemaker_rl.ray_launcher.SageMakerRayLauncher.__init__", return_value=None)
@patch("sagemaker_rl.ray_launcher.change_permissions_recursive")
def test_pytorch_save_checkpoint_and_serving_model(change_permission, launcher_init):
launcher = SageMakerRayLauncher()
launcher.copy_checkpoints_to_model_output = Mock()
launcher.create_tf_serving_model = Mock()
launcher.save_experiment_config = Mock()

launcher.save_checkpoint_and_serving_model(use_pytorch=True)
launcher.create_tf_serving_model.assert_not_called()
launcher.save_checkpoint_and_serving_model(use_pytorch=False)
launcher.create_tf_serving_model.assert_called_once()
assert 4 == change_permission.call_count
33 changes: 16 additions & 17 deletions reinforcement_learning/rl_cartpole_ray/rl_cartpole_ray_gymEnv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"---\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to expose this parameter as compared to the previous version?


Reply via ReviewNB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SageMaker Debugger is a new feature released in last re:Invent. Prior to that the summary data collected by tf.summary didn't interact with SageMaker.

"## Introduction\n",
"\n",
"In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Ray RLlib \n",
"In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Ray RLlib. You can choose either TensorFlow or PyTorch as your underlying DL framework.\n",
"\n",
"(For a similar example using Coach library, see this [link](../rl_cartpole_coach/rl_cartpole_coach_gymEnv.ipynb). Another Cart-pole example using Coach library and offline data can be found [here](../rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb).)\n",
"\n",
Expand Down Expand Up @@ -196,7 +196,8 @@
"\n",
"cpu_or_gpu = 'gpu' if instance_type.startswith('ml.p') else 'cpu'\n",
"aws_region = boto3.Session().region_name\n",
"custom_image_name = \"462105765813.dkr.ecr.%s.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.5-tf-%s-py36\" % (aws_region, cpu_or_gpu)\n",
"framework = 'tf' # change to 'torch' for PyTorch training\n",
"custom_image_name = \"462105765813.dkr.ecr.%s.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.5-%s-%s-py36\" % (aws_region, framework, cpu_or_gpu)\n",
"custom_image_name"
]
},
Expand All @@ -206,8 +207,10 @@
"source": [
"## Write the Training Code\n",
"\n",
"The training code is written in the file “train-coach.py” which is uploaded in the /src directory. \n",
"First import the environment files and the preset files, and then define the main() function. "
"The training code is written in the file “train-rl-cartpole-ray.py” which is uploaded in the /src directory. \n",
"First import the environment files and the preset files, and then define the main() function. \n",
"\n",
"**Note**: If PyTorch is used, plese update the above training code and set `use_pytorch` to `True` in the config."
]
},
{
Expand All @@ -218,7 +221,7 @@
},
"outputs": [],
"source": [
"!pygmentize src/train-{job_name_prefix}.py"
"!pygmentize src/train-rl-cartpole-ray.py"
]
},
{
Expand Down Expand Up @@ -249,11 +252,12 @@
"\n",
"metric_definitions = RLEstimator.default_metric_definitions(RLToolkit.RAY)\n",
" \n",
"estimator = RLEstimator(entry_point=\"train-%s.py\" % job_name_prefix,\n",
"estimator = RLEstimator(entry_point=\"train-rl-cartpole-ray.py\",\n",
" source_dir='src',\n",
" dependencies=[\"common/sagemaker_rl\"],\n",
" image_name=custom_image_name,\n",
" role=role,\n",
" debugger_hook_config=False,\n",
" train_instance_type=instance_type,\n",
" train_instance_count=1,\n",
" output_path=s3_output_path,\n",
Expand Down Expand Up @@ -456,22 +460,17 @@
"print(\"Evaluation job: %s\" % job_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize the output \n",
"\n",
"Optionally, you can run the steps defined earlier to visualize the output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model deployment\n",
"\n",
"Now let us deploy the RL policy so that we can get the optimal action, given an environment observation."
"Now let us deploy the RL policy so that we can get the optimal action, given an environment observation.\n",
"\n",
"**Note**: Model deployment is supported for TensorFLow only at current stage. \n",
"\n",
"STOP HERE IF PYTORCH IS USED."
]
},
{
Expand Down Expand Up @@ -563,4 +562,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_experiment_config(self):
"training_iteration": 40
},
"config": {
"use_pytorch": False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If "use_pytorch" by default is False, it should be by default in the launcher code, not in the config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of the rllib experiment config. Launcher reads it via use_pytorch = experiment_config["training"]["config"].get("use_pytorch", False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use get("use_pytorch", False) we don't need to explicitly set it as False in the config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. It's more like an illustration of how a customer can switch between frameworks easily with RLlib on SageMaker.

"gamma": 0.99,
"kl_coeff": 1.0,
"num_sgd_iter": 20,
Expand Down