Skip to content

Commit 195ceee

Browse files
committed
Merge remote-tracking branch 'origin' into feat/jumpstart-model-artifact-instance-type-variants
2 parents cba4d68 + a44a755 commit 195ceee

File tree

5 files changed

+190
-15
lines changed

5 files changed

+190
-15
lines changed

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
def _retrieve_model_package_arn(
3030
model_id: str,
3131
model_version: str,
32+
instance_type: Optional[str],
3233
region: Optional[str],
3334
scope: Optional[str] = None,
3435
tolerate_vulnerable_model: bool = False,
@@ -42,6 +43,8 @@ def _retrieve_model_package_arn(
4243
retrieve the model package arn.
4344
model_version (str): Version of the JumpStart model for which to retrieve the
4445
model package arn.
46+
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
47+
specific for the instance type.
4548
region (Optional[str]): Region for which to retrieve the model package arn.
4649
scope (Optional[str]): Scope for which to retrieve the model package arn.
4750
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -75,6 +78,17 @@ def _retrieve_model_package_arn(
7578

7679
if scope == JumpStartScriptScope.INFERENCE:
7780

81+
instance_specific_arn: Optional[str] = (
82+
model_specs.hosting_instance_type_variants.get_model_package_arn(
83+
region=region, instance_type=instance_type
84+
)
85+
if getattr(model_specs, "hosting_instance_type_variants", None) is not None
86+
else None
87+
)
88+
89+
if instance_specific_arn is not None:
90+
return instance_specific_arn
91+
7892
if model_specs.hosting_model_package_arns is None:
7993
return None
8094

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
330330
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
331331
model_id=kwargs.model_id,
332332
model_version=kwargs.model_version,
333+
instance_type=kwargs.instance_type,
333334
scope=JumpStartScriptScope.INFERENCE,
334335
region=kwargs.region,
335336
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,

src/sagemaker/jumpstart/types.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -492,42 +492,64 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
492492
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
493493
"""Returns image uri from instance type and region.
494494
495+
Returns None if no instance type is available or found.
496+
None is also returned if the metadata is improperly formatted.
497+
"""
498+
return self._get_regional_property(
499+
instance_type=instance_type, region=region, property_name="image_uri"
500+
)
501+
502+
def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str]:
503+
"""Returns model package arn from instance type and region.
504+
505+
Returns None if no instance type is available or found.
506+
None is also returned if the metadata is improperly formatted.
507+
"""
508+
return self._get_regional_property(
509+
instance_type=instance_type, region=region, property_name="model_package_arn"
510+
)
511+
512+
def _get_regional_property(
513+
self, instance_type: str, region: str, property_name: str
514+
) -> Optional[str]:
515+
"""Returns regional property from instance type and region.
516+
495517
Returns None if no instance type is available or found.
496518
None is also returned if the metadata is improperly formatted.
497519
"""
498520

499521
if None in [self.regional_aliases, self.variants]:
500522
return None
501523

502-
image_uri_alias: Optional[str] = (
503-
self.variants.get(instance_type, {}).get("regional_properties", {}).get("image_uri")
524+
regional_property_alias: Optional[str] = (
525+
self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name)
504526
)
505-
if image_uri_alias is None:
527+
if regional_property_alias is None:
506528
instance_type_family = get_instance_type_family(instance_type)
507529

508530
if instance_type_family in {"", None}:
509531
return None
510532

511-
image_uri_alias = (
533+
regional_property_alias = (
512534
self.variants.get(instance_type_family, {})
513535
.get("regional_properties", {})
514-
.get("image_uri")
536+
.get(property_name)
515537
)
516538

517-
if image_uri_alias is None or len(image_uri_alias) == 0:
539+
if regional_property_alias is None or len(regional_property_alias) == 0:
518540
return None
519541

520-
if not image_uri_alias.startswith("$"):
542+
if not regional_property_alias.startswith("$"):
521543
# No leading '$' indicates bad metadata.
522544
# There are tests to ensure this never happens.
523545
# However, to allow for fallback options in the unlikely event
524546
# of a regression, we do not raise an exception here.
525-
# We return None, indicating the image uri does not exist.
547+
# We return None, indicating the field does not exist.
526548
return None
527549

528550
if region not in self.regional_aliases:
529551
return None
530-
alias_value = self.regional_aliases[region].get(image_uri_alias[1:], None)
552+
alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None)
531553
return alias_value
532554

533555

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@
181181
"min_sdk_version": "2.49.0",
182182
"training_supported": True,
183183
"incremental_training_supported": True,
184+
"hosting_model_package_arns": {
185+
"us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll"
186+
"ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
187+
},
184188
"hosting_ecr_specs": {
185189
"framework": "pytorch",
186190
"framework_version": "1.5.0",
@@ -192,13 +196,35 @@
192196
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
193197
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
194198
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
199+
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
200+
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
195201
}
196202
},
197203
"variants": {
198-
"p2": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
199-
"p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
200-
"p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
201-
"g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
204+
"p2": {
205+
"regional_properties": {
206+
"image_uri": "$gpu_image_uri",
207+
"model_package_arn": "$gpu_model_package_arn",
208+
}
209+
},
210+
"p3": {
211+
"regional_properties": {
212+
"image_uri": "$gpu_image_uri",
213+
"model_package_arn": "$gpu_model_package_arn",
214+
}
215+
},
216+
"p4": {
217+
"regional_properties": {
218+
"image_uri": "$gpu_image_uri",
219+
"model_package_arn": "$gpu_model_package_arn",
220+
}
221+
},
222+
"g4dn": {
223+
"regional_properties": {
224+
"image_uri": "$gpu_image_uri",
225+
"model_package_arn": "$gpu_model_package_arn",
226+
}
227+
},
202228
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
203229
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
204230
"ml.g5.48xlarge": {
@@ -207,6 +233,8 @@
207233
"ml.g5.12xlarge": {
208234
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
209235
},
236+
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
237+
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
210238
},
211239
},
212240
"training_ecr_specs": {
@@ -224,7 +252,6 @@
224252
"training_model_package_artifact_uris": None,
225253
"deprecate_warn_message": None,
226254
"deprecated_message": None,
227-
"hosting_model_package_arns": None,
228255
"hosting_eula_key": None,
229256
"hyperparameters": [
230257
{

tests/unit/sagemaker/jumpstart/test_artifacts.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import unittest
15+
from unittest.mock import Mock
1516

1617

1718
from mock.mock import patch
19+
import pytest
1820

1921
import copy
2022
from sagemaker.jumpstart import artifacts
@@ -28,8 +30,11 @@
2830
BASE_SPEC,
2931
)
3032

33+
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn
34+
from sagemaker.jumpstart.enums import JumpStartScriptScope
3135

32-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
36+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
37+
from tests.unit.sagemaker.workflow.conftest import mock_client
3338

3439

3540
class ModelArtifactVariantsTest(unittest.TestCase):
@@ -319,3 +324,109 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs):
319324
)
320325

321326
assert kwargs == {"some-estimator-fit-key": "some-estimator-fit-value"}
327+
328+
329+
class RetrieveModelPackageArnTest(unittest.TestCase):
330+
331+
mock_session = Mock(s3_client=mock_client)
332+
333+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
334+
def test_retrieve_model_package_arn(self, patched_get_model_specs):
335+
patched_get_model_specs.side_effect = get_special_model_spec
336+
337+
model_id = "variant-model"
338+
region = "us-west-2"
339+
340+
assert (
341+
_retrieve_model_package_arn(
342+
region=region,
343+
model_id=model_id,
344+
scope=JumpStartScriptScope.INFERENCE,
345+
model_version="*",
346+
sagemaker_session=self.mock_session,
347+
instance_type="ml.p2.48xlarge",
348+
)
349+
== "us-west-2/blah/blah/blah/gpu"
350+
)
351+
352+
assert (
353+
_retrieve_model_package_arn(
354+
region=region,
355+
model_id=model_id,
356+
scope=JumpStartScriptScope.INFERENCE,
357+
model_version="*",
358+
sagemaker_session=self.mock_session,
359+
instance_type="ml.p4.2xlarge",
360+
)
361+
== "us-west-2/blah/blah/blah/gpu"
362+
)
363+
364+
assert (
365+
_retrieve_model_package_arn(
366+
region=region,
367+
model_id=model_id,
368+
scope=JumpStartScriptScope.INFERENCE,
369+
model_version="*",
370+
sagemaker_session=self.mock_session,
371+
instance_type="ml.inf1.2xlarge",
372+
)
373+
== "us-west-2/blah/blah/blah/inf"
374+
)
375+
376+
assert (
377+
_retrieve_model_package_arn(
378+
region=region,
379+
model_id=model_id,
380+
scope=JumpStartScriptScope.INFERENCE,
381+
model_version="*",
382+
sagemaker_session=self.mock_session,
383+
instance_type="ml.inf2.12xlarge",
384+
)
385+
== "us-west-2/blah/blah/blah/inf"
386+
)
387+
388+
assert (
389+
_retrieve_model_package_arn(
390+
region=region,
391+
model_id=model_id,
392+
scope=JumpStartScriptScope.INFERENCE,
393+
model_version="*",
394+
sagemaker_session=self.mock_session,
395+
instance_type="ml.afasfasf.12xlarge",
396+
)
397+
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
398+
)
399+
400+
assert (
401+
_retrieve_model_package_arn(
402+
region=region,
403+
model_id=model_id,
404+
scope=JumpStartScriptScope.INFERENCE,
405+
model_version="*",
406+
sagemaker_session=self.mock_session,
407+
instance_type="ml.m2.12xlarge",
408+
)
409+
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
410+
)
411+
412+
assert (
413+
_retrieve_model_package_arn(
414+
region=region,
415+
model_id=model_id,
416+
scope=JumpStartScriptScope.INFERENCE,
417+
model_version="*",
418+
sagemaker_session=self.mock_session,
419+
instance_type="nobodycares",
420+
)
421+
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
422+
)
423+
424+
with pytest.raises(ValueError):
425+
_retrieve_model_package_arn(
426+
region="cn-north-1",
427+
model_id=model_id,
428+
scope=JumpStartScriptScope.INFERENCE,
429+
model_version="*",
430+
sagemaker_session=self.mock_session,
431+
instance_type="ml.p2.12xlarge",
432+
)

0 commit comments

Comments
 (0)