Skip to content

Commit 0fc7201

Browse files
goelakashgoel-akas
andauthored
Add integration tests using Papermill library for RL notebooks. List of notebooks covered in the tests: (#1580)
1. rl_cartpole_coach/rl_cartpole_coach_gymEnv.ipynb 2. rl_cartpole_ray/rl_cartpole_ray_gymEnv.ipynb Co-authored-by: Akash Goel <[email protected]>
1 parent 586fb4b commit 0fc7201

File tree

7 files changed

+125
-5
lines changed

7 files changed

+125
-5
lines changed

reinforcement_learning/rl_cartpole_coach/rl_cartpole_coach_gymEnv.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@
115115
"cell_type": "code",
116116
"execution_count": null,
117117
"metadata": {
118-
"collapsed": true
118+
"collapsed": true,
119+
"tags": [
120+
"parameters"
121+
]
119122
},
120123
"outputs": [],
121124
"source": [
@@ -573,7 +576,7 @@
573576
"name": "python",
574577
"nbconvert_exporter": "python",
575578
"pygments_lexer": "ipython3",
576-
"version": "3.7.7"
579+
"version": "3.7.9"
577580
},
578581
"notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
579582
},

reinforcement_learning/rl_cartpole_ray/rl_cartpole_ray_gymEnv.ipynb

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@
114114
{
115115
"cell_type": "code",
116116
"execution_count": null,
117-
"metadata": {},
117+
"metadata": {
118+
"tags": [
119+
"parameters"
120+
]
121+
},
118122
"outputs": [],
119123
"source": [
120124
"# run in local_mode on this machine, or as a SageMaker TrainingJob?\n",
@@ -239,6 +243,19 @@
239243
"5. Define the metrics definitions that you are interested in capturing in your logs. These can also be visualized in CloudWatch and SageMaker Notebooks. "
240244
]
241245
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": null,
249+
"metadata": {
250+
"tags": [
251+
"parameters"
252+
]
253+
},
254+
"outputs": [],
255+
"source": [
256+
"train_instance_count = 1"
257+
]
258+
},
242259
{
243260
"cell_type": "code",
244261
"execution_count": null,
@@ -259,7 +276,7 @@
259276
" role=role,\n",
260277
" debugger_hook_config=False,\n",
261278
" train_instance_type=instance_type,\n",
262-
" train_instance_count=1,\n",
279+
" train_instance_count=train_instance_count,\n",
263280
" output_path=s3_output_path,\n",
264281
" base_job_name=job_name_prefix,\n",
265282
" metric_definitions=metric_definitions,\n",
@@ -556,7 +573,7 @@
556573
"name": "python",
557574
"nbconvert_exporter": "python",
558575
"pygments_lexer": "ipython3",
559-
"version": "3.7.7"
576+
"version": "3.7.9"
560577
},
561578
"notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
562579
},
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pandas
2+
papermill
3+
tabulate

reinforcement_learning/testconfig.csv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
train_instance_count,instance_type,local_mode
2+
1,ml.m4.xlarge,False
3+
2,ml.m4.xlarge,False
4+
1,ml.p3.2xlarge,False
5+
2,ml.p3.2xlarge,False
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
rl_cartpole_ray/rl_cartpole_ray_gymEnv.ipynb
2+
rl_cartpole_coach/rl_cartpole_coach_gymEnv.ipynb

reinforcement_learning/testrunner.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
import itertools
5+
import json
6+
import os
7+
import subprocess
8+
import sys
9+
import time
10+
try:
11+
import pandas as pd
12+
import papermill
13+
from tabulate import tabulate
14+
except ImportError:
15+
sys.exit("""Some libraries are missing. Please install them by running `pip install -r test_requirements.txt`.""")
16+
17+
# CONSTANTS
18+
TEST_NOTEBOOKS_FILE = 'testnotebooks.txt'
19+
TEST_CONFIG_FILE = 'testconfig.csv'
20+
SUCCESSES = 0
21+
EXCEPTIONS = 0
22+
SUCCESSFUL_EXECUTIONS = []
23+
FAILED_EXECUTIONS = []
24+
CELL_EXECUTION_TIMEOUT_SECONDS = 1200
25+
26+
27+
# helper functions
28+
def run_notebook(nb_path, test_config):
29+
dir_name = os.path.dirname(nb_path)
30+
nb_name = os.path.basename(nb_path)
31+
output_nb_name = "output_{}.ipynb".format(nb_name)
32+
os.chdir(dir_name)
33+
print("Current directory: {}".format(os.getcwd()))
34+
global SUCCESSES
35+
global EXCEPTIONS
36+
for i in range(len(test_config)):
37+
params = json.loads(test_config.loc[i].to_json())
38+
# Coach notebooks support only single instance training, so skip the tests with multiple EC2 instances
39+
if 'coach' in nb_name.lower() and params['train_instance_count'] > 1:
40+
continue
41+
print("\nTEST: " + nb_name + " with parameters " + str(params))
42+
process = None
43+
try:
44+
papermill.execute_notebook(nb_name, output_nb_name, parameters=params, execution_timeout=CELL_EXECUTION_TIMEOUT_SECONDS, log_output=True)
45+
SUCCESSES += 1
46+
SUCCESSFUL_EXECUTIONS.append(dict({'notebook':nb_name, 'params':params}))
47+
except BaseException as error:
48+
print('An exception occurred: {}'.format(error))
49+
EXCEPTIONS += 1
50+
FAILED_EXECUTIONS.append(dict({'notebook':nb_name, 'params':params}))
51+
52+
def print_notebook_executions(nb_list_with_params):
53+
# This expects a list of dict type items.
54+
# E.g. [{'nb_name':'foo', 'params':'bar'}]
55+
if not nb_list_with_params:
56+
print("None")
57+
return
58+
vals = []
59+
for nb_dict in nb_list_with_params:
60+
val = []
61+
for k,v in nb_dict.items():
62+
val.append(v)
63+
vals.append(val)
64+
keys = [k for k in nb_list_with_params[0].keys()]
65+
print(tabulate(pd.DataFrame([v for v in vals], columns=keys), showindex=False))
66+
67+
68+
69+
notebooks_list = open(TEST_NOTEBOOKS_FILE).readlines()
70+
config = pd.read_csv(TEST_CONFIG_FILE)
71+
ROOT = os.path.abspath('.')
72+
73+
# Run tests on each notebook listed in the config.
74+
print("Test Configuration: ")
75+
print(config)
76+
for nb_path in notebooks_list:
77+
os.chdir(ROOT)
78+
print("Testing: {}".format(nb_path))
79+
run_notebook(nb_path.strip(), config)
80+
81+
# Print summary of tests ran.
82+
print("Summary: {}/{} tests passed.".format(SUCCESSES, SUCCESSES + EXCEPTIONS))
83+
print("Successful executions: ")
84+
print_notebook_executions(SUCCESSFUL_EXECUTIONS)
85+
86+
# Throw exception if any test fails, so that the CodeBuild also fails.
87+
if EXCEPTIONS > 0:
88+
print("Failed executions: ")
89+
print_notebook_executions(FAILED_EXECUTIONS)
90+
raise Exception("Test did not complete successfully")

0 commit comments

Comments
 (0)