Skip to content

Commit 924d2f1

Browse files
authored
Merge branch 'master' into master
2 parents 5d0c04d + a3f5874 commit 924d2f1

File tree

97 files changed

+1720
-226
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+1720
-226
lines changed

CHANGELOG.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,45 @@
11
# Changelog
22

3+
## v2.178.0 (2023-08-17)
4+
5+
### Features
6+
7+
* Support to get latest monitoring execution processing logs
8+
9+
### Bug Fixes and Other Changes
10+
11+
* Add context to predict_fn example
12+
* gated models unsupported region
13+
* jumpstart cache using sagemaker session s3 client
14+
* add TFS 2.13 Graviton SM images
15+
* pipeline variable kms key
16+
* integration test for gated jumpstart training model
17+
* tags for jumpstart model package models
18+
19+
## v2.177.1 (2023-08-14)
20+
21+
### Bug Fixes and Other Changes
22+
23+
* chore: excessive jumpstart bucket logging
24+
25+
## v2.177.0 (2023-08-11)
26+
27+
### Features
28+
29+
* Add TLV accounts for 1P Algorithms
30+
31+
## v2.176.0 (2023-08-10)
32+
33+
### Features
34+
35+
* Add TF 2.13 Training and Inference SM images
36+
37+
### Bug Fixes and Other Changes
38+
39+
* revert-PR_3903
40+
* skip tensorflow local mode notebook test
41+
* change instance type for huggingface test to ml.g5.8xlarge
42+
343
## v2.175.0 (2023-08-05)
444

545
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.175.1.dev0
1+
2.178.1.dev0

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ The following example is for use cases with multiple GPUs and shows an overridde
772772
import torch
773773
import numpy as np
774774
775-
def predict_fn(input_data, model):
775+
def predict_fn(input_data, model, context):
776776
device = torch.device("cuda:" + str(context.system_properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
777777
model.to(device)
778778
model.eval()

src/sagemaker/accept_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import List, Optional
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.session import Session
1820

1921

2022
def retrieve_options(
@@ -23,6 +25,7 @@ def retrieve_options(
2325
model_version: Optional[str] = None,
2426
tolerate_vulnerable_model: bool = False,
2527
tolerate_deprecated_model: bool = False,
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2629
) -> List[str]:
2730
"""Retrieves the supported accept types for the model matching the given arguments.
2831
@@ -40,6 +43,10 @@ def retrieve_options(
4043
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
4144
(exception not raised). False if these models should raise an exception.
4245
(Default: False).
46+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
47+
object, used for SageMaker interactions. If not
48+
specified, one is created using the default AWS configuration
49+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
4350
Returns:
4451
list: The supported accept types to use for the model.
4552
@@ -57,6 +64,7 @@ def retrieve_options(
5764
region,
5865
tolerate_vulnerable_model,
5966
tolerate_deprecated_model,
67+
sagemaker_session=sagemaker_session,
6068
)
6169

6270

@@ -66,6 +74,7 @@ def retrieve_default(
6674
model_version: Optional[str] = None,
6775
tolerate_vulnerable_model: bool = False,
6876
tolerate_deprecated_model: bool = False,
77+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6978
) -> str:
7079
"""Retrieves the default accept type for the model matching the given arguments.
7180
@@ -83,6 +92,10 @@ def retrieve_default(
8392
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
8493
(exception not raised). False if these models should raise an exception.
8594
(Default: False).
95+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
96+
object, used for SageMaker interactions. If not
97+
specified, one is created using the default AWS configuration
98+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
8699
Returns:
87100
str: The default accept type to use for the model.
88101
@@ -100,4 +113,5 @@ def retrieve_default(
100113
region,
101114
tolerate_vulnerable_model,
102115
tolerate_deprecated_model,
116+
sagemaker_session=sagemaker_session,
103117
)

src/sagemaker/content_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import List, Optional
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.session import Session
1820

1921

2022
def retrieve_options(
@@ -23,6 +25,7 @@ def retrieve_options(
2325
model_version: Optional[str] = None,
2426
tolerate_vulnerable_model: bool = False,
2527
tolerate_deprecated_model: bool = False,
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2629
) -> List[str]:
2730
"""Retrieves the supported content types for the model matching the given arguments.
2831
@@ -40,6 +43,10 @@ def retrieve_options(
4043
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
4144
(exception not raised). False if these models should raise an exception.
4245
(Default: False).
46+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
47+
object, used for SageMaker interactions. If not
48+
specified, one is created using the default AWS configuration
49+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
4350
Returns:
4451
list: The supported content types to use for the model.
4552
@@ -57,6 +64,7 @@ def retrieve_options(
5764
region,
5865
tolerate_vulnerable_model,
5966
tolerate_deprecated_model,
67+
sagemaker_session=sagemaker_session,
6068
)
6169

6270

@@ -66,6 +74,7 @@ def retrieve_default(
6674
model_version: Optional[str] = None,
6775
tolerate_vulnerable_model: bool = False,
6876
tolerate_deprecated_model: bool = False,
77+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6978
) -> str:
7079
"""Retrieves the default content type for the model matching the given arguments.
7180
@@ -83,6 +92,10 @@ def retrieve_default(
8392
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
8493
(exception not raised). False if these models should raise an exception.
8594
(Default: False).
95+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
96+
object, used for SageMaker interactions. If not
97+
specified, one is created using the default AWS configuration
98+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
8699
Returns:
87100
str: The default content type to use for the model.
88101
@@ -100,6 +113,7 @@ def retrieve_default(
100113
region,
101114
tolerate_vulnerable_model,
102115
tolerate_deprecated_model,
116+
sagemaker_session=sagemaker_session,
103117
)
104118

105119

src/sagemaker/deserializers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
)
3434

3535
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
36+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
37+
from sagemaker.session import Session
3638

3739

3840
def retrieve_options(
@@ -41,6 +43,7 @@ def retrieve_options(
4143
model_version: Optional[str] = None,
4244
tolerate_vulnerable_model: bool = False,
4345
tolerate_deprecated_model: bool = False,
46+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4447
) -> List[BaseDeserializer]:
4548
"""Retrieves the supported deserializers for the model matching the given arguments.
4649
@@ -58,6 +61,10 @@ def retrieve_options(
5861
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
5962
(exception not raised). False if these models should raise an exception.
6063
(Default: False).
64+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
65+
object, used for SageMaker interactions. If not
66+
specified, one is created using the default AWS configuration
67+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6168
Returns:
6269
List[BaseDeserializer]: The supported deserializers to use for the model.
6370
@@ -76,6 +83,7 @@ def retrieve_options(
7683
region,
7784
tolerate_vulnerable_model,
7885
tolerate_deprecated_model,
86+
sagemaker_session=sagemaker_session,
7987
)
8088

8189

@@ -85,6 +93,7 @@ def retrieve_default(
8593
model_version: Optional[str] = None,
8694
tolerate_vulnerable_model: bool = False,
8795
tolerate_deprecated_model: bool = False,
96+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8897
) -> BaseDeserializer:
8998
"""Retrieves the default deserializer for the model matching the given arguments.
9099
@@ -102,6 +111,10 @@ def retrieve_default(
102111
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
103112
(exception not raised). False if these models should raise an exception.
104113
(Default: False).
114+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
115+
object, used for SageMaker interactions. If not
116+
specified, one is created using the default AWS configuration
117+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
105118
Returns:
106119
BaseDeserializer: The default deserializer to use for the model.
107120
@@ -120,4 +133,5 @@ def retrieve_default(
120133
region,
121134
tolerate_vulnerable_model,
122135
tolerate_deprecated_model,
136+
sagemaker_session=sagemaker_session,
123137
)

src/sagemaker/environment_variables.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.session import Session
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -30,6 +32,7 @@ def retrieve_default(
3032
tolerate_vulnerable_model: bool = False,
3133
tolerate_deprecated_model: bool = False,
3234
include_aws_sdk_env_vars: bool = True,
35+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3336
) -> Dict[str, str]:
3437
"""Retrieves the default container environment variables for the model matching the arguments.
3538
@@ -51,6 +54,10 @@ def retrieve_default(
5154
should be included. The `Model` class of the SageMaker Python SDK inserts environment
5255
variables that would be required when making the low-level AWS API call.
5356
(Default: True).
57+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
58+
object, used for SageMaker interactions. If not
59+
specified, one is created using the default AWS configuration
60+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5461
Returns:
5562
dict: The variables to use for the model.
5663
@@ -70,4 +77,5 @@ def retrieve_default(
7077
tolerate_vulnerable_model,
7178
tolerate_deprecated_model,
7279
include_aws_sdk_env_vars,
80+
sagemaker_session=sagemaker_session,
7381
)

src/sagemaker/estimator.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@
9999
)
100100
from sagemaker.workflow import is_pipeline_variable
101101
from sagemaker.workflow.entities import PipelineVariable
102-
from sagemaker.workflow.parameters import ParameterString
103102
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
104103

105104
logger = logging.getLogger(__name__)
@@ -614,16 +613,21 @@ def __init__(
614613
self.output_kms_key = resolve_value_from_config(
615614
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
616615
)
616+
use_volume_kms_config: bool = False
617617
if instance_type is None or isinstance(instance_type, str):
618618
instance_type_for_volume_kms = instance_type
619-
elif isinstance(instance_type, ParameterString):
620-
instance_type_for_volume_kms = instance_type.default_value
619+
elif isinstance(instance_type, PipelineVariable):
620+
use_volume_kms_config = True
621+
instance_type_for_volume_kms = instance_type
621622
else:
622623
raise ValueError(f"Bad value for instance type: '{instance_type}'")
623624

624625
# KMS can only be attached to supported instances
625626
use_volume_kms_config = (
626-
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms))
627+
use_volume_kms_config
628+
or (
629+
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)
630+
)
627631
or instance_groups is not None
628632
and any(
629633
[
@@ -1425,6 +1429,24 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14251429
Instance of the calling ``Estimator`` Class with the attached
14261430
training job.
14271431
"""
1432+
return cls._attach(
1433+
training_job_name=training_job_name,
1434+
sagemaker_session=sagemaker_session,
1435+
model_channel_name=model_channel_name,
1436+
)
1437+
1438+
@classmethod
1439+
def _attach(
1440+
cls,
1441+
training_job_name: str,
1442+
sagemaker_session: Optional[str] = None,
1443+
model_channel_name: str = "model",
1444+
additional_kwargs: Optional[Dict[str, Any]] = None,
1445+
) -> "EstimatorBase":
1446+
"""Creates an Estimator bound to an existing training job.
1447+
1448+
Additional kwargs are allowed for instantiating Estimator.
1449+
"""
14281450
sagemaker_session = sagemaker_session or Session()
14291451

14301452
job_details = sagemaker_session.sagemaker_client.describe_training_job(
@@ -1436,6 +1458,9 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
14361458
)["Tags"]
14371459
init_params.update(tags=tags)
14381460

1461+
if additional_kwargs:
1462+
init_params.update(additional_kwargs)
1463+
14391464
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
14401465
estimator.latest_training_job = _TrainingJob(
14411466
sagemaker_session=sagemaker_session, job_name=training_job_name

0 commit comments

Comments
 (0)