Skip to content

Commit 3bde038

Browse files
committed
fix: sagemaker version for jumpstart
1 parent e94a87b commit 3bde038

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,3 @@
2727
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
2828

2929
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
30-
31-
PARSED_SAGEMAKER_VERSION = "" # this gets set by sagemaker.jumpstart.utils.get_sagemaker_version()

src/sagemaker/jumpstart/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
2020

2121

22+
class SageMakerSettings(object):
23+
"""Static class for storing the SageMaker settings."""
24+
25+
_PARSED_SAGEMAKER_VERSION = ""
26+
27+
@staticmethod
28+
def set_sagemaker_version(version: str) -> None:
29+
"""Set SageMaker version."""
30+
SageMakerSettings._PARSED_SAGEMAKER_VERSION = version
31+
32+
@staticmethod
33+
def get_sagemaker_version() -> str:
34+
"""Return SageMaker version."""
35+
return SageMakerSettings._PARSED_SAGEMAKER_VERSION
36+
37+
2238
def get_jumpstart_launched_regions_message() -> str:
2339
"""Returns formatted string indicating where JumpStart is launched."""
2440
if len(constants.JUMPSTART_REGION_NAME_SET) == 0:
@@ -79,9 +95,9 @@ def get_sagemaker_version() -> str:
7995
calls ``parse_sagemaker_version`` to retrieve the version and set
8096
the constant.
8197
"""
82-
if constants.PARSED_SAGEMAKER_VERSION == "":
83-
constants.PARSED_SAGEMAKER_VERSION = parse_sagemaker_version()
84-
return constants.PARSED_SAGEMAKER_VERSION
98+
if SageMakerSettings.get_sagemaker_version() == "":
99+
SageMakerSettings.set_sagemaker_version(parse_sagemaker_version())
100+
return SageMakerSettings.get_sagemaker_version()
85101

86102

87103
def parse_sagemaker_version() -> str:

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_parse_sagemaker_version():
106106

107107

108108
@patch("sagemaker.jumpstart.utils.parse_sagemaker_version")
109-
@patch("sagemaker.jumpstart.constants.PARSED_SAGEMAKER_VERSION", "")
109+
@patch("sagemaker.jumpstart.utils.SageMakerSettings._PARSED_SAGEMAKER_VERSION", "")
110110
def test_get_sagemaker_version(patched_parse_sm_version: Mock):
111111
utils.get_sagemaker_version()
112112
utils.get_sagemaker_version()

0 commit comments

Comments
 (0)