File tree Expand file tree Collapse file tree 3 files changed +20
-6
lines changed
tests/unit/sagemaker/jumpstart Expand file tree Collapse file tree 3 files changed +20
-6
lines changed Original file line number Diff line number Diff line change 27
27
JUMPSTART_DEFAULT_REGION_NAME = boto3 .session .Session ().region_name
28
28
29
29
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
30
-
31
- PARSED_SAGEMAKER_VERSION = "" # this gets set by sagemaker.jumpstart.utils.get_sagemaker_version()
Original file line number Diff line number Diff line change 19
19
from sagemaker .jumpstart .types import JumpStartModelHeader , JumpStartVersionedModelId
20
20
21
21
22
+ class SageMakerSettings (object ):
23
+ """Static class for storing the SageMaker settings."""
24
+
25
+ _PARSED_SAGEMAKER_VERSION = ""
26
+
27
+ @staticmethod
28
+ def set_sagemaker_version (version : str ) -> None :
29
+ """Set SageMaker version."""
30
+ SageMakerSettings ._PARSED_SAGEMAKER_VERSION = version
31
+
32
+ @staticmethod
33
+ def get_sagemaker_version () -> str :
34
+ """Return SageMaker version."""
35
+ return SageMakerSettings ._PARSED_SAGEMAKER_VERSION
36
+
37
+
22
38
def get_jumpstart_launched_regions_message () -> str :
23
39
"""Returns formatted string indicating where JumpStart is launched."""
24
40
if len (constants .JUMPSTART_REGION_NAME_SET ) == 0 :
@@ -79,9 +95,9 @@ def get_sagemaker_version() -> str:
79
95
calls ``parse_sagemaker_version`` to retrieve the version and set
80
96
the constant.
81
97
"""
82
- if constants . PARSED_SAGEMAKER_VERSION == "" :
83
- constants . PARSED_SAGEMAKER_VERSION = parse_sagemaker_version ()
84
- return constants . PARSED_SAGEMAKER_VERSION
98
+ if SageMakerSettings . get_sagemaker_version () == "" :
99
+ SageMakerSettings . set_sagemaker_version ( parse_sagemaker_version () )
100
+ return SageMakerSettings . get_sagemaker_version ()
85
101
86
102
87
103
def parse_sagemaker_version () -> str :
Original file line number Diff line number Diff line change @@ -106,7 +106,7 @@ def test_parse_sagemaker_version():
106
106
107
107
108
108
@patch ("sagemaker.jumpstart.utils.parse_sagemaker_version" )
109
- @patch ("sagemaker.jumpstart.constants.PARSED_SAGEMAKER_VERSION " , "" )
109
+ @patch ("sagemaker.jumpstart.utils.SageMakerSettings._PARSED_SAGEMAKER_VERSION " , "" )
110
110
def test_get_sagemaker_version (patched_parse_sm_version : Mock ):
111
111
utils .get_sagemaker_version ()
112
112
utils .get_sagemaker_version ()
You can’t perform that action at this time.
0 commit comments