Skip to content

Commit 6dd109c

Browse files
Merge branch 'master' into fix-p2
2 parents 7fac7e2 + 98134d2 commit 6dd109c

File tree

6 files changed

+338
-10
lines changed

6 files changed

+338
-10
lines changed

src/sagemaker/image_uri_config/data-wrangler.json

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,35 @@
77
"ap-east-1": "707077482487",
88
"ap-northeast-1": "649008135260",
99
"ap-northeast-2": "131546521161",
10+
"ap-northeast-3": "913387583493",
11+
"ap-south-1": "089933028263",
12+
"ap-southeast-1": "119527597002",
13+
"ap-southeast-2": "422173101802",
14+
"ca-central-1": "557239378090",
15+
"eu-central-1": "024640144536",
16+
"eu-north-1": "054986407534",
17+
"eu-south-1": "488287956546",
18+
"eu-west-1": "245179582081",
19+
"eu-west-2": "894491911112",
20+
"eu-west-3": "807237891255",
21+
"me-south-1": "376037874950",
22+
"sa-east-1": "424196993095",
23+
"us-east-1": "663277389841",
24+
"us-east-2": "415577184552",
25+
"us-west-1": "926135532090",
26+
"us-west-2": "174368400705",
27+
"cn-north-1": "245909111842",
28+
"cn-northwest-1": "249157047649"
29+
},
30+
"repository": "sagemaker-data-wrangler-container"
31+
},
32+
"2.x": {
33+
"registries": {
34+
"af-south-1": "143210264188",
35+
"ap-east-1": "707077482487",
36+
"ap-northeast-1": "649008135260",
37+
"ap-northeast-2": "131546521161",
38+
"ap-northeast-3": "913387583493",
1039
"ap-south-1": "089933028263",
1140
"ap-southeast-1": "119527597002",
1241
"ap-southeast-2": "422173101802",

src/sagemaker/image_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SKLEARN_FRAMEWORK = "sklearn"
3737
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
3838
INFERENCE_GRAVITON = "inference_graviton"
39+
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
3940

4041

4142
@override_pipeline_parameter_var
@@ -461,6 +462,9 @@ def _validate_version_and_set_if_needed(version, config, framework):
461462

462463
return available_versions[0]
463464

465+
if version is None and framework in [DATA_WRANGLER_FRAMEWORK]:
466+
version = _get_latest_versions(available_versions)
467+
464468
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
465469
return version
466470

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains methods for starting up and accessing TensorBoard apps hosted on SageMaker"""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import logging
18+
import os
19+
import re
20+
21+
from typing import Optional
22+
from sagemaker.session import Session, NOTEBOOK_METADATA_FILE
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class TensorBoardApp(object):
28+
"""TensorBoardApp is a class for creating/accessing a TensorBoard app hosted on SageMaker."""
29+
30+
def __init__(self, region: Optional[str] = None):
31+
"""Initialize a TensorBoardApp object.
32+
33+
Args:
34+
region (str): The AWS Region, e.g. us-east-1. If not specified,
35+
one is created using the default AWS configuration chain.
36+
"""
37+
if region:
38+
self.region = region
39+
else:
40+
try:
41+
self.region = Session().boto_region_name
42+
except ValueError:
43+
raise ValueError(
44+
"Failed to get the Region information from the default config. Please either "
45+
"pass your Region manually as an input argument or set up the local AWS configuration."
46+
)
47+
48+
self._domain_id = None
49+
self._user_profile_name = None
50+
self._valid_domain_and_user = False
51+
self._get_domain_and_user()
52+
53+
def __str__(self):
54+
"""Return str(self)."""
55+
return f"TensorBoardApp(region={self.region})"
56+
57+
def __repr__(self):
58+
"""Return repr(self)."""
59+
return self.__str__()
60+
61+
def get_app_url(self, training_job_name: Optional[str] = None):
62+
"""Generates an unsigned URL to help access the TensorBoard application hosted in SageMaker.
63+
64+
For users that are already in SageMaker Studio, this method tries to get the domain id and the user
65+
profile from the Studio environment. If succeeded, the generated URL will direct to the TensorBoard
66+
application in SageMaker. Otherwise, it will direct to the TensorBoard landing page in the SageMaker
67+
console. For non-Studio users, the URL will direct to the TensorBoard landing page in the SageMaker
68+
console.
69+
70+
Args:
71+
training_job_name (str): Optional. The name of the training job to pre-load in TensorBoard.
72+
If nothing provided, the method still returns the TensorBoard application URL,
73+
but the application will not have any training jobs added for tracking. You can
74+
add training jobs later by using the SageMaker Data Manager UI.
75+
Default: ``None``
76+
77+
Returns:
78+
str: An unsigned URL for TensorBoard hosted on SageMaker.
79+
"""
80+
if self._valid_domain_and_user:
81+
url = "https://{}.studio.{}.sagemaker.aws/tensorboard/default".format(
82+
self._domain_id, self.region
83+
)
84+
if training_job_name is not None:
85+
self._validate_job_name(training_job_name)
86+
url += "/data/plugin/sagemaker_data_manager/add_folder_or_job?Redirect=True&Name={}".format(
87+
training_job_name
88+
)
89+
else:
90+
url += "/#sagemaker_data_manager"
91+
else:
92+
url = "https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/tensor-board-landing".format(
93+
region=self.region
94+
)
95+
if training_job_name is not None:
96+
self._validate_job_name(training_job_name)
97+
url += "/{}".format(training_job_name)
98+
99+
return url
100+
101+
def _get_domain_and_user(self):
102+
"""Get and validate studio domain id and user profile from NOTEBOOK_METADATA_FILE in studio environment.
103+
104+
Set _valid_domain_and_user to True if validation succeeded.
105+
"""
106+
if not os.path.isfile(NOTEBOOK_METADATA_FILE):
107+
return
108+
109+
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
110+
metadata = json.loads(f.read())
111+
self._domain_id = metadata.get("DomainId")
112+
self._user_profile_name = metadata.get("UserProfileName")
113+
if self._validate_domain_id() is True and self._validate_user_profile_name() is True:
114+
self._valid_domain_and_user = True
115+
else:
116+
logger.warning(
117+
"NOTEBOOK_METADATA_FILE detected but failed to get valid domain and user from it."
118+
)
119+
120+
def _validate_job_name(self, job_name: str):
121+
"""Validate training job name format."""
122+
job_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
123+
if not re.fullmatch(job_name_regex, job_name):
124+
raise ValueError(
125+
"Invalid job name. Job name must match regular expression {}".format(job_name_regex)
126+
)
127+
128+
def _validate_domain_id(self):
129+
"""Validate domain id format."""
130+
if self._domain_id is None or len(self._domain_id) > 63:
131+
return False
132+
return True
133+
134+
def _validate_user_profile_name(self):
135+
"""Validate user profile name format."""
136+
user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
137+
if self._user_profile_name is None or not re.fullmatch(
138+
user_profile_name_regex, self._user_profile_name
139+
):
140+
return False
141+
return True

tests/unit/sagemaker/image_uris/test_data_wrangler.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"ap-east-1": "707077482487",
2121
"ap-northeast-1": "649008135260",
2222
"ap-northeast-2": "131546521161",
23+
"ap-northeast-3": "913387583493",
2324
"ap-south-1": "089933028263",
2425
"ap-southeast-1": "119527597002",
2526
"ap-southeast-2": "422173101802",
@@ -39,15 +40,29 @@
3940
"cn-north-1": "245909111842",
4041
"cn-northwest-1": "249157047649",
4142
}
43+
VERSIONS = ["1.x", "2.x"]
4244

4345

4446
def test_data_wrangler_ecr_uri():
45-
for region in DATA_WRANGLER_ACCOUNTS.keys():
46-
actual_uri = image_uris.retrieve("data-wrangler", region=region)
47-
expected_uri = expected_uris.algo_uri(
48-
"sagemaker-data-wrangler-container",
49-
DATA_WRANGLER_ACCOUNTS[region],
50-
region,
51-
version="1.x",
52-
)
53-
assert expected_uri == actual_uri
47+
for version in VERSIONS:
48+
for region in DATA_WRANGLER_ACCOUNTS.keys():
49+
actual_uri = image_uris.retrieve("data-wrangler", region=region, version="1.x")
50+
expected_uri = expected_uris.algo_uri(
51+
"sagemaker-data-wrangler-container",
52+
DATA_WRANGLER_ACCOUNTS[region],
53+
region,
54+
version="1.x",
55+
)
56+
assert expected_uri == actual_uri
57+
58+
59+
def test_data_wrangler_ecr_uri_none():
60+
region = "us-west-2"
61+
actual_uri = image_uris.retrieve("data-wrangler", region=region)
62+
expected_uri = expected_uris.algo_uri(
63+
"sagemaker-data-wrangler-container",
64+
DATA_WRANGLER_ACCOUNTS[region],
65+
region,
66+
version=VERSIONS[-1],
67+
)
68+
assert expected_uri == actual_uri

tests/unit/sagemaker/wrangler/test_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
REGION = "us-west-2"
2424
DATA_WRANGLER_RECIPE_SOURCE = "s3://data_wrangler_flows/flow-26-18-43-16-0b48ac2e.flow"
2525
DATA_WRANGLER_CONTAINER_URI = (
26-
"174368400705.dkr.ecr.us-west-2.amazonaws.com/sagemaker-data-wrangler-container:1.x"
26+
"174368400705.dkr.ecr.us-west-2.amazonaws.com/sagemaker-data-wrangler-container:2.x"
2727
)
2828
MOCK_S3_URI = "s3://mock_data/mock.csv"
2929

tests/unit/test_tensorboard.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.interactive_apps.tensorboard import TensorBoardApp
16+
from unittest.mock import patch, mock_open, PropertyMock
17+
18+
import json
19+
import pytest
20+
21+
TEST_DOMAIN = "testdomain"
22+
TEST_USER_PROFILE = "testuser"
23+
TEST_REGION = "testregion"
24+
TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE})
25+
TEST_TRAINING_JOB = "testjob"
26+
27+
BASE_URL_STUDIO_FORMAT = "https://{}.studio.{}.sagemaker.aws/tensorboard/default"
28+
REDIRECT_STUDIO_FORMAT = (
29+
"/data/plugin/sagemaker_data_manager/add_folder_or_job?Redirect=True&Name={}"
30+
)
31+
BASE_URL_NON_STUDIO_FORMAT = (
32+
"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/tensor-board-landing"
33+
)
34+
REDIRECT_NON_STUDIO_FORMAT = "/{}"
35+
36+
37+
@patch("os.path.isfile")
38+
def test_tb_init_and_url_non_studio_user(mock_file_exists):
39+
"""
40+
Test TensorBoardApp for non Studio users.
41+
"""
42+
mock_file_exists.return_value = False
43+
tb_app = TensorBoardApp(TEST_REGION)
44+
assert tb_app.region == TEST_REGION
45+
assert tb_app._domain_id is None
46+
assert tb_app._user_profile_name is None
47+
assert tb_app._valid_domain_and_user is False
48+
49+
# test url without job redirect
50+
assert tb_app.get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION)
51+
52+
# test url with valid job redirect
53+
assert tb_app.get_app_url(TEST_TRAINING_JOB) == BASE_URL_NON_STUDIO_FORMAT.format(
54+
region=TEST_REGION
55+
) + REDIRECT_NON_STUDIO_FORMAT.format(TEST_TRAINING_JOB)
56+
57+
# test url with invalid job redirect
58+
with pytest.raises(ValueError):
59+
tb_app.get_app_url("invald_job_name!")
60+
61+
62+
@patch("os.path.isfile")
63+
def test_tb_init_and_url_studio_user_valid_medatada(mock_file_exists):
64+
"""
65+
Test TensorBoardApp for Studio user when the notebook metadata file provided by Studio is valid.
66+
"""
67+
mock_file_exists.return_value = True
68+
with patch("builtins.open", mock_open(read_data=TEST_NOTEBOOK_METADATA)):
69+
tb_app = TensorBoardApp(TEST_REGION)
70+
assert tb_app.region == TEST_REGION
71+
assert tb_app._domain_id == TEST_DOMAIN
72+
assert tb_app._user_profile_name == TEST_USER_PROFILE
73+
assert tb_app._valid_domain_and_user is True
74+
75+
# test url without job redirect
76+
assert (
77+
tb_app.get_app_url()
78+
== BASE_URL_STUDIO_FORMAT.format(TEST_DOMAIN, TEST_REGION) + "/#sagemaker_data_manager"
79+
)
80+
81+
# test url with valid job redirect
82+
assert tb_app.get_app_url(TEST_TRAINING_JOB) == BASE_URL_STUDIO_FORMAT.format(
83+
TEST_DOMAIN, TEST_REGION
84+
) + REDIRECT_STUDIO_FORMAT.format(TEST_TRAINING_JOB)
85+
86+
# test url with invalid job redirect
87+
with pytest.raises(ValueError):
88+
tb_app.get_app_url("invald_job_name!")
89+
90+
91+
@patch("os.path.isfile")
92+
def test_tb_init_and_url_studio_user_invalid_medatada(mock_file_exists):
93+
"""
94+
Test TensorBoardApp for Studio user when the notebook metadata file provided by Studio is invalid.
95+
"""
96+
mock_file_exists.return_value = True
97+
98+
# test file does not contain domain and user profle
99+
with patch("builtins.open", mock_open(read_data=json.dumps({"Fake": "Fake"}))):
100+
assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(
101+
region=TEST_REGION
102+
)
103+
104+
# test invalid user profile name
105+
with patch(
106+
"builtins.open",
107+
mock_open(read_data=json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": "u" * 64})),
108+
):
109+
assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(
110+
region=TEST_REGION
111+
)
112+
113+
# test invalid domain id
114+
with patch(
115+
"builtins.open",
116+
mock_open(
117+
read_data=json.dumps({"DomainId": "d" * 64, "UserProfileName": TEST_USER_PROFILE})
118+
),
119+
):
120+
assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(
121+
region=TEST_REGION
122+
)
123+
124+
125+
def test_tb_init_with_default_region():
126+
"""
127+
Test TensorBoardApp init when user does not provide region.
128+
"""
129+
# happy case
130+
with patch("sagemaker.Session.boto_region_name", new_callable=PropertyMock) as region_mock:
131+
region_mock.return_value = TEST_REGION
132+
tb_app = TensorBoardApp()
133+
assert tb_app.region == TEST_REGION
134+
135+
# no default region configured
136+
with patch("sagemaker.Session.boto_region_name", new_callable=PropertyMock) as region_mock:
137+
region_mock.side_effect = [ValueError()]
138+
with pytest.raises(ValueError):
139+
tb_app = TensorBoardApp()

0 commit comments

Comments
 (0)