Skip to content

Commit b8e130c

Browse files
committed
change: improve getting sm version logic
1 parent 6604acc commit b8e130c

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@
2727
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
2828

2929
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
30+
31+
PARSED_SAGEMAKER_VERSION = "" # this gets set by sagemaker.jumpstart.utils.get_sagemaker_version()

src/sagemaker/jumpstart/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ def get_formatted_manifest(
7575
def get_sagemaker_version() -> str:
7676
"""Returns sagemaker library version.
7777
78+
If the sagemaker library version has not been set yet, this function
79+
calls ``parse_sagemaker_version`` to retrive the version.
80+
"""
81+
if constants.PARSED_SAGEMAKER_VERSION == "":
82+
constants.PARSED_SAGEMAKER_VERSION = parse_sagemaker_version()
83+
return constants.PARSED_SAGEMAKER_VERSION
84+
85+
86+
def parse_sagemaker_version() -> str:
87+
"""Returns sagemaker library version. This should only be called once.
88+
7889
Function reads ``__version__`` variable in ``sagemaker`` module.
7990
In order to maintain compatibility with the ``semantic_version``
8091
library, versions with fewer than 2, or more than 3, periods are rejected.

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14-
from mock.mock import patch
14+
from mock.mock import Mock, patch
1515
import pytest
1616
from sagemaker.jumpstart import utils
1717
from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET
@@ -73,33 +73,42 @@ def test_get_formatted_manifest():
7373
assert utils.get_formatted_manifest([]) == {}
7474

7575

76-
def test_get_sagemaker_version():
76+
def test_parse_sagemaker_version():
7777

7878
with patch("sagemaker.__version__", "1.2.3"):
79-
assert utils.get_sagemaker_version() == "1.2.3"
79+
assert utils.parse_sagemaker_version() == "1.2.3"
8080

8181
with patch("sagemaker.__version__", "1.2.3.3332j"):
82-
assert utils.get_sagemaker_version() == "1.2.3"
82+
assert utils.parse_sagemaker_version() == "1.2.3"
8383

8484
with patch("sagemaker.__version__", "1.2.3."):
85-
assert utils.get_sagemaker_version() == "1.2.3"
85+
assert utils.parse_sagemaker_version() == "1.2.3"
8686

8787
with pytest.raises(ValueError):
8888
with patch("sagemaker.__version__", "1.2.3dfsdfs"):
89-
utils.get_sagemaker_version()
89+
utils.parse_sagemaker_version()
9090

9191
with pytest.raises(RuntimeError):
9292
with patch("sagemaker.__version__", "1.2"):
93-
utils.get_sagemaker_version()
93+
utils.parse_sagemaker_version()
9494

9595
with pytest.raises(RuntimeError):
9696
with patch("sagemaker.__version__", "1"):
97-
utils.get_sagemaker_version()
97+
utils.parse_sagemaker_version()
9898

9999
with pytest.raises(RuntimeError):
100100
with patch("sagemaker.__version__", ""):
101-
utils.get_sagemaker_version()
101+
utils.parse_sagemaker_version()
102102

103103
with pytest.raises(RuntimeError):
104104
with patch("sagemaker.__version__", "1.2.3.4.5"):
105-
utils.get_sagemaker_version()
105+
utils.parse_sagemaker_version()
106+
107+
108+
@patch("sagemaker.jumpstart.utils.parse_sagemaker_version")
109+
@patch("sagemaker.jumpstart.constants.PARSED_SAGEMAKER_VERSION", "")
110+
def test_get_sagemaker_version(patched_parse_sm_version: Mock):
111+
utils.get_sagemaker_version()
112+
utils.get_sagemaker_version()
113+
utils.get_sagemaker_version()
114+
assert patched_parse_sm_version.called_only_once()

0 commit comments

Comments
 (0)