Skip to content

Commit 525e9ae

Browse files
Captainialiujiaorr
authored andcommitted
feat: support JumpStart proprietary models (aws#4467)
* feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <[email protected]>
1 parent e95ed65 commit 525e9ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2014
-500
lines changed

doc/doc_utils/jumpstart_doc_utils.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ class Frameworks(str, Enum):
7474

7575
JUMPSTART_REGION = "eu-west-2"
7676
SDK_MANIFEST_FILE = "models_manifest.json"
77+
PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json"
7778
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
7879
JUMPSTART_REGION, JUMPSTART_REGION
7980
)
81+
PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com"
82+
8083
TASK_MAP = {
8184
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION,
8285
Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING,
@@ -152,18 +155,26 @@ class Frameworks(str, Enum):
152155
}
153156

154157

155-
def get_jumpstart_sdk_manifest():
156-
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE)
158+
def get_public_s3_json_object(url):
157159
with request.urlopen(url) as f:
158160
models_manifest = f.read().decode("utf-8")
159161
return json.loads(models_manifest)
160162

161163

162-
def get_jumpstart_sdk_spec(key):
163-
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key)
164-
with request.urlopen(url) as f:
165-
model_spec = f.read().decode("utf-8")
166-
return json.loads(model_spec)
164+
def get_jumpstart_sdk_manifest():
165+
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}")
166+
167+
168+
def get_proprietary_sdk_manifest():
169+
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}")
170+
171+
172+
def get_jumpstart_sdk_spec(s3_key: str):
173+
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}")
174+
175+
176+
def get_proprietary_sdk_spec(s3_key: str):
177+
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}")
167178

168179

169180
def get_model_task(id):
@@ -196,6 +207,45 @@ def get_model_source(url):
196207
return "Source"
197208

198209

210+
def create_proprietary_model_table():
211+
proprietary_content_intro = []
212+
proprietary_content_intro.append("\n")
213+
proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n")
214+
proprietary_content_intro.append(" :widths: 50 20 20 20 20\n")
215+
proprietary_content_intro.append(" :header-rows: 1\n")
216+
proprietary_content_intro.append(" :class: datatable\n")
217+
proprietary_content_intro.append("\n")
218+
proprietary_content_intro.append(" * - Model ID\n")
219+
proprietary_content_intro.append(" - Fine Tunable?\n")
220+
proprietary_content_intro.append(" - Supported Version\n")
221+
proprietary_content_intro.append(" - Min SDK Version\n")
222+
proprietary_content_intro.append(" - Source\n")
223+
224+
sdk_manifest = get_proprietary_sdk_manifest()
225+
sdk_manifest_top_versions_for_models = {}
226+
227+
for model in sdk_manifest:
228+
if model["model_id"] not in sdk_manifest_top_versions_for_models:
229+
sdk_manifest_top_versions_for_models[model["model_id"]] = model
230+
else:
231+
if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str(
232+
model["version"]
233+
):
234+
sdk_manifest_top_versions_for_models[model["model_id"]] = model
235+
236+
proprietary_content_entries = []
237+
for model in sdk_manifest_top_versions_for_models.values():
238+
model_spec = get_proprietary_sdk_spec(model["spec_key"])
239+
proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
240+
proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training
241+
proprietary_content_entries.append(" - {}\n".format(model["version"]))
242+
proprietary_content_entries.append(" - {}\n".format(model["min_version"]))
243+
proprietary_content_entries.append(
244+
" - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url"))
245+
)
246+
return proprietary_content_intro + proprietary_content_entries + ["\n"]
247+
248+
199249
def create_jumpstart_model_table():
200250
sdk_manifest = get_jumpstart_sdk_manifest()
201251
sdk_manifest_top_versions_for_models = {}
@@ -249,19 +299,19 @@ def create_jumpstart_model_table():
249299
file_content_intro.append(" - Source\n")
250300

251301
dynamic_table_files = []
252-
file_content_entries = []
302+
open_weight_content_entries = []
253303

254304
for model in sdk_manifest_top_versions_for_models.values():
255305
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
256306
model_task = get_model_task(model_spec["model_id"])
257307
string_model_task = get_string_model_task(model_spec["model_id"])
258308
model_source = get_model_source(model_spec["url"])
259-
file_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
260-
file_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
261-
file_content_entries.append(" - {}\n".format(model["version"]))
262-
file_content_entries.append(" - {}\n".format(model["min_version"]))
263-
file_content_entries.append(" - {}\n".format(model_task))
264-
file_content_entries.append(
309+
open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
310+
open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
311+
open_weight_content_entries.append(" - {}\n".format(model["version"]))
312+
open_weight_content_entries.append(" - {}\n".format(model["min_version"]))
313+
open_weight_content_entries.append(" - {}\n".format(model_task))
314+
open_weight_content_entries.append(
265315
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
266316
)
267317

@@ -299,7 +349,10 @@ def create_jumpstart_model_table():
299349
f.writelines(file_content_single_entry)
300350
f.close()
301351

352+
proprietary_content_entries = create_proprietary_model_table()
353+
302354
f = open("doc_utils/pretrainedmodels.rst", "a")
303355
f.writelines(file_content_intro)
304-
f.writelines(file_content_entries)
356+
f.writelines(open_weight_content_entries)
357+
f.writelines(proprietary_content_entries)
305358
f.close()

src/sagemaker/accept_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.jumpstart.enums import JumpStartModelType
1920
from sagemaker.session import Session
2021

2122

@@ -80,6 +81,7 @@ def retrieve_default(
8081
tolerate_vulnerable_model: bool = False,
8182
tolerate_deprecated_model: bool = False,
8283
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
84+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
8385
) -> str:
8486
"""Retrieves the default accept type for the model matching the given arguments.
8587
@@ -122,4 +124,5 @@ def retrieve_default(
122124
tolerate_vulnerable_model=tolerate_vulnerable_model,
123125
tolerate_deprecated_model=tolerate_deprecated_model,
124126
sagemaker_session=sagemaker_session,
127+
model_type=model_type,
125128
)

src/sagemaker/base_predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@
5858
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5959

6060
from sagemaker.lineage.context import EndpointContext
61-
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
61+
from sagemaker.compute_resource_requirements.resource_requirements import (
62+
ResourceRequirements,
63+
)
6264

6365
LOGGER = logging.getLogger("sagemaker")
6466

src/sagemaker/content_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.jumpstart.enums import JumpStartModelType
1920
from sagemaker.session import Session
2021

2122

@@ -80,6 +81,7 @@ def retrieve_default(
8081
tolerate_vulnerable_model: bool = False,
8182
tolerate_deprecated_model: bool = False,
8283
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
84+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
8385
) -> str:
8486
"""Retrieves the default content type for the model matching the given arguments.
8587
@@ -122,6 +124,7 @@ def retrieve_default(
122124
tolerate_vulnerable_model=tolerate_vulnerable_model,
123125
tolerate_deprecated_model=tolerate_deprecated_model,
124126
sagemaker_session=sagemaker_session,
127+
model_type=model_type,
125128
)
126129

127130

src/sagemaker/deserializers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
3737
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
38+
from sagemaker.jumpstart.enums import JumpStartModelType
3839
from sagemaker.session import Session
3940

4041

@@ -100,6 +101,7 @@ def retrieve_default(
100101
tolerate_vulnerable_model: bool = False,
101102
tolerate_deprecated_model: bool = False,
102103
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
104+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
103105
) -> BaseDeserializer:
104106
"""Retrieves the default deserializer for the model matching the given arguments.
105107
@@ -143,4 +145,5 @@ def retrieve_default(
143145
tolerate_vulnerable_model=tolerate_vulnerable_model,
144146
tolerate_deprecated_model=tolerate_deprecated_model,
145147
sagemaker_session=sagemaker_session,
148+
model_type=model_type,
146149
)

src/sagemaker/instance_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.jumpstart.enums import JumpStartModelType
2324
from sagemaker.session import Session
2425

2526
logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ def retrieve_default(
3536
tolerate_deprecated_model: bool = False,
3637
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3738
training_instance_type: Optional[str] = None,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3840
) -> str:
3941
"""Retrieves the default instance type for the model matching the given arguments.
4042
@@ -89,6 +91,7 @@ def retrieve_default(
8991
tolerate_deprecated_model,
9092
sagemaker_session=sagemaker_session,
9193
training_instance_type=training_instance_type,
94+
model_type=model_type,
9295
)
9396

9497

src/sagemaker/jumpstart/accessors.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sagemaker.deprecations import deprecated
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
21+
from sagemaker.jumpstart.enums import JumpStartModelType
2122
from sagemaker.jumpstart import cache
2223
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
2324
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
@@ -198,7 +199,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
198199

199200
@staticmethod
200201
def _get_manifest(
201-
region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None
202+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
203+
s3_client: Optional[boto3.client] = None,
204+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
202205
) -> List[JumpStartModelHeader]:
203206
"""Return entire JumpStart models manifest.
204207
@@ -216,13 +219,19 @@ def _get_manifest(
216219
additional_kwargs.update({"s3_client": s3_client})
217220

218221
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
219-
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region
222+
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs},
223+
region,
220224
)
221225
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
222-
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
226+
return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore
223227

224228
@staticmethod
225-
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
229+
def get_model_header(
230+
region: str,
231+
model_id: str,
232+
version: str,
233+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
234+
) -> JumpStartModelHeader:
226235
"""Returns model header from JumpStart models cache.
227236
228237
Args:
@@ -235,7 +244,9 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
235244
)
236245
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
237246
return JumpStartModelsAccessor._cache.get_header( # type: ignore
238-
model_id=model_id, semantic_version_str=version
247+
model_id=model_id,
248+
semantic_version_str=version,
249+
model_type=model_type,
239250
)
240251

241252
@staticmethod
@@ -245,6 +256,7 @@ def get_model_specs(
245256
version: str,
246257
hub_arn: Optional[str] = None,
247258
s3_client: Optional[boto3.client] = None,
259+
model_type=JumpStartModelType.OPEN_WEIGHTS,
248260
) -> JumpStartModelSpecs:
249261
"""Returns model specs from JumpStart models cache.
250262
@@ -272,7 +284,7 @@ def get_model_specs(
272284
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
273285

274286
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
275-
model_id=model_id, semantic_version_str=version
287+
model_id=model_id, version_str=version, model_type=model_type
276288
)
277289

278290
@staticmethod

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from sagemaker.jumpstart.enums import (
2424
JumpStartScriptScope,
25+
JumpStartModelType,
2526
)
2627
from sagemaker.jumpstart.utils import (
2728
verify_model_region_and_return_specs,
@@ -39,6 +40,7 @@ def _retrieve_default_instance_type(
3940
tolerate_deprecated_model: bool = False,
4041
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4142
training_instance_type: Optional[str] = None,
43+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4244
) -> str:
4345
"""Retrieves the default instance type for the model.
4446
@@ -88,6 +90,7 @@ def _retrieve_default_instance_type(
8890
region=region,
8991
tolerate_vulnerable_model=tolerate_vulnerable_model,
9092
tolerate_deprecated_model=tolerate_deprecated_model,
93+
model_type=model_type,
9194
sagemaker_session=sagemaker_session,
9295
)
9396

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from sagemaker.jumpstart.enums import (
2424
JumpStartScriptScope,
25+
JumpStartModelType,
2526
)
2627
from sagemaker.jumpstart.utils import (
2728
verify_model_region_and_return_specs,
@@ -36,6 +37,7 @@ def _retrieve_model_init_kwargs(
3637
tolerate_vulnerable_model: bool = False,
3738
tolerate_deprecated_model: bool = False,
3839
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3941
) -> dict:
4042
"""Retrieves kwargs for `Model`.
4143
@@ -75,6 +77,7 @@ def _retrieve_model_init_kwargs(
7577
tolerate_vulnerable_model=tolerate_vulnerable_model,
7678
tolerate_deprecated_model=tolerate_deprecated_model,
7779
sagemaker_session=sagemaker_session,
80+
model_type=model_type,
7881
)
7982

8083
kwargs = deepcopy(model_specs.model_kwargs)
@@ -94,6 +97,7 @@ def _retrieve_model_deploy_kwargs(
9497
tolerate_vulnerable_model: bool = False,
9598
tolerate_deprecated_model: bool = False,
9699
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
100+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
97101
) -> dict:
98102
"""Retrieves kwargs for `Model.deploy`.
99103
@@ -136,6 +140,7 @@ def _retrieve_model_deploy_kwargs(
136140
tolerate_vulnerable_model=tolerate_vulnerable_model,
137141
tolerate_deprecated_model=tolerate_deprecated_model,
138142
sagemaker_session=sagemaker_session,
143+
model_type=model_type,
139144
)
140145

141146
if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None:

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from sagemaker.jumpstart.enums import (
2424
JumpStartScriptScope,
25+
JumpStartModelType,
2526
)
2627
from sagemaker.session import Session
2728

@@ -36,6 +37,7 @@ def _retrieve_model_package_arn(
3637
tolerate_vulnerable_model: bool = False,
3738
tolerate_deprecated_model: bool = False,
3839
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3941
) -> Optional[str]:
4042
"""Retrieves associated model pacakge arn for the model.
4143
@@ -78,6 +80,7 @@ def _retrieve_model_package_arn(
7880
tolerate_vulnerable_model=tolerate_vulnerable_model,
7981
tolerate_deprecated_model=tolerate_deprecated_model,
8082
sagemaker_session=sagemaker_session,
83+
model_type=model_type,
8184
)
8285

8386
if scope == JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)