Skip to content

fix: ModelReference deployment for Alt Configs models #4813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 1, 2024
Merged
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ def _add_config_name_to_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)

if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,14 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)
if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
default_config_name = (
specs.inference_configs.get_top_config_from_ranking().config_name
if specs.inference_configs.get_top_config_from_ranking()
else None
)
kwargs.config_name = kwargs.config_name or default_config_name

if not kwargs.config_name:
Expand Down Expand Up @@ -707,6 +712,7 @@ def _add_additional_model_data_sources_to_kwargs(
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)
# Append speculative decoding data source from metadata
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
Expand Down Expand Up @@ -750,6 +756,7 @@ def _add_config_name_to_deploy_kwargs(
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
hub_arn=kwargs.hub_arn,
)

if training_config_name:
Expand All @@ -759,7 +766,11 @@ def _add_config_name_to_deploy_kwargs(
return kwargs

if specs.inference_configs:
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
default_config_name = (
specs.inference_configs.get_top_config_from_ranking().config_name
if specs.inference_configs.get_top_config_from_ranking()
else None
)
kwargs.config_name = kwargs.config_name or default_config_name

return kwargs
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def _validate_model_id_and_type():
model_version=self.model_version,
sagemaker_session=self.sagemaker_session,
model_type=self.model_type,
hub_arn=self.hub_arn,
)

def log_subscription_warning(self) -> None:
Expand Down
123 changes: 82 additions & 41 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Args:
json_obj (Dict[str, Any]): Dictionary representation of spec.
"""
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
self.model_id: str = json_obj.get("model_id")
self.url: str = json_obj.get("url")
self.version: str = json_obj.get("version")
Expand Down Expand Up @@ -1722,6 +1724,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
json_obj (Dict[str, Any]): Dictionary representation of spec.
"""
super().from_json(json_obj)
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
{
component_name: JumpStartConfigComponent(component_name, component)
Expand All @@ -1732,32 +1736,50 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
)
self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
{
alias: JumpStartConfigRanking(ranking)
alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
for alias, ranking in json_obj["inference_config_rankings"].items()
}
if json_obj.get("inference_config_rankings")
else None
)
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
component_name: self.inference_config_components.get(component_name)
for component_name in config.get("component_names")
}
if config and config.get("component_names")
else None
),
)
for alias, config in json_obj["inference_configs"].items()
}
if json_obj.get("inference_configs")
else None
)

if self._is_hub_content:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: since we are here, it would be good if we could update line 1792 for training configs, but doesn't have to be in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the callout, I added those changed as well

inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
config.config_components,
is_hub_content=self._is_hub_content,
)
for alias, config in json_obj["inference_configs"]["configs"].items()
}
if json_obj.get("inference_configs")
else None
)
else:
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
component_name: self.inference_config_components.get(component_name)
for component_name in config.get("component_names")
}
if config and config.get("component_names")
else None
),
)
for alias, config in json_obj["inference_configs"].items()
}
if json_obj.get("inference_configs")
else None
)

self.inference_configs: Optional[JumpStartMetadataConfigs] = (
JumpStartMetadataConfigs(
inference_configs_dict,
Expand All @@ -1784,26 +1806,45 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if json_obj.get("training_config_rankings")
else None
)
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
component_name: self.training_config_components.get(component_name)
for component_name in config.get("component_names")
}
if config and config.get("component_names")
else None
),
)
for alias, config in json_obj["training_configs"].items()
}
if json_obj.get("training_configs")
else None
)

if self._is_hub_content:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
config.config_components,
is_hub_content=self._is_hub_content,
)
for alias, config in json_obj["training_configs"]["configs"].items()
}
if json_obj.get("training_configs")
else None
)
else:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
config,
json_obj,
(
{
component_name: self.training_config_components.get(
component_name
)
for component_name in config.get("component_names")
}
if config and config.get("component_names")
else None
),
)
for alias, config in json_obj["training_configs"].items()
}
if json_obj.get("training_configs")
else None
)

self.training_configs: Optional[JumpStartMetadataConfigs] = (
JumpStartMetadataConfigs(
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,7 @@ def get_jumpstart_configs(
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
) -> Dict[str, JumpStartMetadataConfig]:
"""Returns metadata configs for the given model ID and region.

Expand All @@ -1087,6 +1088,7 @@ def get_jumpstart_configs(
sagemaker_session=sagemaker_session,
scope=scope,
model_type=model_type,
hub_arn=hub_arn,
)

if scope == enums.JumpStartScriptScope.INFERENCE:
Expand Down