Skip to content

Commit 4580c07

Browse files
metrizableDan Choi
authored andcommitted
feature: ensure studio project id tag is added if available (aws#518)
1 parent 2412204 commit 4580c07

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

src/sagemaker/_studio.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Provides internal tooling for studio environments."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import logging
18+
19+
from pathlib import Path
20+
21+
STUDIO_PROJECT_CONFIG = ".sagemaker-code-config"
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def _append_project_tags(working_dir=None, tags=None):
27+
"""Appends the project tag to the list of tags, if it exists.
28+
29+
Args:
30+
working_dir: the working directory to start looking.
31+
tags: the list of tags to append to.
32+
33+
Returns:
34+
A possibly extended list of tags that includes the project id.
35+
"""
36+
path = _find_config(working_dir)
37+
if path is None:
38+
return tags
39+
40+
config = _load_config(path)
41+
if config is None:
42+
return tags
43+
44+
additional_tags = _parse_tags(config)
45+
if additional_tags is None:
46+
return tags
47+
48+
all_tags = tags or []
49+
all_tags.extend(additional_tags)
50+
51+
return all_tags
52+
53+
54+
def _find_config(working_dir=None):
55+
"""Gets project config on SageMaker Studio platforms, if it exists.
56+
57+
Args:
58+
working_dir: the working directory to start looking.
59+
60+
Returns:
61+
The project config path, if it exists. Otherwise None.
62+
"""
63+
try:
64+
wd = Path(working_dir) if working_dir else Path.cwd()
65+
66+
path = None
67+
while path is None and not wd.match("/"):
68+
candidate = wd / STUDIO_PROJECT_CONFIG
69+
if Path.exists(candidate):
70+
path = candidate
71+
wd = wd.parent
72+
73+
return path
74+
except Exception as e: # pylint: disable=W0703
75+
logger.debug("Could not find the studio project config. %s", e)
76+
77+
78+
def _load_config(path):
79+
"""Parse out the projectId attribute if it exists at path.
80+
81+
Args:
82+
path: path to project config
83+
84+
Returns:
85+
Project config Json, or None if it does not exist.
86+
"""
87+
try:
88+
with open(path, "r") as f:
89+
content = f.read().strip()
90+
config = json.loads(content)
91+
92+
return config
93+
except Exception as e: # pylint: disable=W0703
94+
logger.debug("Could not load project config. %s", e)
95+
96+
97+
def _parse_tags(config):
98+
"""Parse out appropriate attributes and formats as tags.
99+
100+
Args:
101+
config: project config dict
102+
103+
Returns:
104+
List of tags
105+
"""
106+
try:
107+
return [
108+
{"Key": "sagemaker:project-id", "Value": config["sagemakerProjectId"]},
109+
{"Key": "sagemaker:project-name", "Value": config["sagemakerProjectName"]},
110+
]
111+
except Exception as e: # pylint: disable=W0703
112+
logger.debug("Could not parse project config. %s", e)

src/sagemaker/workflow/pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from botocore.exceptions import ClientError
2525

26+
from sagemaker._studio import _append_project_tags
2627
from sagemaker.session import Session
2728
from sagemaker.workflow.entities import (
2829
Entity,
@@ -90,6 +91,8 @@ def create(
9091
Returns:
9192
A response dict from the service.
9293
"""
94+
tags = _append_project_tags(tags)
95+
9396
kwargs = self._create_args(role_arn, description)
9497
update_args(
9598
kwargs,

tests/unit/sagemaker/test_studio.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
from __future__ import absolute_import
15+
16+
from sagemaker._studio import (
17+
_append_project_tags,
18+
_find_config,
19+
_load_config,
20+
_parse_tags,
21+
)
22+
23+
24+
def test_find_config(tmpdir):
25+
path = tmpdir.join(".sagemaker-code-config")
26+
path.write('{"sagemakerProjectId": "proj-1234"}')
27+
working_dir = tmpdir.mkdir("sub")
28+
29+
found_path = _find_config(working_dir)
30+
assert found_path == path
31+
32+
33+
def test_find_config_missing(tmpdir):
34+
working_dir = tmpdir.mkdir("sub")
35+
36+
found_path = _find_config(working_dir)
37+
assert found_path is None
38+
39+
40+
def test_load_config(tmpdir):
41+
path = tmpdir.join(".sagemaker-code-config")
42+
path.write('{"sagemakerProjectId": "proj-1234"}')
43+
44+
config = _load_config(path)
45+
assert isinstance(config, dict)
46+
47+
48+
def test_load_config_malformed(tmpdir):
49+
path = tmpdir.join(".sagemaker-code-config")
50+
path.write('{"proj')
51+
52+
config = _load_config(path)
53+
assert config is None
54+
55+
56+
def test_parse_tags():
57+
tags = _parse_tags(
58+
{
59+
"sagemakerProjectId": "proj-1234",
60+
"sagemakerProjectName": "proj-name",
61+
"foo": "abc",
62+
}
63+
)
64+
assert tags == [
65+
{"Key": "sagemaker:project-id", "Value": "proj-1234"},
66+
{"Key": "sagemaker:project-name", "Value": "proj-name"},
67+
]
68+
69+
70+
def test_parse_tags_missing():
71+
tags = _parse_tags(
72+
{
73+
"sagemakerProjectId": "proj-1234",
74+
"foo": "abc",
75+
}
76+
)
77+
assert tags is None
78+
79+
80+
def test_append_project_tags(tmpdir):
81+
config = tmpdir.join(".sagemaker-code-config")
82+
config.write('{"sagemakerProjectId": "proj-1234", "sagemakerProjectName": "proj-name"}')
83+
working_dir = tmpdir.mkdir("sub")
84+
85+
tags = _append_project_tags(working_dir, None)
86+
assert tags == [
87+
{"Key": "sagemaker:project-id", "Value": "proj-1234"},
88+
{"Key": "sagemaker:project-name", "Value": "proj-name"},
89+
]
90+
91+
tags = _append_project_tags(working_dir, [{"Key": "a", "Value": "b"}])
92+
assert tags == [
93+
{"Key": "a", "Value": "b"},
94+
{"Key": "sagemaker:project-id", "Value": "proj-1234"},
95+
{"Key": "sagemaker:project-name", "Value": "proj-name"},
96+
]

0 commit comments

Comments
 (0)