Skip to content

Commit f4cd7d1

Browse files
authored
Merge branch 'master' into feat/jumpstart-construct-payload-utility
2 parents 60b3e9a + 03680b2 commit f4cd7d1

File tree

20 files changed

+956
-33
lines changed

20 files changed

+956
-33
lines changed

CHANGELOG.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
# Changelog
22

3+
## v2.193.0 (2023-10-18)
4+
5+
### Features
6+
7+
* jumpstart model artifact instance type variants
8+
* jumpstart instance specific hyperparameters
9+
* Feature Processor event based triggers (#1132)
10+
* Support job checkpoint in remote function
11+
* jumpstart model package arn instance type variants
12+
13+
### Bug Fixes and Other Changes
14+
15+
* Fix hyperlinks in feature_processor.scheduler parameter descriptions
16+
* add image_uris_unit_test pytest mark
17+
* bump apache-airflow to `v2.7.2`
18+
* clone distribution in validate_distribution
19+
* fix flaky Inference Recommender integration tests
20+
21+
### Documentation Changes
22+
23+
* Update PipelineModel.register documentation
24+
* specify that input_shape in no longer required for torch 2.0 mod…
25+
326
## v2.192.1 (2023-10-13)
427

528
### Bug Fixes and Other Changes

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.192.2.dev0
1+
2.193.1.dev0

doc/amazon_sagemaker_featurestore.rst

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,11 @@ The following code from the fraud detection example shows a minimal
230230
    enable_online_store=True
231231
)
232232
233-
Creating a feature group takes time as the data is loaded. You will need
234-
to wait until it is created before you can use it. You can check status
235-
using the following method.
233+
Creating a feature group takes time as the data is loaded. You will
234+
need to wait until it is created before you can use it. You can
235+
check status using the following method. Note that it can take
236+
approximately 10-15 minutes to provision an online ``FeatureGroup``
237+
with the ``InMemory`` ``StorageType``.
236238

237239
.. code:: python
238240
@@ -480,7 +482,9 @@ Feature Store `DatasetBuilder API Reference
480482
.. rubric:: Delete a feature group
481483
:name: bCe9CA61b78
482484

483-
You can delete a feature group with the ``delete`` function.
485+
You can delete a feature group with the ``delete`` function. Note that it
486+
can take approximately 10-15 minutes to delete an online ``FeatureGroup``
487+
with the ``InMemory`` ``StorageType``.
484488

485489
.. code:: python
486490

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ awslogs==0.14.0
1212
black==22.3.0
1313
stopit==1.1.2
1414
# Update tox.ini to have correct version of airflow constraints file
15-
apache-airflow==2.7.1
15+
apache-airflow==2.7.2
1616
apache-airflow-providers-amazon==7.2.1
1717
attrs>=23.1.0,<24
1818
fabric==2.6.0

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,13 @@ def schedule(
281281
Args:
282282
pipeline_name (str): The SageMaker Pipeline name that will be scheduled.
283283
schedule_expression (str): The expression that defines when the schedule runs. It supports
284-
at expression, rate expression and cron expression. See https://docs.aws.amazon.com/
285-
scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request
286-
-ScheduleExpression for more details.
284+
at expression, rate expression and cron expression. See '''https://docs.aws.amazon.com\
285+
/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-\
286+
request-ScheduleExpression''' for more details.
287287
state (str): Specifies whether the schedule is enabled or disabled. Valid values are
288-
ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/
289-
API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details.
290-
If not specified, it will default to ENABLED.
288+
ENABLED and DISABLED. See '''https://docs.aws.amazon.com/scheduler/latest/APIReference\
289+
/API_CreateSchedule.html#scheduler-CreateSchedule-request-State'''
290+
for more details. If not specified, it will default to ENABLED.
291291
start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin
292292
invoking its target. Depending on the schedule’s recurrence expression, invocations
293293
might occur on, or after, the StartDate you specify.

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,68 @@
2828
verify_model_region_and_return_specs,
2929
)
3030
from sagemaker.session import Session
31+
from sagemaker.jumpstart.types import JumpStartModelSpecs
32+
33+
34+
def _retrieve_hosting_prepacked_artifact_key(
35+
model_specs: JumpStartModelSpecs, instance_type: str
36+
) -> str:
37+
"""Returns instance specific hosting prepacked artifact key or default one as fallback."""
38+
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
39+
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
40+
instance_type=instance_type
41+
)
42+
if instance_type
43+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
44+
else None
45+
)
46+
47+
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
48+
model_specs, "hosting_prepacked_artifact_key"
49+
)
50+
51+
return (
52+
instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
53+
)
54+
55+
56+
def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
57+
"""Returns instance specific hosting artifact key or default one as fallback."""
58+
instance_specific_hosting_artifact_key: Optional[str] = (
59+
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
60+
instance_type=instance_type
61+
)
62+
if instance_type
63+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
64+
else None
65+
)
66+
67+
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
68+
69+
return instance_specific_hosting_artifact_key or default_hosting_artifact_key
70+
71+
72+
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
73+
"""Returns instance specific training artifact key or default one as fallback."""
74+
instance_specific_training_artifact_key: Optional[str] = (
75+
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
76+
instance_type=instance_type
77+
)
78+
if instance_type
79+
and getattr(model_specs, "training_instance_type_variants", None) is not None
80+
else None
81+
)
82+
83+
default_training_artifact_key: str = model_specs.training_artifact_key
84+
85+
return instance_specific_training_artifact_key or default_training_artifact_key
3186

3287

3388
def _retrieve_model_uri(
3489
model_id: str,
3590
model_version: str,
3691
model_scope: Optional[str] = None,
92+
instance_type: Optional[str] = None,
3793
region: Optional[str] = None,
3894
tolerate_vulnerable_model: bool = False,
3995
tolerate_deprecated_model: bool = False,
@@ -50,6 +106,7 @@ def _retrieve_model_uri(
50106
artifact S3 URI.
51107
model_scope (str): The model type, i.e. what it is used for.
52108
Valid values: "training" and "inference".
109+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
53110
region (str): Region for which to retrieve model S3 URI. (Default: None).
54111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
55112
specifications should be tolerated (exception not raised). If False, raises an
@@ -84,14 +141,21 @@ def _retrieve_model_uri(
84141
sagemaker_session=sagemaker_session,
85142
)
86143

144+
model_artifact_key: str
145+
87146
if model_scope == JumpStartScriptScope.INFERENCE:
147+
148+
is_prepacked = not model_specs.use_inference_script_uri()
149+
88150
model_artifact_key = (
89-
getattr(model_specs, "hosting_prepacked_artifact_key", None)
90-
or model_specs.hosting_artifact_key
151+
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
152+
if is_prepacked
153+
else _retrieve_hosting_artifact_key(model_specs, instance_type)
91154
)
92155

93156
elif model_scope == JumpStartScriptScope.TRAINING:
94-
model_artifact_key = model_specs.training_artifact_key
157+
158+
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
95159

96160
bucket = os.environ.get(
97161
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
475475
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
476476
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
477477
sagemaker_session=kwargs.sagemaker_session,
478+
instance_type=kwargs.instance_type,
478479
)
479480

480481
if (

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
215215
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
216216
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
217217
sagemaker_session=kwargs.sagemaker_session,
218+
instance_type=kwargs.instance_type,
218219
)
219220

220221
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):

src/sagemaker/jumpstart/types.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,62 @@ def to_json(self) -> Dict[str, Any]:
405405
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
406406
return json_obj
407407

408+
def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
409+
"""Returns instance specific model artifact key.
410+
411+
Returns None if a model, instance type tuple does not have specific
412+
artifact key.
413+
"""
414+
415+
return self._get_instance_specific_property(
416+
instance_type=instance_type, property_name="prepacked_artifact_key"
417+
)
418+
419+
def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
420+
"""Returns instance specific model artifact key.
421+
422+
Returns None if a model, instance type tuple does not have specific
423+
artifact key.
424+
"""
425+
426+
return self._get_instance_specific_property(
427+
instance_type=instance_type, property_name="artifact_key"
428+
)
429+
430+
def _get_instance_specific_property(
431+
self, instance_type: str, property_name: str
432+
) -> Optional[str]:
433+
"""Returns instance specific property.
434+
435+
If a value exists for both the instance family and instance type,
436+
the instance type value is chosen.
437+
438+
Returns None if a (model, instance type, property name) tuple does not have
439+
specific prepacked artifact key.
440+
"""
441+
442+
if self.variants is None:
443+
return None
444+
445+
instance_specific_property: Optional[str] = (
446+
self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
447+
)
448+
449+
if instance_specific_property:
450+
return instance_specific_property
451+
452+
instance_type_family = get_instance_type_family(instance_type)
453+
454+
instance_family_property: Optional[str] = (
455+
self.variants.get(instance_type_family, {})
456+
.get("properties", {})
457+
.get(property_name, None)
458+
if instance_type_family not in {"", None}
459+
else None
460+
)
461+
462+
return instance_family_property
463+
408464
def get_instance_specific_hyperparameters(
409465
self, instance_type: str
410466
) -> List[JumpStartHyperparameter]:

src/sagemaker/model.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,14 +623,7 @@ def prepare_container_def(
623623
)
624624
deploy_env = copy.deepcopy(self.env)
625625
if self.source_dir or self.dependencies or self.entry_point or self.git_config:
626-
is_repack = (
627-
self.source_dir
628-
and self.entry_point
629-
and not (
630-
(self.key_prefix and issubclass(type(self), FrameworkModel)) or self.git_config
631-
)
632-
)
633-
self._upload_code(deploy_key_prefix, repack=is_repack)
626+
self._upload_code(deploy_key_prefix, repack=self.is_repack())
634627
deploy_env.update(self._script_mode_env_vars())
635628

636629
return sagemaker.container_def(
@@ -640,6 +633,14 @@ def prepare_container_def(
640633
image_config=self.image_config,
641634
)
642635

636+
def is_repack(self) -> bool:
637+
"""Whether the source code needs to be repacked before uploading to S3.
638+
639+
Returns:
640+
bool: if the source need to be repacked or not
641+
"""
642+
return self.source_dir and self.entry_point and not self.git_config
643+
643644
def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
644645
"""Uploads code to S3 to be used with script mode with SageMaker inference.
645646
@@ -1775,6 +1776,14 @@ def __init__(
17751776
**kwargs,
17761777
)
17771778

1779+
def is_repack(self) -> bool:
1780+
"""Whether the source code needs to be repacked before uploading to S3.
1781+
1782+
Returns:
1783+
bool: if the source need to be repacked or not
1784+
"""
1785+
return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
1786+
17781787

17791788
# works for MODEL_PACKAGE_ARN with or without version info.
17801789
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"

src/sagemaker/model_uris.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve(
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
3232
model_scope: Optional[str] = None,
33+
instance_type: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -44,6 +45,7 @@ def retrieve(
4445
the model artifact S3 URI.
4546
model_scope (str): The model type.
4647
Valid values: "training" and "inference".
48+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
4749
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
4850
specifications should be tolerated without raising an exception. If ``False``, raises an
4951
exception if the script used by this version of the model has dependencies with known
@@ -71,11 +73,12 @@ def retrieve(
7173
)
7274

7375
return artifacts._retrieve_model_uri(
74-
model_id,
75-
model_version, # type: ignore
76-
model_scope,
77-
region,
78-
tolerate_vulnerable_model,
79-
tolerate_deprecated_model,
76+
model_id=model_id,
77+
model_version=model_version, # type: ignore
78+
model_scope=model_scope,
79+
instance_type=instance_type,
80+
region=region,
81+
tolerate_vulnerable_model=tolerate_vulnerable_model,
82+
tolerate_deprecated_model=tolerate_deprecated_model,
8083
sagemaker_session=sagemaker_session,
8184
)

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
import pathlib
1718

1819
import boto3
1920
import pytest
@@ -86,6 +87,16 @@
8687
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
8788

8889

90+
image_uris_unit_tests_dir = pathlib.Path("tests/unit/sagemaker/image_uris")
91+
92+
93+
def pytest_collection_modifyitems(config, items):
94+
for item in items:
95+
testmod = pathlib.Path(item.fspath)
96+
if config.rootdir / image_uris_unit_tests_dir in testmod.parents:
97+
item.add_marker(pytest.mark.image_uris_unit_test)
98+
99+
89100
def pytest_addoption(parser):
90101
parser.addoption("--sagemaker-client-config", action="store", default=None)
91102
parser.addoption("--sagemaker-runtime-config", action="store", default=None)

0 commit comments

Comments
 (0)