Skip to content

Commit cf8e5cb

Browse files
committed
feat: Adding Bedrock Store model support for HubService (aws#1539)
* feat: Adding BRS support
1 parent 4727161 commit cf8e5cb

File tree

13 files changed

+474
-87
lines changed

13 files changed

+474
-87
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ class VariableTypes(str, Enum):
8383

8484

8585
class HubContentCapability(str, Enum):
86-
"""Enum class for HubContent capabilities."""
86+
"""Enum class for HubContent capabilities."""
87+
88+
BEDROCK_CONSOLE = "BEDROCK_CONSOLE"
8789

88-
BEDROCK_CONSOLE = "BEDROCK_CONSOLE"
8990

9091
class JumpStartTag(str, Enum):
9192
"""Enum class for tag keys to apply to JumpStart models."""

src/sagemaker/jumpstart/factory/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
add_hub_content_arn_tags,
5454
add_bedrock_store_tags,
5555
add_jumpstart_model_info_tags,
56+
add_bedrock_store_tags,
5657
get_default_jumpstart_session_with_user_agent_suffix,
5758
get_neo_content_bucket,
5859
get_top_ranked_config_name,
@@ -489,9 +490,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
489490
)
490491
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
491492

492-
if kwargs.specs.capabilities is not None:
493-
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
494-
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
493+
if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
494+
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
495+
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
495496

496497
return kwargs
497498

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ class HubModelDocument(HubDataHolderType):
457457
"url",
458458
"min_sdk_version",
459459
"training_supported",
460+
"model_types",
460461
"capabilities",
461462
"incremental_training_supported",
462463
"dynamic_container_deployment_supported",
@@ -551,7 +552,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
551552
Args:
552553
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
553554
"""
554-
self.url: str = json_obj["Url"]
555+
self.url: str = json_obj.get("Url")
555556
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
556557
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
557558
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
@@ -561,9 +562,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
561562
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
562563
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
563564
]
565+
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
564566
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
565567
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
566-
self.incremental_training_supported: bool = bool(json_obj.get("IncrementalTrainingSupported"))
568+
self.incremental_training_supported: bool = bool(
569+
json_obj.get("IncrementalTrainingSupported")
570+
)
567571
self.dynamic_container_deployment_supported: Optional[bool] = (
568572
bool(json_obj.get("DynamicContainerDeploymentSupported"))
569573
if json_obj.get("DynamicContainerDeploymentSupported")

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
from typing import Any, Dict, List, Optional
1919

2020

21-
def camel_to_snake(camel_case_string: str) -> str:
22-
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
23-
return re.sub(r'(?<!^)(?=[A-Z])', '_', camel_case_string).lower()
21+
def pascal_to_snake(camel_case_string: str) -> str:
22+
"""Converts PascalCase to snake_case_string using a regex.
23+
24+
This regex cannot handle whitespace ("PascalString TwoWords")
25+
"""
26+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2427

2528

2629
def snake_to_upper_camel(snake_case_string: str) -> str:

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
HubModelDocument,
2828
)
2929
from sagemaker.jumpstart.hub.parser_utils import (
30-
camel_to_snake,
30+
pascal_to_snake,
3131
snake_to_upper_camel,
3232
walk_and_apply_json,
3333
)
@@ -86,7 +86,7 @@ def get_model_spec_arg_keys(
8686
arg_keys = []
8787

8888
if naming_convention == NamingConventionType.SNAKE_CASE:
89-
arg_keys = [camel_to_snake(key) for key in arg_keys]
89+
arg_keys = [pascal_to_snake(key) for key in arg_keys]
9090
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
9191
return arg_keys
9292
else:
@@ -137,6 +137,7 @@ def make_model_specs_from_describe_hub_content_response(
137137
hub_model_document: HubModelDocument = response.hub_content_document
138138
specs["url"] = hub_model_document.url
139139
specs["min_sdk_version"] = hub_model_document.min_sdk_version
140+
specs["model_types"] = hub_model_document.model_types
140141
specs["capabilities"] = hub_model_document.capabilities
141142
specs["training_supported"] = bool(hub_model_document.training_supported)
142143
specs["incremental_training_supported"] = bool(
@@ -148,18 +149,18 @@ def make_model_specs_from_describe_hub_content_response(
148149
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
149150

150151
if hub_model_document.hosting_artifact_uri:
151-
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
152-
hub_model_document.hosting_artifact_uri
153-
)
154-
specs["hosting_artifact_key"] = hosting_artifact_key
155-
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
152+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
153+
hub_model_document.hosting_artifact_uri
154+
)
155+
specs["hosting_artifact_key"] = hosting_artifact_key
156+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
156157

157158
if hub_model_document.hosting_script_uri:
158-
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
159-
hub_model_document.hosting_script_uri
160-
)
161-
specs["hosting_script_key"] = hosting_script_key
162-
159+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
160+
hub_model_document.hosting_script_uri
161+
)
162+
specs["hosting_script_key"] = hosting_script_key
163+
163164
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
164165
specs["inference_vulnerable"] = False
165166
specs["inference_dependencies"] = hub_model_document.inference_dependencies
@@ -206,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response(
206207
default_payloads: Dict[str, Any] = {}
207208
if hub_model_document.default_payloads is not None:
208209
for alias, payload in hub_model_document.default_payloads.items():
209-
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
210+
default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake)
210211
specs["default_payloads"] = default_payloads
211212
specs["gated_bucket"] = hub_model_document.gated_bucket
212213
specs["inference_volume_size"] = hub_model_document.inference_volume_size
@@ -227,6 +228,8 @@ def make_model_specs_from_describe_hub_content_response(
227228

228229
specs["model_subscription_link"] = hub_model_document.model_subscription_link
229230

231+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
232+
230233
specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
231234

232235
specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""This module contains utilities related to SageMaker JumpStart Hub."""
1515
from __future__ import absolute_import
1616
import re
17-
from typing import Optional
17+
from typing import Optional, List, Any
1818
from sagemaker.jumpstart.hub.types import S3ObjectLocation
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
@@ -23,6 +23,14 @@
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
2525

26+
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
27+
28+
29+
def _convert_str_to_optional(string: str) -> Optional[str]:
30+
if string == "None":
31+
string = None
32+
return string
33+
2634

2735
def get_info_from_hub_resource_arn(
2836
arn: str,
@@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
3745
hub_name = match.group(4)
3846
hub_content_type = match.group(5)
3947
hub_content_name = match.group(6)
40-
hub_content_version = match.group(7)
48+
hub_content_version = _convert_str_to_optional(match.group(7))
4149

4250
return HubArnExtractedInfo(
4351
partition=partition,
@@ -194,10 +202,14 @@ def get_hub_model_version(
194202
hub_model_version: Optional[str] = None,
195203
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
196204
) -> str:
197-
"""Returns available Jumpstart hub model version
205+
"""Returns available Jumpstart hub model version.
206+
207+
It will attempt both a semantic HubContent version search and Marketplace version search.
208+
If the Marketplace version is also semantic, this function will default to HubContent version.
198209
199210
Raises:
200211
ClientError: If the specified model is not found in the hub.
212+
KeyError: If the specified model version is not found.
201213
"""
202214

203215
try:
@@ -207,6 +219,23 @@ def get_hub_model_version(
207219
except Exception as ex:
208220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
209221

222+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
223+
hub_content_summaries, hub_model_version
224+
)
225+
226+
try:
227+
return _get_hub_model_version_for_open_weight_version(
228+
hub_content_summaries, hub_model_version
229+
)
230+
except KeyError as e:
231+
if marketplace_hub_content_version:
232+
return marketplace_hub_content_version
233+
raise e
234+
235+
236+
def _get_hub_model_version_for_open_weight_version(
237+
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
238+
) -> str:
210239
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
211240

212241
if hub_model_version == "*" or hub_model_version is None:
@@ -222,3 +251,37 @@ def get_hub_model_version(
222251
hub_model_version = str(max(available_versions_filtered))
223252

224253
return hub_model_version
254+
255+
256+
def _get_hub_model_version_for_marketplace_version(
257+
hub_content_summaries: List[Any], marketplace_version: str
258+
) -> Optional[str]:
259+
"""Returns the HubContent version associated with the Marketplace version.
260+
261+
This function will check within the HubContentSearchKeywords for the proprietary version.
262+
"""
263+
for model in hub_content_summaries:
264+
model_search_keywords = model.get("HubContentSearchKeywords", [])
265+
if _hub_search_keywords_contains_marketplace_version(
266+
model_search_keywords, marketplace_version
267+
):
268+
return model.get("HubContentVersion")
269+
270+
return None
271+
272+
273+
def _hub_search_keywords_contains_marketplace_version(
274+
model_search_keywords: List[str], marketplace_version: str
275+
) -> bool:
276+
proprietary_version_keyword = next(
277+
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
278+
)
279+
280+
if not proprietary_version_keyword:
281+
return False
282+
283+
proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
284+
if proprietary_version == marketplace_version:
285+
return True
286+
287+
return False

src/sagemaker/jumpstart/types.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
4040
from sagemaker.enums import EndpointType
4141
from sagemaker.jumpstart.hub.parser_utils import (
42-
camel_to_snake,
42+
pascal_to_snake,
4343
walk_and_apply_json,
4444
)
4545

@@ -239,7 +239,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
239239
return
240240

241241
if self._is_hub_content:
242-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
242+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
243243

244244
self.framework = json_obj.get("framework")
245245
self.framework_version = json_obj.get("framework_version")
@@ -293,7 +293,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
293293
"""
294294

295295
if self._is_hub_content:
296-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
296+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
297297
self.name = json_obj["name"]
298298
self.type = json_obj["type"]
299299
self.default = json_obj["default"]
@@ -361,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
361361
Args:
362362
json_obj (Dict[str, Any]): Dictionary representation of environment variable.
363363
"""
364-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
364+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
365365
self.name = json_obj["name"]
366366
self.type = json_obj["type"]
367367
self.default = json_obj["default"]
@@ -411,7 +411,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
411411
return
412412

413413
if self._is_hub_content:
414-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
414+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
415415
self.default_content_type = json_obj["default_content_type"]
416416
self.supported_content_types = json_obj["supported_content_types"]
417417
self.default_accept_type = json_obj["default_accept_type"]
@@ -465,7 +465,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
465465
return
466466

467467
if self._is_hub_content:
468-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
468+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
469469
self.raw_payload = json_obj
470470
self.content_type = json_obj["content_type"]
471471
self.body = json_obj.get("body")
@@ -538,7 +538,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]])
538538
if response is None:
539539
return
540540

541-
response = walk_and_apply_json(response, camel_to_snake)
541+
response = walk_and_apply_json(response, pascal_to_snake)
542542
self.aliases: Optional[dict] = response.get("aliases")
543543
self.regional_aliases = None
544544
self.variants: Optional[dict] = response.get("variants")
@@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
11741174
spec (Dict[str, Any]): Dictionary representation of training config ranking.
11751175
"""
11761176
if is_hub_content:
1177-
spec = walk_and_apply_json(spec, camel_to_snake)
1177+
spec = walk_and_apply_json(spec, pascal_to_snake)
11781178
self.from_json(spec)
11791179

11801180
def from_json(self, json_obj: Dict[str, Any]) -> None:
@@ -1200,6 +1200,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
12001200
"url",
12011201
"version",
12021202
"min_sdk_version",
1203+
"model_types",
12031204
"capabilities",
12041205
"incremental_training_supported",
12051206
"hosting_ecr_specs",
@@ -1279,7 +1280,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12791280
json_obj (Dict[str, Any]): Dictionary representation of spec.
12801281
"""
12811282
if self._is_hub_content:
1282-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1283+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
12831284
self.model_id: str = json_obj.get("model_id")
12841285
self.url: str = json_obj.get("url")
12851286
self.version: str = json_obj.get("version")
@@ -1288,7 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12881289
json_obj.get("incremental_training_supported", False)
12891290
)
12901291
if self._is_hub_content:
1291-
self.capabilities: Optional[str] = json_obj.get("capabilities")
1292+
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
1293+
self.model_types: Optional[List[str]] = json_obj.get("model_types")
12921294
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
12931295
self._non_serializable_slots.append("hosting_ecr_specs")
12941296
else:
@@ -1507,7 +1509,7 @@ def __init__(
15071509
ValueError: If the component field is invalid.
15081510
"""
15091511
if is_hub_content:
1510-
component = walk_and_apply_json(component, camel_to_snake)
1512+
component = walk_and_apply_json(component, pascal_to_snake)
15111513
self.component_name = component_name
15121514
super().__init__(component, is_hub_content)
15131515
self.from_json(component)
@@ -1560,8 +1562,8 @@ def __init__(
15601562
The list of components that are used to construct the resolved config.
15611563
"""
15621564
if is_hub_content:
1563-
config = walk_and_apply_json(config, camel_to_snake)
1564-
base_fields = walk_and_apply_json(base_fields, camel_to_snake)
1565+
config = walk_and_apply_json(config, pascal_to_snake)
1566+
base_fields = walk_and_apply_json(base_fields, pascal_to_snake)
15651567
self.base_fields = base_fields
15661568
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
15671569
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
@@ -1727,7 +1729,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17271729
"""
17281730
super().from_json(json_obj)
17291731
if self._is_hub_content:
1730-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1732+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
17311733
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
17321734
{
17331735
component_name: JumpStartConfigComponent(component_name, component)

0 commit comments

Comments
 (0)