Skip to content

Commit b857ead

Browse files
evakravibencrabtree
authored andcommitted
feat: instance specific jumpstart host requirements (aws#4397)
* feat: instance specific jumpstart host requirements * chore: add js support for copies resource requirement, enforce coupling with ResourceRequirements class * fix: typing * fix: pylint
1 parent ac4e861 commit b857ead

File tree

7 files changed

+193
-15
lines changed

7 files changed

+193
-15
lines changed

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains functions for obtaining JumpStart resoure requirements."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional
16+
from typing import Dict, Optional, Tuple
1717

1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -28,6 +28,20 @@
2828
from sagemaker.session import Session
2929
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
3030

31+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[
32+
str, Dict[str, Tuple[str, str]]
33+
] = {
34+
"requests": {
35+
"num_accelerators": ("num_accelerators", "num_accelerators"),
36+
"num_cpus": ("num_cpus", "num_cpus"),
37+
"copies": ("copies", "copy_count"),
38+
"min_memory_mb": ("memory", "min_memory"),
39+
},
40+
"limits": {
41+
"max_memory_mb": ("memory", "max_memory"),
42+
},
43+
}
44+
3145

3246
def _retrieve_default_resources(
3347
model_id: str,
@@ -38,6 +52,7 @@ def _retrieve_default_resources(
3852
tolerate_vulnerable_model: bool = False,
3953
tolerate_deprecated_model: bool = False,
4054
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
55+
instance_type: Optional[str] = None,
4156
) -> ResourceRequirements:
4257
"""Retrieves the default resource requirements for the model.
4358
@@ -63,6 +78,8 @@ def _retrieve_default_resources(
6378
object, used for SageMaker interactions. If not
6479
specified, one is created using the default AWS configuration
6580
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
81+
instance_type (str): An instance type to optionally supply in order to get
82+
host requirements specific for the instance type.
6683
Returns:
6784
str: The default resource requirements to use for the model or None.
6885
@@ -91,23 +108,44 @@ def _retrieve_default_resources(
91108
is_dynamic_container_deployment_supported = (
92109
model_specs.dynamic_container_deployment_supported
93110
)
94-
default_resource_requirements = model_specs.hosting_resource_requirements
111+
default_resource_requirements: Dict[str, int] = (
112+
model_specs.hosting_resource_requirements or {}
113+
)
95114
else:
96115
raise NotImplementedError(
97116
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
98117
)
99118

119+
instance_specific_resource_requirements: Dict[str, int] = (
120+
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
121+
instance_type
122+
)
123+
if instance_type
124+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
125+
else {}
126+
)
127+
128+
default_resource_requirements = {
129+
**default_resource_requirements,
130+
**instance_specific_resource_requirements,
131+
}
132+
100133
if is_dynamic_container_deployment_supported:
101-
requests = {}
102-
if "num_accelerators" in default_resource_requirements:
103-
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
104-
if "min_memory_mb" in default_resource_requirements:
105-
requests["memory"] = default_resource_requirements["min_memory_mb"]
106-
if "num_cpus" in default_resource_requirements:
107-
requests["num_cpus"] = default_resource_requirements["num_cpus"]
108-
109-
limits = {}
110-
if "max_memory_mb" in default_resource_requirements:
111-
limits["memory"] = default_resource_requirements["max_memory_mb"]
112-
return ResourceRequirements(requests=requests, limits=limits)
134+
135+
all_resource_requirement_kwargs = {}
136+
137+
for (
138+
requirement_type,
139+
spec_field_to_resource_requirement_map,
140+
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
141+
requirement_kwargs = {}
142+
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
143+
if spec_field in default_resource_requirements:
144+
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
145+
spec_field
146+
]
147+
148+
all_resource_requirement_kwargs[requirement_type] = requirement_kwargs
149+
150+
return ResourceRequirements(**all_resource_requirement_kwargs)
113151
return None

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
503503
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
504504
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
505505
sagemaker_session=kwargs.sagemaker_session,
506+
instance_type=kwargs.instance_type,
506507
)
507508

508509
return kwargs

src/sagemaker/jumpstart/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
512512
instance_type=instance_type, property_name="artifact_key"
513513
)
514514

515+
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
516+
"""Returns instance specific resource requirements.
517+
518+
If a value exists for both the instance family and instance type, the instance type value
519+
is chosen.
520+
"""
521+
522+
instance_specific_resource_requirements: dict = (
523+
self.variants.get(instance_type, {})
524+
.get("properties", {})
525+
.get("resource_requirements", {})
526+
)
527+
528+
instance_type_family = get_instance_type_family(instance_type)
529+
530+
instance_family_resource_requirements: dict = (
531+
self.variants.get(instance_type_family, {})
532+
.get("properties", {})
533+
.get("resource_requirements", {})
534+
)
535+
536+
return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
537+
515538
def _get_instance_specific_property(
516539
self, instance_type: str, property_name: str
517540
) -> Optional[str]:

src/sagemaker/resource_requirements.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
from typing import Optional
19+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
1920

2021
from sagemaker.jumpstart import utils as jumpstart_utils
2122
from sagemaker.jumpstart import artifacts
@@ -34,7 +35,8 @@ def retrieve_default(
3435
tolerate_vulnerable_model: bool = False,
3536
tolerate_deprecated_model: bool = False,
3637
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
37-
) -> str:
38+
instance_type: Optional[str] = None,
39+
) -> ResourceRequirements:
3840
"""Retrieves the default resource requirements for the model matching the given arguments.
3941
4042
Args:
@@ -59,6 +61,8 @@ def retrieve_default(
5961
object, used for SageMaker interactions. If not
6062
specified, one is created using the default AWS configuration
6163
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
instance_type (str): An instance type to optionally supply in order to get
65+
host requirements specific for the instance type.
6266
Returns:
6367
str: The default resource requirements to use for the model.
6468
@@ -83,4 +87,5 @@ def retrieve_default(
8387
tolerate_vulnerable_model=tolerate_vulnerable_model,
8488
tolerate_deprecated_model=tolerate_deprecated_model,
8589
sagemaker_session=sagemaker_session,
90+
instance_type=instance_type,
8691
)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,22 @@
840840
"model_package_arn": "$gpu_model_package_arn",
841841
}
842842
},
843+
"g5": {
844+
"properties": {
845+
"resource_requirements": {
846+
"num_accelerators": 888810,
847+
"randon-field-2": 2222,
848+
}
849+
}
850+
},
843851
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
844852
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
853+
"ml.g5.xlarge": {
854+
"properties": {
855+
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
856+
"resource_requirements": {"num_accelerators": 10},
857+
}
858+
},
845859
"ml.g5.48xlarge": {
846860
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
847861
},
@@ -857,6 +871,12 @@
857871
"framework_version": "1.5.0",
858872
"py_version": "py3",
859873
},
874+
"dynamic_container_deployment_supported": True,
875+
"hosting_resource_requirements": {
876+
"min_memory_mb": 81999,
877+
"num_accelerators": 1,
878+
"random_field_1": 1,
879+
},
860880
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
861881
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
862882
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"variants": {
3535
"ml.p2.12xlarge": {
3636
"properties": {
37+
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
3738
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
3839
"supported_inference_instance_types": ["ml.p5.xlarge"],
3940
"default_inference_instance_type": "ml.p5.xlarge",
@@ -60,6 +61,11 @@
6061
"p2": {
6162
"regional_properties": {"image_uri": "$gpu_image_uri"},
6263
"properties": {
64+
"resource_requirements": {
65+
"req2": {"2": 5, "9": 999},
66+
"req3": 999,
67+
"req4": "blah",
68+
},
6369
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
6470
"default_inference_instance_type": "ml.p2.xlarge",
6571
"metrics": [
@@ -880,3 +886,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
880886
)
881887
is None
882888
)
889+
890+
891+
def test_jumpstart_resource_requirements_instance_variants():
892+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
893+
instance_type="ml.p2.xlarge"
894+
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}
895+
896+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
897+
instance_type="ml.p2.12xlarge"
898+
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}
899+
900+
assert (
901+
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
902+
instance_type="ml.p99.12xlarge"
903+
)
904+
== {}
905+
)

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import pytest
1919

2020
from sagemaker import resource_requirements
21+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
22+
from sagemaker.jumpstart.artifacts.resource_requirements import (
23+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP,
24+
)
2125

2226
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2327

@@ -47,6 +51,55 @@ def test_jumpstart_resource_requirements(patched_get_model_specs):
4751
patched_get_model_specs.reset_mock()
4852

4953

54+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
55+
def test_jumpstart_resource_requirements_instance_type_variants(patched_get_model_specs):
56+
57+
patched_get_model_specs.side_effect = get_special_model_spec
58+
region = "us-west-2"
59+
mock_client = boto3.client("s3")
60+
mock_session = Mock(s3_client=mock_client)
61+
62+
model_id, model_version = "variant-model", "*"
63+
default_inference_resource_requirements = resource_requirements.retrieve_default(
64+
region=region,
65+
model_id=model_id,
66+
model_version=model_version,
67+
scope="inference",
68+
sagemaker_session=mock_session,
69+
instance_type="ml.g5.xlarge",
70+
)
71+
assert default_inference_resource_requirements.requests == {
72+
"memory": 81999,
73+
"num_accelerators": 10,
74+
}
75+
76+
default_inference_resource_requirements = resource_requirements.retrieve_default(
77+
region=region,
78+
model_id=model_id,
79+
model_version=model_version,
80+
scope="inference",
81+
sagemaker_session=mock_session,
82+
instance_type="ml.g5.555xlarge",
83+
)
84+
assert default_inference_resource_requirements.requests == {
85+
"memory": 81999,
86+
"num_accelerators": 888810,
87+
}
88+
89+
default_inference_resource_requirements = resource_requirements.retrieve_default(
90+
region=region,
91+
model_id=model_id,
92+
model_version=model_version,
93+
scope="inference",
94+
sagemaker_session=mock_session,
95+
instance_type="ml.f9.555xlarge",
96+
)
97+
assert default_inference_resource_requirements.requests == {
98+
"memory": 81999,
99+
"num_accelerators": 1,
100+
}
101+
102+
50103
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
51104
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
52105
patched_get_model_specs.side_effect = get_special_model_spec
@@ -74,3 +127,18 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
74127
resource_requirements.retrieve_default(
75128
region=region, model_id=model_id, model_version=model_version, scope="training"
76129
)
130+
131+
132+
def test_jumpstart_supports_all_resource_requirement_fields():
133+
134+
all_tracked_resource_requirement_fields = {
135+
field
136+
for requirements in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.values()
137+
for _, field in requirements.values()
138+
}
139+
140+
excluded_resource_requirement_fields = {"requests", "limits"}
141+
assert (
142+
set(ResourceRequirements().__dict__.keys()) - excluded_resource_requirement_fields
143+
== all_tracked_resource_requirement_fields
144+
)

0 commit comments

Comments
 (0)