Skip to content

Commit 521f0dc

Browse files
authored
Merge branch 'master' into s3_overwrite
2 parents 9de7c80 + 1b4dc7c commit 521f0dc

File tree

5 files changed

+99
-43
lines changed

5 files changed

+99
-43
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def _add_config_name_to_kwargs(
912912
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
913913
sagemaker_session=kwargs.sagemaker_session,
914914
config_name=kwargs.config_name,
915+
hub_arn=kwargs.hub_arn,
915916
)
916917

917918
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():

src/sagemaker/jumpstart/factory/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,14 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
672672
sagemaker_session=kwargs.sagemaker_session,
673673
model_type=kwargs.model_type,
674674
config_name=kwargs.config_name,
675+
hub_arn=kwargs.hub_arn,
675676
)
676677
if specs.inference_configs:
677-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
678+
default_config_name = (
679+
specs.inference_configs.get_top_config_from_ranking().config_name
680+
if specs.inference_configs.get_top_config_from_ranking()
681+
else None
682+
)
678683
kwargs.config_name = kwargs.config_name or default_config_name
679684

680685
if not kwargs.config_name:
@@ -707,6 +712,7 @@ def _add_additional_model_data_sources_to_kwargs(
707712
sagemaker_session=kwargs.sagemaker_session,
708713
model_type=kwargs.model_type,
709714
config_name=kwargs.config_name,
715+
hub_arn=kwargs.hub_arn,
710716
)
711717
# Append speculative decoding data source from metadata
712718
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
@@ -750,6 +756,7 @@ def _add_config_name_to_deploy_kwargs(
750756
sagemaker_session=kwargs.sagemaker_session,
751757
model_type=kwargs.model_type,
752758
config_name=kwargs.config_name,
759+
hub_arn=kwargs.hub_arn,
753760
)
754761

755762
if training_config_name:
@@ -759,7 +766,11 @@ def _add_config_name_to_deploy_kwargs(
759766
return kwargs
760767

761768
if specs.inference_configs:
762-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
769+
default_config_name = (
770+
specs.inference_configs.get_top_config_from_ranking().config_name
771+
if specs.inference_configs.get_top_config_from_ranking()
772+
else None
773+
)
763774
kwargs.config_name = kwargs.config_name or default_config_name
764775

765776
return kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def _validate_model_id_and_type():
393393
model_version=self.model_version,
394394
sagemaker_session=self.sagemaker_session,
395395
model_type=self.model_type,
396+
hub_arn=self.hub_arn,
396397
)
397398

398399
def log_subscription_warning(self) -> None:

src/sagemaker/jumpstart/types.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12771277
Args:
12781278
json_obj (Dict[str, Any]): Dictionary representation of spec.
12791279
"""
1280+
if self._is_hub_content:
1281+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
12801282
self.model_id: str = json_obj.get("model_id")
12811283
self.url: str = json_obj.get("url")
12821284
self.version: str = json_obj.get("version")
@@ -1722,6 +1724,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17221724
json_obj (Dict[str, Any]): Dictionary representation of spec.
17231725
"""
17241726
super().from_json(json_obj)
1727+
if self._is_hub_content:
1728+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
17251729
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
17261730
{
17271731
component_name: JumpStartConfigComponent(component_name, component)
@@ -1732,32 +1736,50 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17321736
)
17331737
self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
17341738
{
1735-
alias: JumpStartConfigRanking(ranking)
1739+
alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
17361740
for alias, ranking in json_obj["inference_config_rankings"].items()
17371741
}
17381742
if json_obj.get("inference_config_rankings")
17391743
else None
17401744
)
1741-
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1742-
{
1743-
alias: JumpStartMetadataConfig(
1744-
alias,
1745-
config,
1746-
json_obj,
1747-
(
1748-
{
1749-
component_name: self.inference_config_components.get(component_name)
1750-
for component_name in config.get("component_names")
1751-
}
1752-
if config and config.get("component_names")
1753-
else None
1754-
),
1755-
)
1756-
for alias, config in json_obj["inference_configs"].items()
1757-
}
1758-
if json_obj.get("inference_configs")
1759-
else None
1760-
)
1745+
1746+
if self._is_hub_content:
1747+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1748+
{
1749+
alias: JumpStartMetadataConfig(
1750+
alias,
1751+
config,
1752+
json_obj,
1753+
config.config_components,
1754+
is_hub_content=self._is_hub_content,
1755+
)
1756+
for alias, config in json_obj["inference_configs"]["configs"].items()
1757+
}
1758+
if json_obj.get("inference_configs")
1759+
else None
1760+
)
1761+
else:
1762+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1763+
{
1764+
alias: JumpStartMetadataConfig(
1765+
alias,
1766+
config,
1767+
json_obj,
1768+
(
1769+
{
1770+
component_name: self.inference_config_components.get(component_name)
1771+
for component_name in config.get("component_names")
1772+
}
1773+
if config and config.get("component_names")
1774+
else None
1775+
),
1776+
)
1777+
for alias, config in json_obj["inference_configs"].items()
1778+
}
1779+
if json_obj.get("inference_configs")
1780+
else None
1781+
)
1782+
17611783
self.inference_configs: Optional[JumpStartMetadataConfigs] = (
17621784
JumpStartMetadataConfigs(
17631785
inference_configs_dict,
@@ -1784,26 +1806,45 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17841806
if json_obj.get("training_config_rankings")
17851807
else None
17861808
)
1787-
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1788-
{
1789-
alias: JumpStartMetadataConfig(
1790-
alias,
1791-
config,
1792-
json_obj,
1793-
(
1794-
{
1795-
component_name: self.training_config_components.get(component_name)
1796-
for component_name in config.get("component_names")
1797-
}
1798-
if config and config.get("component_names")
1799-
else None
1800-
),
1801-
)
1802-
for alias, config in json_obj["training_configs"].items()
1803-
}
1804-
if json_obj.get("training_configs")
1805-
else None
1806-
)
1809+
1810+
if self._is_hub_content:
1811+
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1812+
{
1813+
alias: JumpStartMetadataConfig(
1814+
alias,
1815+
config,
1816+
json_obj,
1817+
config.config_components,
1818+
is_hub_content=self._is_hub_content,
1819+
)
1820+
for alias, config in json_obj["training_configs"]["configs"].items()
1821+
}
1822+
if json_obj.get("training_configs")
1823+
else None
1824+
)
1825+
else:
1826+
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1827+
{
1828+
alias: JumpStartMetadataConfig(
1829+
alias,
1830+
config,
1831+
json_obj,
1832+
(
1833+
{
1834+
component_name: self.training_config_components.get(
1835+
component_name
1836+
)
1837+
for component_name in config.get("component_names")
1838+
}
1839+
if config and config.get("component_names")
1840+
else None
1841+
),
1842+
)
1843+
for alias, config in json_obj["training_configs"].items()
1844+
}
1845+
if json_obj.get("training_configs")
1846+
else None
1847+
)
18071848

18081849
self.training_configs: Optional[JumpStartMetadataConfigs] = (
18091850
JumpStartMetadataConfigs(

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def get_jumpstart_configs(
10741074
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
10751075
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
10761076
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1077+
hub_arn: Optional[str] = None,
10771078
) -> Dict[str, JumpStartMetadataConfig]:
10781079
"""Returns metadata configs for the given model ID and region.
10791080
@@ -1087,6 +1088,7 @@ def get_jumpstart_configs(
10871088
sagemaker_session=sagemaker_session,
10881089
scope=scope,
10891090
model_type=model_type,
1091+
hub_arn=hub_arn,
10901092
)
10911093

10921094
if scope == enums.JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)