Skip to content

Commit eb3f92c

Browse files
chrstfubenieric
authored andcommitted
feat: Adding Bedrock Store model support for HubService (#1539)
* feat: Adding BRS support
1 parent 0373886 commit eb3f92c

File tree

10 files changed

+232
-21
lines changed

10 files changed

+232
-21
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class VariableTypes(str, Enum):
8282
BOOL = "bool"
8383

8484

85+
class HubContentCapability(str, Enum):
86+
"""Enum class for HubContent capabilities."""
87+
88+
BEDROCK_CONSOLE = "BEDROCK_CONSOLE"
89+
90+
8591
class JumpStartTag(str, Enum):
8692
"""Enum class for tag keys to apply to JumpStart models."""
8793

@@ -99,6 +105,8 @@ class JumpStartTag(str, Enum):
99105

100106
HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"
101107

108+
BEDROCK = "sagemaker-sdk:bedrock"
109+
102110

103111
class SerializerType(str, Enum):
104112
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/factory/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from sagemaker.model_metrics import ModelMetrics
4343
from sagemaker.metadata_properties import MetadataProperties
4444
from sagemaker.drift_check_baselines import DriftCheckBaselines
45-
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
45+
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability
4646
from sagemaker.jumpstart.types import (
4747
HubContentType,
4848
JumpStartModelDeployKwargs,
@@ -53,6 +53,7 @@
5353
from sagemaker.jumpstart.utils import (
5454
add_hub_content_arn_tags,
5555
add_jumpstart_model_info_tags,
56+
add_bedrock_store_tags,
5657
get_default_jumpstart_session_with_user_agent_suffix,
5758
get_top_ranked_config_name,
5859
update_dict_if_key_not_present,
@@ -495,6 +496,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
495496
)
496497
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
497498

499+
if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
500+
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
501+
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
502+
498503
return kwargs
499504

500505

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,17 @@
1818
from typing import Any, Dict, List, Optional
1919

2020

21+
<<<<<<< HEAD
2122
def camel_to_snake(camel_case_string: str) -> str:
2223
"""Converts camelCase to snake_case_string using a regex.
2324
2425
This regex cannot handle whitespace ("camelString TwoWords")
26+
=======
27+
def pascal_to_snake(camel_case_string: str) -> str:
28+
"""Converts PascalCase to snake_case_string using a regex.
29+
30+
This regex cannot handle whitespace ("PascalString TwoWords")
31+
>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539))
2532
"""
2633
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2734

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
HubModelDocument,
2828
)
2929
from sagemaker.jumpstart.hub.parser_utils import (
30-
camel_to_snake,
30+
pascal_to_snake,
3131
snake_to_upper_camel,
3232
walk_and_apply_json,
3333
)
@@ -86,7 +86,7 @@ def get_model_spec_arg_keys(
8686
arg_keys = []
8787

8888
if naming_convention == NamingConventionType.SNAKE_CASE:
89-
arg_keys = [camel_to_snake(key) for key in arg_keys]
89+
arg_keys = [pascal_to_snake(key) for key in arg_keys]
9090
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
9191
return arg_keys
9292
else:
@@ -207,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response(
207207
default_payloads: Dict[str, Any] = {}
208208
if hub_model_document.default_payloads is not None:
209209
for alias, payload in hub_model_document.default_payloads.items():
210-
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
210+
default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake)
211211
specs["default_payloads"] = default_payloads
212212
specs["gated_bucket"] = hub_model_document.gated_bucket
213213
specs["inference_volume_size"] = hub_model_document.inference_volume_size

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,31 @@ def get_hub_model_version(
219219
except Exception as ex:
220220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
221221

222+
<<<<<<< HEAD
223+
=======
224+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
225+
hub_content_summaries, hub_model_version
226+
)
227+
228+
>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539))
222229
try:
223230
return _get_hub_model_version_for_open_weight_version(
224231
hub_content_summaries, hub_model_version
225232
)
233+
<<<<<<< HEAD
226234
except KeyError:
227235
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
228236
hub_content_summaries, hub_model_version
229237
)
230238
if marketplace_hub_content_version:
231239
return marketplace_hub_content_version
232240
raise
241+
=======
242+
except KeyError as e:
243+
if marketplace_hub_content_version:
244+
return marketplace_hub_content_version
245+
raise e
246+
>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539))
233247

234248

235249
def _get_hub_model_version_for_open_weight_version(

src/sagemaker/jumpstart/types.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
4141
from sagemaker.enums import EndpointType
4242
from sagemaker.jumpstart.hub.parser_utils import (
43-
camel_to_snake,
43+
pascal_to_snake,
4444
walk_and_apply_json,
4545
)
4646
from sagemaker.model_life_cycle import ModelLifeCycle
@@ -241,7 +241,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
241241
return
242242

243243
if self._is_hub_content:
244-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
244+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
245245

246246
self.framework = json_obj.get("framework")
247247
self.framework_version = json_obj.get("framework_version")
@@ -295,7 +295,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
295295
"""
296296

297297
if self._is_hub_content:
298-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
298+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
299299
self.name = json_obj["name"]
300300
self.type = json_obj["type"]
301301
self.default = json_obj["default"]
@@ -363,7 +363,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
363363
Args:
364364
json_obj (Dict[str, Any]): Dictionary representation of environment variable.
365365
"""
366-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
366+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
367367
self.name = json_obj["name"]
368368
self.type = json_obj["type"]
369369
self.default = json_obj["default"]
@@ -413,7 +413,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
413413
return
414414

415415
if self._is_hub_content:
416-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
416+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
417417
self.default_content_type = json_obj["default_content_type"]
418418
self.supported_content_types = json_obj["supported_content_types"]
419419
self.default_accept_type = json_obj["default_accept_type"]
@@ -467,7 +467,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
467467
return
468468

469469
if self._is_hub_content:
470-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
470+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
471471
self.raw_payload = json_obj
472472
self.content_type = json_obj["content_type"]
473473
self.body = json_obj.get("body")
@@ -540,7 +540,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]])
540540
if response is None:
541541
return
542542

543-
response = walk_and_apply_json(response, camel_to_snake)
543+
response = walk_and_apply_json(response, pascal_to_snake)
544544
self.aliases: Optional[dict] = response.get("aliases")
545545
self.regional_aliases = None
546546
self.variants: Optional[dict] = response.get("variants")
@@ -1180,7 +1180,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
11801180
spec (Dict[str, Any]): Dictionary representation of training config ranking.
11811181
"""
11821182
if is_hub_content:
1183-
spec = walk_and_apply_json(spec, camel_to_snake)
1183+
spec = walk_and_apply_json(spec, pascal_to_snake)
11841184
self.from_json(spec)
11851185

11861186
def from_json(self, json_obj: Dict[str, Any]) -> None:
@@ -1286,7 +1286,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12861286
json_obj (Dict[str, Any]): Dictionary representation of spec.
12871287
"""
12881288
if self._is_hub_content:
1289-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1289+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
12901290
self.model_id: str = json_obj.get("model_id")
12911291
self.url: str = json_obj.get("url")
12921292
self.version: str = json_obj.get("version")
@@ -1515,7 +1515,7 @@ def __init__(
15151515
ValueError: If the component field is invalid.
15161516
"""
15171517
if is_hub_content:
1518-
component = walk_and_apply_json(component, camel_to_snake)
1518+
component = walk_and_apply_json(component, pascal_to_snake)
15191519
self.component_name = component_name
15201520
super().__init__(component, is_hub_content)
15211521
self.from_json(component)
@@ -1568,8 +1568,8 @@ def __init__(
15681568
The list of components that are used to construct the resolved config.
15691569
"""
15701570
if is_hub_content:
1571-
config = walk_and_apply_json(config, camel_to_snake)
1572-
base_fields = walk_and_apply_json(base_fields, camel_to_snake)
1571+
config = walk_and_apply_json(config, pascal_to_snake)
1572+
base_fields = walk_and_apply_json(base_fields, pascal_to_snake)
15731573
self.base_fields = base_fields
15741574
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
15751575
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
@@ -1735,7 +1735,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17351735
"""
17361736
super().from_json(json_obj)
17371737
if self._is_hub_content:
1738-
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
1738+
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
17391739
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
17401740
{
17411741
component_name: JumpStartConfigComponent(component_name, component)

src/sagemaker/jumpstart/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from sagemaker.jumpstart import constants, enums
3737
from sagemaker.jumpstart import accessors
38-
from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel
38+
from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake, snake_to_upper_camel
3939
from sagemaker.s3 import parse_s3_url
4040
from sagemaker.jumpstart.exceptions import (
4141
DeprecatedJumpStartModelError,
@@ -455,6 +455,21 @@ def add_hub_content_arn_tags(
455455
return tags
456456

457457

458+
def add_bedrock_store_tags(
459+
tags: Optional[List[TagsDict]],
460+
compatibility: str,
461+
) -> Optional[List[TagsDict]]:
462+
"""Adds custom Hub arn tag to JumpStart related resources."""
463+
464+
tags = add_single_jumpstart_tag(
465+
compatibility,
466+
enums.JumpStartTag.BEDROCK,
467+
tags,
468+
is_uri=False,
469+
)
470+
return tags
471+
472+
458473
def add_jumpstart_uri_tags(
459474
tags: Optional[List[TagsDict]] = None,
460475
inference_model_uri: Optional[Union[str, dict]] = None,
@@ -1163,7 +1178,7 @@ def get_jumpstart_configs(
11631178
return (
11641179
{
11651180
config_name: metadata_configs.configs[
1166-
camel_to_snake(snake_to_upper_camel(config_name))
1181+
pascal_to_snake(snake_to_upper_camel(config_name))
11671182
]
11681183
for config_name in config_names
11691184
}

0 commit comments

Comments
 (0)