Skip to content

Commit 348646d

Browse files
authored
Merge branch 'zwei' into json-update-script
2 parents 1e7799f + 284eddc commit 348646d

38 files changed

+222
-1520
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class Chainer(Framework):
3333
"""Handle end-to-end training and deployment of custom Chainer code."""
3434

35-
__framework_name__ = "chainer"
35+
_framework_name = "chainer"
3636

3737
# Hyperparameters
3838
_use_mpi = "sagemaker_use_mpi"
@@ -131,7 +131,7 @@ def __init__(
131131
validate_version_or_image_args(framework_version, py_version, image_uri)
132132
if py_version == "py2":
133133
logger.warning(
134-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
134+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
135135
)
136136
self.framework_version = framework_version
137137
self.py_version = py_version
@@ -272,7 +272,7 @@ class constructor
272272
init_params["image_uri"] = image_uri
273273
return init_params
274274

275-
if framework != cls.__framework_name__:
275+
if framework != cls._framework_name:
276276
raise ValueError(
277277
"Training job: {} didn't use image for requested framework".format(
278278
job_details["TrainingJobName"]

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ChainerModel(FrameworkModel):
5959
``Endpoint``.
6060
"""
6161

62-
__framework_name__ = "chainer"
62+
_framework_name = "chainer"
6363

6464
def __init__(
6565
self,
@@ -116,7 +116,7 @@ def __init__(
116116
validate_version_or_image_args(framework_version, py_version, image_uri)
117117
if py_version == "py2":
118118
logger.warning(
119-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119+
python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
120120
)
121121
self.framework_version = framework_version
122122
self.py_version = py_version
@@ -176,7 +176,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
176176
177177
"""
178178
return image_uris.retrieve(
179-
self.__framework_name__,
179+
self._framework_name,
180180
region_name,
181181
version=self.framework_version,
182182
py_version=self.py_version,

src/sagemaker/debugger.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,9 @@
2323

2424
import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
2525

26-
from sagemaker.utils import get_ecr_image_uri_prefix
27-
28-
RULES_ECR_REPO_NAME = "sagemaker-debugger-rules"
29-
30-
SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP = {
31-
"eu-north-1": {RULES_ECR_REPO_NAME: "314864569078"},
32-
"me-south-1": {RULES_ECR_REPO_NAME: "986000313247"},
33-
"ap-south-1": {RULES_ECR_REPO_NAME: "904829902805"},
34-
"eu-west-3": {RULES_ECR_REPO_NAME: "447278800020"},
35-
"us-east-2": {RULES_ECR_REPO_NAME: "915447279597"},
36-
"eu-west-1": {RULES_ECR_REPO_NAME: "929884845733"},
37-
"eu-central-1": {RULES_ECR_REPO_NAME: "482524230118"},
38-
"sa-east-1": {RULES_ECR_REPO_NAME: "818342061345"},
39-
"ap-east-1": {RULES_ECR_REPO_NAME: "199566480951"},
40-
"us-east-1": {RULES_ECR_REPO_NAME: "503895931360"},
41-
"ap-northeast-2": {RULES_ECR_REPO_NAME: "578805364391"},
42-
"eu-west-2": {RULES_ECR_REPO_NAME: "250201462417"},
43-
"ap-northeast-1": {RULES_ECR_REPO_NAME: "430734990657"},
44-
"us-west-2": {RULES_ECR_REPO_NAME: "895741380848"},
45-
"us-west-1": {RULES_ECR_REPO_NAME: "685455198987"},
46-
"ap-southeast-1": {RULES_ECR_REPO_NAME: "972752614525"},
47-
"ap-southeast-2": {RULES_ECR_REPO_NAME: "184798709955"},
48-
"ca-central-1": {RULES_ECR_REPO_NAME: "519511493484"},
49-
"cn-north-1": {RULES_ECR_REPO_NAME: "618459771430"},
50-
"cn-northwest-1": {RULES_ECR_REPO_NAME: "658757709296"},
51-
}
26+
from sagemaker import image_uris
27+
28+
framework_name = "debugger"
5229

5330

5431
def get_rule_container_image_uri(region):
@@ -61,9 +38,7 @@ def get_rule_container_image_uri(region):
6138
Returns:
6239
str: Formatted image uri for the given region and the rule container type
6340
"""
64-
registry_id = SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP.get(region).get(RULES_ECR_REPO_NAME)
65-
image_uri_prefix = get_ecr_image_uri_prefix(registry_id, region)
66-
return "{}/{}:latest".format(image_uri_prefix, RULES_ECR_REPO_NAME)
41+
return image_uris.retrieve(framework_name, region)
6742

6843

6944
class Rule(object):

src/sagemaker/deserializers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ def deserialize(self, stream, content_type):
266266
list: A list of JSON serializable objects.
267267
"""
268268
try:
269-
lines = stream.read().rstrip().split("\n")
269+
body = stream.read().decode("utf-8")
270+
lines = body.rstrip().split("\n")
270271
return [json.loads(line) for line in lines]
271272
finally:
272273
stream.close()

src/sagemaker/estimator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929
from sagemaker.debugger import DebuggerHookConfig
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import get_rule_container_image_uri
32-
from sagemaker.s3 import S3Uploader
32+
from sagemaker.s3 import S3Uploader, parse_s3_url
3333

3434
from sagemaker.fw_utils import (
3535
tar_and_upload_dir,
36-
parse_s3_url,
3736
UploadedCode,
3837
validate_source_dir,
3938
_region_supports_debugger,
@@ -1418,7 +1417,7 @@ class Framework(EstimatorBase):
14181417
such as training/deployment images and predictor instances.
14191418
"""
14201419

1421-
__framework_name__ = None
1420+
_framework_name = None
14221421

14231422
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
14241423
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
@@ -1816,7 +1815,7 @@ def train_image(self):
18161815
if self.image_uri:
18171816
return self.image_uri
18181817
return image_uris.retrieve(
1819-
self.__framework_name__,
1818+
self._framework_name,
18201819
self.sagemaker_session.boto_region_name,
18211820
instance_type=self.instance_type,
18221821
version=self.framework_version, # pylint: disable=no-member

0 commit comments

Comments
 (0)