Skip to content

Commit 1f3754d

Browse files
authored
fix: jumpstart cache using sagemaker session s3 client (#4051)
1 parent 39943a8 commit 1f3754d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+814
-166
lines changed

src/sagemaker/accept_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import List, Optional
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.session import Session
1820

1921

2022
def retrieve_options(
@@ -23,6 +25,7 @@ def retrieve_options(
2325
model_version: Optional[str] = None,
2426
tolerate_vulnerable_model: bool = False,
2527
tolerate_deprecated_model: bool = False,
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2629
) -> List[str]:
2730
"""Retrieves the supported accept types for the model matching the given arguments.
2831
@@ -40,6 +43,10 @@ def retrieve_options(
4043
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
4144
(exception not raised). False if these models should raise an exception.
4245
(Default: False).
46+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
47+
object, used for SageMaker interactions. If not
48+
specified, one is created using the default AWS configuration
49+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
4350
Returns:
4451
list: The supported accept types to use for the model.
4552
@@ -57,6 +64,7 @@ def retrieve_options(
5764
region,
5865
tolerate_vulnerable_model,
5966
tolerate_deprecated_model,
67+
sagemaker_session=sagemaker_session,
6068
)
6169

6270

@@ -66,6 +74,7 @@ def retrieve_default(
6674
model_version: Optional[str] = None,
6775
tolerate_vulnerable_model: bool = False,
6876
tolerate_deprecated_model: bool = False,
77+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6978
) -> str:
7079
"""Retrieves the default accept type for the model matching the given arguments.
7180
@@ -83,6 +92,10 @@ def retrieve_default(
8392
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
8493
(exception not raised). False if these models should raise an exception.
8594
(Default: False).
95+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
96+
object, used for SageMaker interactions. If not
97+
specified, one is created using the default AWS configuration
98+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
8699
Returns:
87100
str: The default accept type to use for the model.
88101
@@ -100,4 +113,5 @@ def retrieve_default(
100113
region,
101114
tolerate_vulnerable_model,
102115
tolerate_deprecated_model,
116+
sagemaker_session=sagemaker_session,
103117
)

src/sagemaker/content_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import List, Optional
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.session import Session
1820

1921

2022
def retrieve_options(
@@ -23,6 +25,7 @@ def retrieve_options(
2325
model_version: Optional[str] = None,
2426
tolerate_vulnerable_model: bool = False,
2527
tolerate_deprecated_model: bool = False,
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2629
) -> List[str]:
2730
"""Retrieves the supported content types for the model matching the given arguments.
2831
@@ -40,6 +43,10 @@ def retrieve_options(
4043
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
4144
(exception not raised). False if these models should raise an exception.
4245
(Default: False).
46+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
47+
object, used for SageMaker interactions. If not
48+
specified, one is created using the default AWS configuration
49+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
4350
Returns:
4451
list: The supported content types to use for the model.
4552
@@ -57,6 +64,7 @@ def retrieve_options(
5764
region,
5865
tolerate_vulnerable_model,
5966
tolerate_deprecated_model,
67+
sagemaker_session=sagemaker_session,
6068
)
6169

6270

@@ -66,6 +74,7 @@ def retrieve_default(
6674
model_version: Optional[str] = None,
6775
tolerate_vulnerable_model: bool = False,
6876
tolerate_deprecated_model: bool = False,
77+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6978
) -> str:
7079
"""Retrieves the default content type for the model matching the given arguments.
7180
@@ -83,6 +92,10 @@ def retrieve_default(
8392
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
8493
(exception not raised). False if these models should raise an exception.
8594
(Default: False).
95+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
96+
object, used for SageMaker interactions. If not
97+
specified, one is created using the default AWS configuration
98+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
8699
Returns:
87100
str: The default content type to use for the model.
88101
@@ -100,6 +113,7 @@ def retrieve_default(
100113
region,
101114
tolerate_vulnerable_model,
102115
tolerate_deprecated_model,
116+
sagemaker_session=sagemaker_session,
103117
)
104118

105119

src/sagemaker/deserializers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
)
3434

3535
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
36+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
37+
from sagemaker.session import Session
3638

3739

3840
def retrieve_options(
@@ -41,6 +43,7 @@ def retrieve_options(
4143
model_version: Optional[str] = None,
4244
tolerate_vulnerable_model: bool = False,
4345
tolerate_deprecated_model: bool = False,
46+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4447
) -> List[BaseDeserializer]:
4548
"""Retrieves the supported deserializers for the model matching the given arguments.
4649
@@ -58,6 +61,10 @@ def retrieve_options(
5861
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
5962
(exception not raised). False if these models should raise an exception.
6063
(Default: False).
64+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
65+
object, used for SageMaker interactions. If not
66+
specified, one is created using the default AWS configuration
67+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6168
Returns:
6269
List[BaseDeserializer]: The supported deserializers to use for the model.
6370
@@ -76,6 +83,7 @@ def retrieve_options(
7683
region,
7784
tolerate_vulnerable_model,
7885
tolerate_deprecated_model,
86+
sagemaker_session=sagemaker_session,
7987
)
8088

8189

@@ -85,6 +93,7 @@ def retrieve_default(
8593
model_version: Optional[str] = None,
8694
tolerate_vulnerable_model: bool = False,
8795
tolerate_deprecated_model: bool = False,
96+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8897
) -> BaseDeserializer:
8998
"""Retrieves the default deserializer for the model matching the given arguments.
9099
@@ -102,6 +111,10 @@ def retrieve_default(
102111
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
103112
(exception not raised). False if these models should raise an exception.
104113
(Default: False).
114+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
115+
object, used for SageMaker interactions. If not
116+
specified, one is created using the default AWS configuration
117+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
105118
Returns:
106119
BaseDeserializer: The default deserializer to use for the model.
107120
@@ -120,4 +133,5 @@ def retrieve_default(
120133
region,
121134
tolerate_vulnerable_model,
122135
tolerate_deprecated_model,
136+
sagemaker_session=sagemaker_session,
123137
)

src/sagemaker/environment_variables.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.session import Session
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -30,6 +32,7 @@ def retrieve_default(
3032
tolerate_vulnerable_model: bool = False,
3133
tolerate_deprecated_model: bool = False,
3234
include_aws_sdk_env_vars: bool = True,
35+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3336
) -> Dict[str, str]:
3437
"""Retrieves the default container environment variables for the model matching the arguments.
3538
@@ -51,6 +54,10 @@ def retrieve_default(
5154
should be included. The `Model` class of the SageMaker Python SDK inserts environment
5255
variables that would be required when making the low-level AWS API call.
5356
(Default: True).
57+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
58+
object, used for SageMaker interactions. If not
59+
specified, one is created using the default AWS configuration
60+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5461
Returns:
5562
dict: The variables to use for the model.
5663
@@ -70,4 +77,5 @@ def retrieve_default(
7077
tolerate_vulnerable_model,
7178
tolerate_deprecated_model,
7279
include_aws_sdk_env_vars,
80+
sagemaker_session=sagemaker_session,
7381
)

src/sagemaker/hyperparameters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2223
from sagemaker.jumpstart.enums import HyperparameterValidationMode
2324
from sagemaker.jumpstart.validators import validate_hyperparameters
25+
from sagemaker.session import Session
2426

2527
logger = logging.getLogger(__name__)
2628

@@ -32,6 +34,7 @@ def retrieve_default(
3234
include_container_hyperparameters: bool = False,
3335
tolerate_vulnerable_model: bool = False,
3436
tolerate_deprecated_model: bool = False,
37+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3538
) -> Dict[str, str]:
3639
"""Retrieves the default training hyperparameters for the model matching the given arguments.
3740
@@ -56,6 +59,10 @@ def retrieve_default(
5659
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
5760
(exception not raised). False if these models should raise an exception.
5861
(Default: False).
62+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
63+
object, used for SageMaker interactions. If not
64+
specified, one is created using the default AWS configuration
65+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5966
Returns:
6067
dict: The hyperparameters to use for the model.
6168
@@ -74,6 +81,7 @@ def retrieve_default(
7481
include_container_hyperparameters,
7582
tolerate_vulnerable_model,
7683
tolerate_deprecated_model,
84+
sagemaker_session=sagemaker_session,
7785
)
7886

7987

@@ -83,6 +91,9 @@ def validate(
8391
model_version: Optional[str] = None,
8492
hyperparameters: Optional[dict] = None,
8593
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
94+
tolerate_vulnerable_model: bool = False,
95+
tolerate_deprecated_model: bool = False,
96+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8697
) -> None:
8798
"""Validates hyperparameters for models.
8899
@@ -100,6 +111,17 @@ def validate(
100111
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
101112
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
102113
(Default: None).
114+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
115+
specifications should be tolerated (exception not raised). If False, raises an
116+
exception if the script used by this version of the model has dependencies with known
117+
security vulnerabilities. (Default: False).
118+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
119+
(exception not raised). False if these models should raise an exception.
120+
(Default: False).
121+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
122+
object, used for SageMaker interactions. If not
123+
specified, one is created using the default AWS configuration
124+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
103125
104126
Raises:
105127
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
@@ -125,4 +147,7 @@ def validate(
125147
hyperparameters=hyperparameters,
126148
validation_mode=validation_mode,
127149
region=region,
150+
tolerate_vulnerable_model=tolerate_vulnerable_model,
151+
tolerate_deprecated_model=tolerate_deprecated_model,
152+
sagemaker_session=sagemaker_session,
128153
)

src/sagemaker/image_uris.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from packaging.version import Version
2222

2323
from sagemaker import utils
24+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2425
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2526
from sagemaker.spark import defaults
2627
from sagemaker.jumpstart import artifacts
@@ -60,6 +61,7 @@ def retrieve(
6061
sdk_version=None,
6162
inference_tool=None,
6263
serverless_inference_config=None,
64+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6365
) -> str:
6466
"""Retrieves the ECR URI for the Docker image matching the given arguments.
6567
@@ -109,6 +111,10 @@ def retrieve(
109111
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
110112
Specifies configuration related to serverless endpoint. Instance type is
111113
not provided in serverless inference. So this is used to determine processor type.
114+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
115+
object, used for SageMaker interactions. If not
116+
specified, one is created using the default AWS configuration
117+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
112118
113119
Returns:
114120
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -147,6 +153,7 @@ def retrieve(
147153
training_compiler_config,
148154
tolerate_vulnerable_model,
149155
tolerate_deprecated_model,
156+
sagemaker_session=sagemaker_session,
150157
)
151158

152159
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):

src/sagemaker/instance_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.session import Session
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -30,6 +32,7 @@ def retrieve_default(
3032
scope: Optional[str] = None,
3133
tolerate_vulnerable_model: bool = False,
3234
tolerate_deprecated_model: bool = False,
35+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3336
) -> str:
3437
"""Retrieves the default instance type for the model matching the given arguments.
3538
@@ -49,6 +52,10 @@ def retrieve_default(
4952
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
5053
(exception not raised). False if these models should raise an exception.
5154
(Default: False).
55+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
56+
object, used for SageMaker interactions. If not
57+
specified, one is created using the default AWS configuration
58+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5259
Returns:
5360
str: The default instance type to use for the model.
5461
@@ -70,6 +77,7 @@ def retrieve_default(
7077
region,
7178
tolerate_vulnerable_model,
7279
tolerate_deprecated_model,
80+
sagemaker_session=sagemaker_session,
7381
)
7482

7583

@@ -80,6 +88,7 @@ def retrieve(
8088
scope: Optional[str] = None,
8189
tolerate_vulnerable_model: bool = False,
8290
tolerate_deprecated_model: bool = False,
91+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8392
) -> List[str]:
8493
"""Retrieves the supported training instance types for the model matching the given arguments.
8594
@@ -97,6 +106,10 @@ def retrieve(
97106
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
98107
(exception not raised). False if these models should raise an exception.
99108
(Default: False).
109+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
110+
object, used for SageMaker interactions. If not
111+
specified, one is created using the default AWS configuration
112+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
100113
Returns:
101114
list: The supported instance types to use for the model.
102115
@@ -118,4 +131,5 @@ def retrieve(
118131
region,
119132
tolerate_vulnerable_model,
120133
tolerate_deprecated_model,
134+
sagemaker_session=sagemaker_session,
121135
)

0 commit comments

Comments
 (0)