-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Make RL training compatible with PyTorch #1520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,8 @@ def change_permissions_recursive(path, mode): | |
for root, dirs, files in os.walk(path, topdown=False): | ||
for dir in [os.path.join(root, d) for d in dirs]: | ||
os.chmod(dir, mode) | ||
for file in [os.path.join(root, f) for f in files]: | ||
os.chmod(file, mode) | ||
for file in [os.path.join(root, f) for f in files]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this change, this seems changing the logic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an existing bug since the launch of SageMaker RL. There was plan to fix it but got delayed. |
||
os.chmod(file, mode) | ||
|
||
|
||
def export_tf_serving(agent, output_dir): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
import pytest | ||
from mock import Mock, MagicMock, patch | ||
|
||
from sagemaker_rl.ray_launcher import SageMakerRayLauncher | ||
|
||
@patch("sagemaker_rl.ray_launcher.SageMakerRayLauncher.__init__", return_value=None) | ||
@patch("sagemaker_rl.ray_launcher.change_permissions_recursive") | ||
def test_pytorch_save_checkpoint_and_serving_model(change_permission, launcher_init): | ||
launcher = SageMakerRayLauncher() | ||
launcher.copy_checkpoints_to_model_output = Mock() | ||
launcher.create_tf_serving_model = Mock() | ||
launcher.save_experiment_config = Mock() | ||
|
||
launcher.save_checkpoint_and_serving_model(use_pytorch=True) | ||
launcher.create_tf_serving_model.assert_not_called() | ||
launcher.save_checkpoint_and_serving_model(use_pytorch=False) | ||
launcher.create_tf_serving_model.assert_called_once() | ||
assert 4 == change_permission.call_count |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
"---\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SageMaker Debugger is a new feature released in last re:Invent. Prior to that the summary data collected by |
||
"## Introduction\n", | ||
"\n", | ||
"In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Ray RLlib \n", | ||
"In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Ray RLlib. You can choose either TensorFlow or PyTorch as your underlying DL framework.\n", | ||
"\n", | ||
"(For a similar example using Coach library, see this [link](../rl_cartpole_coach/rl_cartpole_coach_gymEnv.ipynb). Another Cart-pole example using Coach library and offline data can be found [here](../rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb).)\n", | ||
"\n", | ||
|
@@ -196,7 +196,8 @@ | |
"\n", | ||
"cpu_or_gpu = 'gpu' if instance_type.startswith('ml.p') else 'cpu'\n", | ||
"aws_region = boto3.Session().region_name\n", | ||
"custom_image_name = \"462105765813.dkr.ecr.%s.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.5-tf-%s-py36\" % (aws_region, cpu_or_gpu)\n", | ||
"framework = 'tf' # change to 'torch' for PyTorch training\n", | ||
"custom_image_name = \"462105765813.dkr.ecr.%s.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.5-%s-%s-py36\" % (aws_region, framework, cpu_or_gpu)\n", | ||
"custom_image_name" | ||
] | ||
}, | ||
|
@@ -206,8 +207,10 @@ | |
"source": [ | ||
"## Write the Training Code\n", | ||
"\n", | ||
"The training code is written in the file “train-coach.py” which is uploaded in the /src directory. \n", | ||
"First import the environment files and the preset files, and then define the main() function. " | ||
"The training code is written in the file “train-rl-cartpole-ray.py” which is uploaded in the /src directory. \n", | ||
"First import the environment files and the preset files, and then define the main() function. \n", | ||
"\n", | ||
"**Note**: If PyTorch is used, plese update the above training code and set `use_pytorch` to `True` in the config." | ||
] | ||
}, | ||
{ | ||
|
@@ -218,7 +221,7 @@ | |
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pygmentize src/train-{job_name_prefix}.py" | ||
"!pygmentize src/train-rl-cartpole-ray.py" | ||
] | ||
}, | ||
{ | ||
|
@@ -249,11 +252,12 @@ | |
"\n", | ||
"metric_definitions = RLEstimator.default_metric_definitions(RLToolkit.RAY)\n", | ||
" \n", | ||
"estimator = RLEstimator(entry_point=\"train-%s.py\" % job_name_prefix,\n", | ||
"estimator = RLEstimator(entry_point=\"train-rl-cartpole-ray.py\",\n", | ||
" source_dir='src',\n", | ||
" dependencies=[\"common/sagemaker_rl\"],\n", | ||
" image_name=custom_image_name,\n", | ||
" role=role,\n", | ||
" debugger_hook_config=False,\n", | ||
" train_instance_type=instance_type,\n", | ||
" train_instance_count=1,\n", | ||
" output_path=s3_output_path,\n", | ||
|
@@ -456,22 +460,17 @@ | |
"print(\"Evaluation job: %s\" % job_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Visualize the output \n", | ||
"\n", | ||
"Optionally, you can run the steps defined earlier to visualize the output." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Model deployment\n", | ||
"\n", | ||
"Now let us deploy the RL policy so that we can get the optimal action, given an environment observation." | ||
"Now let us deploy the RL policy so that we can get the optimal action, given an environment observation.\n", | ||
"\n", | ||
"**Note**: Model deployment is supported for TensorFLow only at current stage. \n", | ||
"\n", | ||
"STOP HERE IF PYTORCH IS USED." | ||
] | ||
}, | ||
{ | ||
|
@@ -563,4 +562,4 @@ | |
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ def get_experiment_config(self): | |
"training_iteration": 40 | ||
}, | ||
"config": { | ||
"use_pytorch": False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If "use_pytorch" by default is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is part of the rllib experiment config. Launcher reads it via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. It's more like an illustration of how a customer can switch between frameworks easily with RLlib on SageMaker. |
||
"gamma": 0.99, | ||
"kl_coeff": 1.0, | ||
"num_sgd_iter": 20, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the conflicts here? And isn't
debugger_hook_config
by default isFalse
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not specified, a default one is created using the estimator's
output_path
, unless the region does not support SageMaker Debugger (link).tf.summary
generated by ray cannot be picked up appropriately by SM Debgger.