Skip to content

Commit 4eec9a7

Browse files
authored
Merge branch 'zwei' into remove-serde-parameters
2 parents 1f2de10 + 4632611 commit 4eec9a7

File tree

13 files changed

+171
-128
lines changed

13 files changed

+171
-128
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"scope": ["monitoring"],
3+
"versions": {
4+
"": {
5+
"registries": {
6+
"eu-north-1": "895015795356",
7+
"me-south-1": "607024016150",
8+
"ap-south-1": "126357580389",
9+
"us-east-2": "777275614652",
10+
"eu-west-1": "468650794304",
11+
"eu-central-1": "048819808253",
12+
"sa-east-1": "539772159869",
13+
"ap-east-1": "001633400207",
14+
"us-east-1": "156813124566",
15+
"ap-northeast-2": "709848358524",
16+
"eu-west-2": "749857270468",
17+
"eu-west-3": "680080141114",
18+
"ap-northeast-1": "574779866223",
19+
"us-west-2": "159807026194",
20+
"us-west-1": "890145073186",
21+
"ap-southeast-1": "245545462676",
22+
"ap-southeast-2": "563025443158",
23+
"ca-central-1": "536280801234",
24+
"cn-north-1": "453000072557",
25+
"cn-northwest-1": "453252182341"
26+
},
27+
"repository": "sagemaker-model-monitor-analyzer"
28+
}
29+
}
30+
}

src/sagemaker/image_uris.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
logger = logging.getLogger(__name__)
2424

25-
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"
25+
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
2626

2727

2828
def retrieve(
@@ -69,14 +69,17 @@ def retrieve(
6969
registry = _registry_from_region(region, version_config["registries"])
7070
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
7171

72+
repo = version_config["repository"]
73+
7274
processor = _processor(
7375
instance_type, config.get("processors") or version_config.get("processors")
7476
)
7577
tag = _format_tag(version_config.get("tag_prefix", version), processor, py_version)
7678

77-
repo = version_config["repository"]
79+
if tag:
80+
repo += ":{}".format(tag)
7881

79-
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)
82+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
8083

8184

8285
def _config_for_framework_and_scope(framework, image_scope, accelerator_type=None):

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,16 @@
2626
from six.moves.urllib.parse import urlparse
2727
from botocore.exceptions import ClientError
2828

29+
from sagemaker import image_uris
2930
from sagemaker.exceptions import UnexpectedStatusException
3031
from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics
3132
from sagemaker.network import NetworkConfig
3233
from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput
3334
from sagemaker.s3 import S3Uploader
3435
from sagemaker.session import Session
35-
from sagemaker.utils import name_from_base, retries, get_ecr_image_uri_prefix
36-
37-
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS = "{}/sagemaker-model-monitor-analyzer"
38-
39-
_DEFAULT_MONITOR_IMAGE_REGION_ACCOUNT_MAPPING = {
40-
"eu-north-1": "895015795356",
41-
"me-south-1": "607024016150",
42-
"ap-south-1": "126357580389",
43-
"us-east-2": "777275614652",
44-
"eu-west-1": "468650794304",
45-
"eu-central-1": "048819808253",
46-
"sa-east-1": "539772159869",
47-
"ap-east-1": "001633400207",
48-
"us-east-1": "156813124566",
49-
"ap-northeast-2": "709848358524",
50-
"eu-west-2": "749857270468",
51-
"eu-west-3": "680080141114",
52-
"ap-northeast-1": "574779866223",
53-
"us-west-2": "159807026194",
54-
"us-west-1": "890145073186",
55-
"ap-southeast-1": "245545462676",
56-
"ap-southeast-2": "563025443158",
57-
"ca-central-1": "536280801234",
58-
"cn-north-1": "453000072557",
59-
"cn-northwest-1": "453252182341",
60-
}
36+
from sagemaker.utils import name_from_base, retries
37+
38+
DEFAULT_REPOSITORY_NAME = "sagemaker-model-monitor-analyzer"
6139

6240
STATISTICS_JSON_DEFAULT_FILE_NAME = "statistics.json"
6341
CONSTRAINTS_JSON_DEFAULT_FILE_NAME = "constraints.json"
@@ -89,6 +67,8 @@
8967

9068
_LOGGER = logging.getLogger(__name__)
9169

70+
framework_name = "model-monitor"
71+
9272

9373
class ModelMonitor(object):
9474
"""Sets up Amazon SageMaker Monitoring Schedules and baseline suggestions. Use this class when
@@ -1787,9 +1767,7 @@ def _get_default_image_uri(region):
17871767
Returns:
17881768
str: The Default Model Monitoring image uri based on the region.
17891769
"""
1790-
return _DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS.format(
1791-
get_ecr_image_uri_prefix(_DEFAULT_MONITOR_IMAGE_REGION_ACCOUNT_MAPPING[region], region)
1792-
)
1770+
return image_uris.retrieve(framework=framework_name, region=region)
17931771

17941772

17951773
class BaseliningJob(ProcessingJob):

src/sagemaker/predictor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.utils import name_from_base
2121

2222
from sagemaker.model_monitor.model_monitoring import (
23-
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS,
23+
DEFAULT_REPOSITORY_NAME,
2424
ModelMonitor,
2525
DefaultModelMonitor,
2626
)
@@ -348,10 +348,7 @@ def list_monitors(self):
348348
image_uri = schedule["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
349349
"MonitoringAppSpecification"
350350
]["ImageUri"]
351-
index_after_placeholders = _DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS.rfind("{}")
352-
if image_uri.endswith(
353-
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS[index_after_placeholders + len("{}") :]
354-
):
351+
if image_uri.endswith(DEFAULT_REPOSITORY_NAME):
355352
monitors.append(
356353
DefaultModelMonitor.attach(
357354
monitor_schedule_name=schedule_name,

src/sagemaker/rl/estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"0.11.0": {"tensorflow": "1.11", "mxnet": "1.3"},
3737
"0.11.1": {"tensorflow": "1.12"},
3838
"0.11": {"tensorflow": "1.12", "mxnet": "1.3"},
39+
"1.0.0": {"tensorflow": "1.12"},
3940
},
4041
"ray": {
4142
"0.5.3": {"tensorflow": "1.11"},
@@ -68,7 +69,7 @@ class RLEstimator(Framework):
6869

6970
COACH_LATEST_VERSION_TF = "0.11.1"
7071
COACH_LATEST_VERSION_MXNET = "0.11.0"
71-
RAY_LATEST_VERSION = "0.6.5"
72+
RAY_LATEST_VERSION = "0.8.5"
7273

7374
def __init__(
7475
self,

tests/conftest.py

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from sagemaker import Session, image_uris, utils
2424
from sagemaker.local import LocalSession
25-
from sagemaker.rl import RLEstimator
2625
import tests.integ
2726

2827
DEFAULT_REGION = "us-west-2"
@@ -41,15 +40,20 @@
4140

4241
FRAMEWORKS_FOR_GENERATED_VERSION_FIXTURES = (
4342
"chainer",
43+
"coach_mxnet",
44+
"coach_tensorflow",
4445
"inferentia_mxnet",
4546
"inferentia_tensorflow",
4647
"mxnet",
4748
"neo_mxnet",
4849
"neo_pytorch",
4950
"neo_tensorflow",
5051
"pytorch",
52+
"ray_pytorch",
53+
"ray_tensorflow",
5154
"sklearn",
5255
"tensorflow",
56+
"vw",
5357
"xgboost",
5458
)
5559

@@ -181,46 +185,6 @@ def _tf_py_version(tf_version, request):
181185
return "py37"
182186

183187

184-
@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"])
185-
def rl_coach_tf_version(request):
186-
return request.param
187-
188-
189-
@pytest.fixture(scope="module", params=["0.11", "0.11.0"])
190-
def rl_coach_mxnet_version(request):
191-
return request.param
192-
193-
194-
@pytest.fixture(scope="module", params=["0.5", "0.5.3", "0.6", "0.6.5", "0.8.2", "0.8.5"])
195-
def rl_ray_tf_version(request):
196-
return request.param
197-
198-
199-
@pytest.fixture(scope="module", params=["0.8.5"])
200-
def rl_ray_pytorch_version(request):
201-
return request.param
202-
203-
204-
@pytest.fixture(scope="module", params=["8.7.0"])
205-
def rl_vw_version(request):
206-
return request.param
207-
208-
209-
@pytest.fixture(scope="module")
210-
def rl_coach_mxnet_full_version():
211-
return RLEstimator.COACH_LATEST_VERSION_MXNET
212-
213-
214-
@pytest.fixture(scope="module")
215-
def rl_coach_tf_full_version():
216-
return RLEstimator.COACH_LATEST_VERSION_TF
217-
218-
219-
@pytest.fixture(scope="module")
220-
def rl_ray_full_version():
221-
return RLEstimator.RAY_LATEST_VERSION
222-
223-
224188
@pytest.fixture(scope="module")
225189
def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_latest_version):
226190
"""Fixture for TF tests that test both training and inference.

tests/data/ray_cartpole/train_ray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from ray.tune.logger import pretty_print
66

77
# Based on https://github.com/ray-project/ray/blob/master/doc/source/rllib-training.rst#python-api
8-
ray.init(log_to_driver=False)
8+
ray.init(log_to_driver=False, webui_host="127.0.0.1")
99
config = ppo.DEFAULT_CONFIG.copy()
1010
config["num_gpus"] = int(os.environ.get("SM_NUM_GPUS", 0))
1111
checkpoint_dir = os.environ.get("SM_MODEL_DIR", "/Users/nadzeya/gym")
1212
config["num_workers"] = 1
13-
agent = ppo.PPOAgent(config=config, env="CartPole-v0")
13+
agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
1414

1515
# Can optionally call agent.restore(path) to load a checkpoint.
1616

tests/integ/test_rl.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525

2626
@pytest.mark.canary_quick
27-
def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instance_type):
27+
def test_coach_mxnet(sagemaker_session, coach_mxnet_latest_version, cpu_instance_type):
2828
estimator = _test_coach(
29-
sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version, cpu_instance_type
29+
sagemaker_session, RLFramework.MXNET, coach_mxnet_latest_version, cpu_instance_type
3030
)
3131
job_name = unique_name_from_base("test-coach-mxnet")
3232

@@ -51,9 +51,12 @@ def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instanc
5151
assert 0 < action[0][1] < 1
5252

5353

54-
def test_coach_tf(sagemaker_session, rl_coach_tf_full_version, cpu_instance_type):
54+
def test_coach_tf(sagemaker_session, coach_tensorflow_latest_version, cpu_instance_type):
5555
estimator = _test_coach(
56-
sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version, cpu_instance_type
56+
sagemaker_session,
57+
RLFramework.TENSORFLOW,
58+
coach_tensorflow_latest_version,
59+
cpu_instance_type,
5760
)
5861
job_name = unique_name_from_base("test-coach-tf")
5962

@@ -96,7 +99,7 @@ def _test_coach(sagemaker_session, rl_framework, rl_coach_version, cpu_instance_
9699

97100

98101
@pytest.mark.canary_quick
99-
def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
102+
def test_ray_tf(sagemaker_session, ray_tensorflow_latest_version, cpu_instance_type):
100103
source_dir = os.path.join(DATA_DIR, "ray_cartpole")
101104
cartpole = "train_ray.py"
102105

@@ -105,7 +108,7 @@ def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
105108
source_dir=source_dir,
106109
toolkit=RLToolkit.RAY,
107110
framework=RLFramework.TENSORFLOW,
108-
toolkit_version=rl_ray_full_version,
111+
toolkit_version=ray_tensorflow_latest_version,
109112
sagemaker_session=sagemaker_session,
110113
role="SageMakerRole",
111114
instance_type=cpu_instance_type,

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
}
2020
DOMAIN = "amazonaws.com"
2121
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}"
22+
MONITOR_URI_FORMAT = "{}.dkr.ecr.{}.{}/sagemaker-model-monitor-analyzer"
2223
REGION = "us-west-2"
2324

2425

@@ -34,3 +35,8 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r
3435
def algo_uri(algo, account, region, version=1):
3536
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
3637
return IMAGE_URI_FORMAT.format(account, region, domain, algo, version)
38+
39+
40+
def monitor_uri(account, region=REGION):
41+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
42+
return MONITOR_URI_FORMAT.format(account, region, domain)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris, regions
17+
18+
ACCOUNTS = {
19+
"eu-north-1": "895015795356",
20+
"me-south-1": "607024016150",
21+
"ap-south-1": "126357580389",
22+
"us-east-2": "777275614652",
23+
"eu-west-1": "468650794304",
24+
"eu-central-1": "048819808253",
25+
"sa-east-1": "539772159869",
26+
"ap-east-1": "001633400207",
27+
"us-east-1": "156813124566",
28+
"ap-northeast-2": "709848358524",
29+
"eu-west-2": "749857270468",
30+
"eu-west-3": "680080141114",
31+
"ap-northeast-1": "574779866223",
32+
"us-west-2": "159807026194",
33+
"us-west-1": "890145073186",
34+
"ap-southeast-1": "245545462676",
35+
"ap-southeast-2": "563025443158",
36+
"ca-central-1": "536280801234",
37+
"cn-north-1": "453000072557",
38+
"cn-northwest-1": "453252182341",
39+
}
40+
41+
42+
def test_model_monitor():
43+
for region in regions.regions():
44+
if region in ACCOUNTS.keys():
45+
uri = image_uris.retrieve("model-monitor", region=region)
46+
47+
expected = expected_uris.monitor_uri(ACCOUNTS[region], region)
48+
assert expected == uri

0 commit comments

Comments
 (0)