Skip to content

Commit c115b7a

Browse files
bmouryakrMourya Baddam
authored andcommitted
feature: Auto-capture image URI (aws#861)
* pathways:add default image uri config * pathways: Add unit test for studio image uri * pathways: Add unit test for studio image uri * change: Move pathways image config to job --------- Co-authored-by: Mourya Baddam <[email protected]>
1 parent b794bb3 commit c115b7a

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

src/sagemaker/remote_function/job.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
import shutil
19+
import sys
1920
from typing import Dict, List, Tuple
2021

2122
from sagemaker.config import config_schema
@@ -120,8 +121,12 @@ def __init__(
120121
environment_variables, config_schema.ENVIRONMENT_VARIABLES
121122
)
122123

123-
# TODO: provide default image uri if not set
124-
self.image_uri = self._get_from_config(image_uri, config_schema.IMAGE_URI, required=True)
124+
_image_uri = self._get_from_config(image_uri, config_schema.IMAGE_URI)
125+
if _image_uri:
126+
self.image_uri = _image_uri
127+
else:
128+
self.image_uri = self._get_default_image(self.sagemaker_session)
129+
125130
self.dependencies = self._get_from_config(dependencies, config_schema.DEPENDENCIES)
126131

127132
self.instance_type = self._get_from_config(
@@ -171,7 +176,7 @@ def _get_from_config(
171176
default=None,
172177
required=False,
173178
):
174-
"""Get default value from sagemaker config."""
179+
"""Get value from sagemaker config."""
175180
if override_value:
176181
return override_value
177182
config_value = self.sagemaker_config.get_config_value(
@@ -189,6 +194,33 @@ def _get_from_config(
189194
raise ValueError(f"{sagemaker_config_key} is a required parameter!")
190195
return default
191196

197+
@staticmethod
198+
def _get_default_image(session):
199+
"""Return Studio notebook image, if in Studio env. Else, base python"""
200+
201+
if (
202+
"SAGEMAKER_INTERNAL_IMAGE_URI" in os.environ
203+
and os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"]
204+
):
205+
return os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"]
206+
207+
py_major_version = sys.version_info[0]
208+
py_minor_version = sys.version_info[1]
209+
210+
# TODO:Add Support for 3.8
211+
if py_major_version != 3 or py_minor_version != 10:
212+
raise ValueError("Use supported Python version or provide compatible ImageUri.")
213+
214+
# TODO: Support only supported by Studio
215+
region = session.boto_region_name
216+
217+
# TODO: Remove beta image and use public base python
218+
beta_image = (
219+
f"581474259216.dkr.ecr.{region}.amazonaws.com/"
220+
f"sagemaker-pathways-beta:basepy_3_10_latest"
221+
)
222+
return beta_image
223+
192224

193225
class _Job:
194226
"""Helper class that interacts with the SageMaker training service."""

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import sys
1617

1718
import pytest
1819
from mock import patch, Mock, ANY
@@ -123,8 +124,35 @@ def test_sagemaker_config_job_settings_missing_image_uri(get_execution_role, ses
123124
"SAGEMAKER_DEFAULT_CONFIG_OVERRIDE", os.path.join(DATA_DIR, "remote_function")
124125
)
125126

126-
with pytest.raises(ValueError, match="ImageUri is a required parameter!"):
127-
_JobSettings()
127+
py_major_version = sys.version_info[0]
128+
py_minor_version = sys.version_info[1]
129+
if py_major_version != 3 or py_minor_version != 10:
130+
with pytest.raises(
131+
ValueError, match="Use supported Python version or provide compatible ImageUri."
132+
):
133+
_JobSettings()
134+
else:
135+
job_settings = _JobSettings()
136+
assert (
137+
job_settings.image_uri
138+
== f"581474259216.dkr.ecr.{TEST_REGION}.amazonaws.com/sagemaker-pathways-beta:basepy_3_10_latest"
139+
)
140+
141+
monkeypatch.delenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE")
142+
143+
144+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
145+
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
146+
def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, session, monkeypatch):
147+
monkeypatch.setenv(
148+
"SAGEMAKER_DEFAULT_CONFIG_OVERRIDE", os.path.join(DATA_DIR, "remote_function")
149+
)
150+
monkeypatch.setenv("SAGEMAKER_INTERNAL_IMAGE_URI", "studio_image_uri")
151+
152+
job_settings = _JobSettings()
153+
assert job_settings.image_uri == "studio_image_uri"
154+
155+
monkeypatch.delenv("SAGEMAKER_INTERNAL_IMAGE_URI")
128156
monkeypatch.delenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE")
129157

130158

0 commit comments

Comments
 (0)