Skip to content

Commit 98134d2

Browse files
waytrue17Wei Chu
andauthored
feature: Introduce TensorBoard app class (#3810)
Co-authored-by: Wei Chu <[email protected]>
1 parent d54edb4 commit 98134d2

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains methods for starting up and accessing TensorBoard apps hosted on SageMaker"""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import logging
18+
import os
19+
import re
20+
21+
from typing import Optional
22+
from sagemaker.session import Session, NOTEBOOK_METADATA_FILE
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class TensorBoardApp(object):
28+
"""TensorBoardApp is a class for creating/accessing a TensorBoard app hosted on SageMaker."""
29+
30+
def __init__(self, region: Optional[str] = None):
31+
"""Initialize a TensorBoardApp object.
32+
33+
Args:
34+
region (str): The AWS Region, e.g. us-east-1. If not specified,
35+
one is created using the default AWS configuration chain.
36+
"""
37+
if region:
38+
self.region = region
39+
else:
40+
try:
41+
self.region = Session().boto_region_name
42+
except ValueError:
43+
raise ValueError(
44+
"Failed to get the Region information from the default config. Please either "
45+
"pass your Region manually as an input argument or set up the local AWS configuration."
46+
)
47+
48+
self._domain_id = None
49+
self._user_profile_name = None
50+
self._valid_domain_and_user = False
51+
self._get_domain_and_user()
52+
53+
def __str__(self):
54+
"""Return str(self)."""
55+
return f"TensorBoardApp(region={self.region})"
56+
57+
def __repr__(self):
58+
"""Return repr(self)."""
59+
return self.__str__()
60+
61+
def get_app_url(self, training_job_name: Optional[str] = None):
62+
"""Generates an unsigned URL to help access the TensorBoard application hosted in SageMaker.
63+
64+
For users that are already in SageMaker Studio, this method tries to get the domain id and the user
65+
profile from the Studio environment. If succeeded, the generated URL will direct to the TensorBoard
66+
application in SageMaker. Otherwise, it will direct to the TensorBoard landing page in the SageMaker
67+
console. For non-Studio users, the URL will direct to the TensorBoard landing page in the SageMaker
68+
console.
69+
70+
Args:
71+
training_job_name (str): Optional. The name of the training job to pre-load in TensorBoard.
72+
If nothing provided, the method still returns the TensorBoard application URL,
73+
but the application will not have any training jobs added for tracking. You can
74+
add training jobs later by using the SageMaker Data Manager UI.
75+
Default: ``None``
76+
77+
Returns:
78+
str: An unsigned URL for TensorBoard hosted on SageMaker.
79+
"""
80+
if self._valid_domain_and_user:
81+
url = "https://{}.studio.{}.sagemaker.aws/tensorboard/default".format(
82+
self._domain_id, self.region
83+
)
84+
if training_job_name is not None:
85+
self._validate_job_name(training_job_name)
86+
url += "/data/plugin/sagemaker_data_manager/add_folder_or_job?Redirect=True&Name={}".format(
87+
training_job_name
88+
)
89+
else:
90+
url += "/#sagemaker_data_manager"
91+
else:
92+
url = "https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/tensor-board-landing".format(
93+
region=self.region
94+
)
95+
if training_job_name is not None:
96+
self._validate_job_name(training_job_name)
97+
url += "/{}".format(training_job_name)
98+
99+
return url
100+
101+
def _get_domain_and_user(self):
102+
"""Get and validate studio domain id and user profile from NOTEBOOK_METADATA_FILE in studio environment.
103+
104+
Set _valid_domain_and_user to True if validation succeeded.
105+
"""
106+
if not os.path.isfile(NOTEBOOK_METADATA_FILE):
107+
return
108+
109+
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
110+
metadata = json.loads(f.read())
111+
self._domain_id = metadata.get("DomainId")
112+
self._user_profile_name = metadata.get("UserProfileName")
113+
if self._validate_domain_id() is True and self._validate_user_profile_name() is True:
114+
self._valid_domain_and_user = True
115+
else:
116+
logger.warning(
117+
"NOTEBOOK_METADATA_FILE detected but failed to get valid domain and user from it."
118+
)
119+
120+
def _validate_job_name(self, job_name: str):
121+
"""Validate training job name format."""
122+
job_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
123+
if not re.fullmatch(job_name_regex, job_name):
124+
raise ValueError(
125+
"Invalid job name. Job name must match regular expression {}".format(job_name_regex)
126+
)
127+
128+
def _validate_domain_id(self):
129+
"""Validate domain id format."""
130+
if self._domain_id is None or len(self._domain_id) > 63:
131+
return False
132+
return True
133+
134+
def _validate_user_profile_name(self):
135+
"""Validate user profile name format."""
136+
user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
137+
if self._user_profile_name is None or not re.fullmatch(
138+
user_profile_name_regex, self._user_profile_name
139+
):
140+
return False
141+
return True

tests/unit/test_tensorboard.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.interactive_apps.tensorboard import TensorBoardApp
16+
from unittest.mock import patch, mock_open, PropertyMock
17+
18+
import json
19+
import pytest
20+
21+
TEST_DOMAIN = "testdomain"
22+
TEST_USER_PROFILE = "testuser"
23+
TEST_REGION = "testregion"
24+
TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE})
25+
TEST_TRAINING_JOB = "testjob"
26+
27+
BASE_URL_STUDIO_FORMAT = "https://{}.studio.{}.sagemaker.aws/tensorboard/default"
28+
REDIRECT_STUDIO_FORMAT = (
29+
"/data/plugin/sagemaker_data_manager/add_folder_or_job?Redirect=True&Name={}"
30+
)
31+
BASE_URL_NON_STUDIO_FORMAT = (
32+
"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/tensor-board-landing"
33+
)
34+
REDIRECT_NON_STUDIO_FORMAT = "/{}"
35+
36+
37+
@patch("os.path.isfile")
38+
def test_tb_init_and_url_non_studio_user(mock_file_exists):
39+
"""
40+
Test TensorBoardApp for non Studio users.
41+
"""
42+
mock_file_exists.return_value = False
43+
tb_app = TensorBoardApp(TEST_REGION)
44+
assert tb_app.region == TEST_REGION
45+
assert tb_app._domain_id is None
46+
assert tb_app._user_profile_name is None
47+
assert tb_app._valid_domain_and_user is False
48+
49+
# test url without job redirect
50+
assert tb_app.get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION)
51+
52+
# test url with valid job redirect
53+
assert tb_app.get_app_url(TEST_TRAINING_JOB) == BASE_URL_NON_STUDIO_FORMAT.format(
54+
region=TEST_REGION
55+
) + REDIRECT_NON_STUDIO_FORMAT.format(TEST_TRAINING_JOB)
56+
57+
# test url with invalid job redirect
58+
with pytest.raises(ValueError):
59+
tb_app.get_app_url("invald_job_name!")
60+
61+
62+
@patch("os.path.isfile")
63+
def test_tb_init_and_url_studio_user_valid_medatada(mock_file_exists):
64+
"""
65+
Test TensorBoardApp for Studio user when the notebook metadata file provided by Studio is valid.
66+
"""
67+
mock_file_exists.return_value = True
68+
with patch("builtins.open", mock_open(read_data=TEST_NOTEBOOK_METADATA)):
69+
tb_app = TensorBoardApp(TEST_REGION)
70+
assert tb_app.region == TEST_REGION
71+
assert tb_app._domain_id == TEST_DOMAIN
72+
assert tb_app._user_profile_name == TEST_USER_PROFILE
73+
assert tb_app._valid_domain_and_user is True
74+
75+
# test url without job redirect
76+
assert (
77+
tb_app.get_app_url()
78+
== BASE_URL_STUDIO_FORMAT.format(TEST_DOMAIN, TEST_REGION) + "/#sagemaker_data_manager"
79+
)
80+
81+
# test url with valid job redirect
82+
assert tb_app.get_app_url(TEST_TRAINING_JOB) == BASE_URL_STUDIO_FORMAT.format(
83+
TEST_DOMAIN, TEST_REGION
84+
) + REDIRECT_STUDIO_FORMAT.format(TEST_TRAINING_JOB)
85+
86+
# test url with invalid job redirect
87+
with pytest.raises(ValueError):
88+
tb_app.get_app_url("invald_job_name!")
89+
90+
91+
@patch("os.path.isfile")
92+
def test_tb_init_and_url_studio_user_invalid_medatada(mock_file_exists):
93+
"""
94+
Test TensorBoardApp for Studio user when the notebook metadata file provided by Studio is invalid.
95+
"""
96+
mock_file_exists.return_value = True
97+
98+
# test file does not contain domain and user profle
99+
with patch("builtins.open", mock_open(read_data=json.dumps({"Fake": "Fake"}))):
100+
assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(
101+
region=TEST_REGION
102+
)
103+
104+
# test invalid user profile name
105+
with patch(
106+
"builtins.open",
107+
mock_open(read_data=json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": "u" * 64})),
108+
):
109+
assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(
110+
region=TEST_REGION
111+
)
112+
113+
# test invalid domain id
114+
with patch(
115+
"builtins.open",
116+
mock_open(
117+
read_data=json.dumps({"DomainId": "d" * 64, "UserProfileName": TEST_USER_PROFILE})
118+
),
119+
):
120+
assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(
121+
region=TEST_REGION
122+
)
123+
124+
125+
def test_tb_init_with_default_region():
126+
"""
127+
Test TensorBoardApp init when user does not provide region.
128+
"""
129+
# happy case
130+
with patch("sagemaker.Session.boto_region_name", new_callable=PropertyMock) as region_mock:
131+
region_mock.return_value = TEST_REGION
132+
tb_app = TensorBoardApp()
133+
assert tb_app.region == TEST_REGION
134+
135+
# no default region configured
136+
with patch("sagemaker.Session.boto_region_name", new_callable=PropertyMock) as region_mock:
137+
region_mock.side_effect = [ValueError()]
138+
with pytest.raises(ValueError):
139+
tb_app = TensorBoardApp()

0 commit comments

Comments
 (0)