Skip to content

Commit 3a97e4b

Browse files
authored
Merge branch 'master' into master
2 parents 62a46c7 + 7213b5a commit 3a97e4b

File tree

14 files changed

+800
-20
lines changed

14 files changed

+800
-20
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ awslogs==0.14.0
1212
black==22.3.0
1313
stopit==1.1.2
1414
# Update tox.ini to have correct version of airflow constraints file
15-
apache-airflow==2.7.1
15+
apache-airflow==2.7.2
1616
apache-airflow-providers-amazon==7.2.1
1717
attrs>=23.1.0,<24
1818
fabric==2.6.0

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,13 @@ def schedule(
281281
Args:
282282
pipeline_name (str): The SageMaker Pipeline name that will be scheduled.
283283
schedule_expression (str): The expression that defines when the schedule runs. It supports
284-
at expression, rate expression and cron expression. See https://docs.aws.amazon.com/
285-
scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request
286-
-ScheduleExpression for more details.
284+
at expression, rate expression and cron expression. See '''https://docs.aws.amazon.com\
285+
/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-\
286+
request-ScheduleExpression''' for more details.
287287
state (str): Specifies whether the schedule is enabled or disabled. Valid values are
288-
ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/
289-
API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details.
290-
If not specified, it will default to ENABLED.
288+
ENABLED and DISABLED. See '''https://docs.aws.amazon.com/scheduler/latest/APIReference\
289+
/API_CreateSchedule.html#scheduler-CreateSchedule-request-State'''
290+
for more details. If not specified, it will default to ENABLED.
291291
start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin
292292
invoking its target. Depending on the schedule’s recurrence expression, invocations
293293
might occur on, or after, the StartDate you specify.

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,68 @@
2828
verify_model_region_and_return_specs,
2929
)
3030
from sagemaker.session import Session
31+
from sagemaker.jumpstart.types import JumpStartModelSpecs
32+
33+
34+
def _retrieve_hosting_prepacked_artifact_key(
35+
model_specs: JumpStartModelSpecs, instance_type: str
36+
) -> str:
37+
"""Returns instance specific hosting prepacked artifact key or default one as fallback."""
38+
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
39+
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
40+
instance_type=instance_type
41+
)
42+
if instance_type
43+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
44+
else None
45+
)
46+
47+
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
48+
model_specs, "hosting_prepacked_artifact_key"
49+
)
50+
51+
return (
52+
instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
53+
)
54+
55+
56+
def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
57+
"""Returns instance specific hosting artifact key or default one as fallback."""
58+
instance_specific_hosting_artifact_key: Optional[str] = (
59+
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
60+
instance_type=instance_type
61+
)
62+
if instance_type
63+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
64+
else None
65+
)
66+
67+
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
68+
69+
return instance_specific_hosting_artifact_key or default_hosting_artifact_key
70+
71+
72+
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
73+
"""Returns instance specific training artifact key or default one as fallback."""
74+
instance_specific_training_artifact_key: Optional[str] = (
75+
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
76+
instance_type=instance_type
77+
)
78+
if instance_type
79+
and getattr(model_specs, "training_instance_type_variants", None) is not None
80+
else None
81+
)
82+
83+
default_training_artifact_key: str = model_specs.training_artifact_key
84+
85+
return instance_specific_training_artifact_key or default_training_artifact_key
3186

3287

3388
def _retrieve_model_uri(
3489
model_id: str,
3590
model_version: str,
3691
model_scope: Optional[str] = None,
92+
instance_type: Optional[str] = None,
3793
region: Optional[str] = None,
3894
tolerate_vulnerable_model: bool = False,
3995
tolerate_deprecated_model: bool = False,
@@ -50,6 +106,7 @@ def _retrieve_model_uri(
50106
artifact S3 URI.
51107
model_scope (str): The model type, i.e. what it is used for.
52108
Valid values: "training" and "inference".
109+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
53110
region (str): Region for which to retrieve model S3 URI. (Default: None).
54111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
55112
specifications should be tolerated (exception not raised). If False, raises an
@@ -84,14 +141,21 @@ def _retrieve_model_uri(
84141
sagemaker_session=sagemaker_session,
85142
)
86143

144+
model_artifact_key: str
145+
87146
if model_scope == JumpStartScriptScope.INFERENCE:
147+
148+
is_prepacked = not model_specs.use_inference_script_uri()
149+
88150
model_artifact_key = (
89-
getattr(model_specs, "hosting_prepacked_artifact_key", None)
90-
or model_specs.hosting_artifact_key
151+
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
152+
if is_prepacked
153+
else _retrieve_hosting_artifact_key(model_specs, instance_type)
91154
)
92155

93156
elif model_scope == JumpStartScriptScope.TRAINING:
94-
model_artifact_key = model_specs.training_artifact_key
157+
158+
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
95159

96160
bucket = os.environ.get(
97161
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
475475
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
476476
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
477477
sagemaker_session=kwargs.sagemaker_session,
478+
instance_type=kwargs.instance_type,
478479
)
479480

480481
if (

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
215215
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
216216
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
217217
sagemaker_session=kwargs.sagemaker_session,
218+
instance_type=kwargs.instance_type,
218219
)
219220

220221
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):

src/sagemaker/jumpstart/types.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,62 @@ def to_json(self) -> Dict[str, Any]:
403403
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
404404
return json_obj
405405

406+
def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
407+
"""Returns instance specific model artifact key.
408+
409+
Returns None if a model, instance type tuple does not have specific
410+
artifact key.
411+
"""
412+
413+
return self._get_instance_specific_property(
414+
instance_type=instance_type, property_name="prepacked_artifact_key"
415+
)
416+
417+
def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
418+
"""Returns instance specific model artifact key.
419+
420+
Returns None if a model, instance type tuple does not have specific
421+
artifact key.
422+
"""
423+
424+
return self._get_instance_specific_property(
425+
instance_type=instance_type, property_name="artifact_key"
426+
)
427+
428+
def _get_instance_specific_property(
429+
self, instance_type: str, property_name: str
430+
) -> Optional[str]:
431+
"""Returns instance specific property.
432+
433+
If a value exists for both the instance family and instance type,
434+
the instance type value is chosen.
435+
436+
Returns None if a (model, instance type, property name) tuple does not have
437+
specific prepacked artifact key.
438+
"""
439+
440+
if self.variants is None:
441+
return None
442+
443+
instance_specific_property: Optional[str] = (
444+
self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
445+
)
446+
447+
if instance_specific_property:
448+
return instance_specific_property
449+
450+
instance_type_family = get_instance_type_family(instance_type)
451+
452+
instance_family_property: Optional[str] = (
453+
self.variants.get(instance_type_family, {})
454+
.get("properties", {})
455+
.get(property_name, None)
456+
if instance_type_family not in {"", None}
457+
else None
458+
)
459+
460+
return instance_family_property
461+
406462
def get_instance_specific_hyperparameters(
407463
self, instance_type: str
408464
) -> List[JumpStartHyperparameter]:

src/sagemaker/model_uris.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve(
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
3232
model_scope: Optional[str] = None,
33+
instance_type: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -44,6 +45,7 @@ def retrieve(
4445
the model artifact S3 URI.
4546
model_scope (str): The model type.
4647
Valid values: "training" and "inference".
48+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
4749
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
4850
specifications should be tolerated without raising an exception. If ``False``, raises an
4951
exception if the script used by this version of the model has dependencies with known
@@ -71,11 +73,12 @@ def retrieve(
7173
)
7274

7375
return artifacts._retrieve_model_uri(
74-
model_id,
75-
model_version, # type: ignore
76-
model_scope,
77-
region,
78-
tolerate_vulnerable_model,
79-
tolerate_deprecated_model,
76+
model_id=model_id,
77+
model_version=model_version, # type: ignore
78+
model_scope=model_scope,
79+
instance_type=instance_type,
80+
region=region,
81+
tolerate_vulnerable_model=tolerate_vulnerable_model,
82+
tolerate_deprecated_model=tolerate_deprecated_model,
8083
sagemaker_session=sagemaker_session,
8184
)

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
import pathlib
1718

1819
import boto3
1920
import pytest
@@ -86,6 +87,16 @@
8687
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
8788

8889

90+
image_uris_unit_tests_dir = pathlib.Path("tests/unit/sagemaker/image_uris")
91+
92+
93+
def pytest_collection_modifyitems(config, items):
94+
for item in items:
95+
testmod = pathlib.Path(item.fspath)
96+
if config.rootdir / image_uris_unit_tests_dir in testmod.parents:
97+
item.add_marker(pytest.mark.image_uris_unit_test)
98+
99+
89100
def pytest_addoption(parser):
90101
parser.addoption("--sagemaker-client-config", action="store", default=None)
91102
parser.addoption("--sagemaker-runtime-config", action="store", default=None)

0 commit comments

Comments
 (0)