Skip to content

Commit 37a4038

Browse files
authored
Merge branch 'master' into fix-remove-kwargs
2 parents 109c161 + 1b4dc7c commit 37a4038

18 files changed

+190
-135
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def _add_config_name_to_kwargs(
912912
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
913913
sagemaker_session=kwargs.sagemaker_session,
914914
config_name=kwargs.config_name,
915+
hub_arn=kwargs.hub_arn,
915916
)
916917

917918
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():

src/sagemaker/jumpstart/factory/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,14 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
672672
sagemaker_session=kwargs.sagemaker_session,
673673
model_type=kwargs.model_type,
674674
config_name=kwargs.config_name,
675+
hub_arn=kwargs.hub_arn,
675676
)
676677
if specs.inference_configs:
677-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
678+
default_config_name = (
679+
specs.inference_configs.get_top_config_from_ranking().config_name
680+
if specs.inference_configs.get_top_config_from_ranking()
681+
else None
682+
)
678683
kwargs.config_name = kwargs.config_name or default_config_name
679684

680685
if not kwargs.config_name:
@@ -707,6 +712,7 @@ def _add_additional_model_data_sources_to_kwargs(
707712
sagemaker_session=kwargs.sagemaker_session,
708713
model_type=kwargs.model_type,
709714
config_name=kwargs.config_name,
715+
hub_arn=kwargs.hub_arn,
710716
)
711717
# Append speculative decoding data source from metadata
712718
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
@@ -750,6 +756,7 @@ def _add_config_name_to_deploy_kwargs(
750756
sagemaker_session=kwargs.sagemaker_session,
751757
model_type=kwargs.model_type,
752758
config_name=kwargs.config_name,
759+
hub_arn=kwargs.hub_arn,
753760
)
754761

755762
if training_config_name:
@@ -759,7 +766,11 @@ def _add_config_name_to_deploy_kwargs(
759766
return kwargs
760767

761768
if specs.inference_configs:
762-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
769+
default_config_name = (
770+
specs.inference_configs.get_top_config_from_ranking().config_name
771+
if specs.inference_configs.get_top_config_from_ranking()
772+
else None
773+
)
763774
kwargs.config_name = kwargs.config_name or default_config_name
764775

765776
return kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def _validate_model_id_and_type():
393393
model_version=self.model_version,
394394
sagemaker_session=self.sagemaker_session,
395395
model_type=self.model_type,
396+
hub_arn=self.hub_arn,
396397
)
397398

398399
def log_subscription_warning(self) -> None:

src/sagemaker/jumpstart/types.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12771277
Args:
12781278
json_obj (Dict[str, Any]): Dictionary representation of spec.
12791279
"""
1280+
if self._is_hub_content:
1281+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
12801282
self.model_id: str = json_obj.get("model_id")
12811283
self.url: str = json_obj.get("url")
12821284
self.version: str = json_obj.get("version")
@@ -1722,6 +1724,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17221724
json_obj (Dict[str, Any]): Dictionary representation of spec.
17231725
"""
17241726
super().from_json(json_obj)
1727+
if self._is_hub_content:
1728+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
17251729
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
17261730
{
17271731
component_name: JumpStartConfigComponent(component_name, component)
@@ -1732,32 +1736,50 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17321736
)
17331737
self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
17341738
{
1735-
alias: JumpStartConfigRanking(ranking)
1739+
alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
17361740
for alias, ranking in json_obj["inference_config_rankings"].items()
17371741
}
17381742
if json_obj.get("inference_config_rankings")
17391743
else None
17401744
)
1741-
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1742-
{
1743-
alias: JumpStartMetadataConfig(
1744-
alias,
1745-
config,
1746-
json_obj,
1747-
(
1748-
{
1749-
component_name: self.inference_config_components.get(component_name)
1750-
for component_name in config.get("component_names")
1751-
}
1752-
if config and config.get("component_names")
1753-
else None
1754-
),
1755-
)
1756-
for alias, config in json_obj["inference_configs"].items()
1757-
}
1758-
if json_obj.get("inference_configs")
1759-
else None
1760-
)
1745+
1746+
if self._is_hub_content:
1747+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1748+
{
1749+
alias: JumpStartMetadataConfig(
1750+
alias,
1751+
config,
1752+
json_obj,
1753+
config.config_components,
1754+
is_hub_content=self._is_hub_content,
1755+
)
1756+
for alias, config in json_obj["inference_configs"]["configs"].items()
1757+
}
1758+
if json_obj.get("inference_configs")
1759+
else None
1760+
)
1761+
else:
1762+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1763+
{
1764+
alias: JumpStartMetadataConfig(
1765+
alias,
1766+
config,
1767+
json_obj,
1768+
(
1769+
{
1770+
component_name: self.inference_config_components.get(component_name)
1771+
for component_name in config.get("component_names")
1772+
}
1773+
if config and config.get("component_names")
1774+
else None
1775+
),
1776+
)
1777+
for alias, config in json_obj["inference_configs"].items()
1778+
}
1779+
if json_obj.get("inference_configs")
1780+
else None
1781+
)
1782+
17611783
self.inference_configs: Optional[JumpStartMetadataConfigs] = (
17621784
JumpStartMetadataConfigs(
17631785
inference_configs_dict,
@@ -1784,26 +1806,45 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
17841806
if json_obj.get("training_config_rankings")
17851807
else None
17861808
)
1787-
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1788-
{
1789-
alias: JumpStartMetadataConfig(
1790-
alias,
1791-
config,
1792-
json_obj,
1793-
(
1794-
{
1795-
component_name: self.training_config_components.get(component_name)
1796-
for component_name in config.get("component_names")
1797-
}
1798-
if config and config.get("component_names")
1799-
else None
1800-
),
1801-
)
1802-
for alias, config in json_obj["training_configs"].items()
1803-
}
1804-
if json_obj.get("training_configs")
1805-
else None
1806-
)
1809+
1810+
if self._is_hub_content:
1811+
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1812+
{
1813+
alias: JumpStartMetadataConfig(
1814+
alias,
1815+
config,
1816+
json_obj,
1817+
config.config_components,
1818+
is_hub_content=self._is_hub_content,
1819+
)
1820+
for alias, config in json_obj["training_configs"]["configs"].items()
1821+
}
1822+
if json_obj.get("training_configs")
1823+
else None
1824+
)
1825+
else:
1826+
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
1827+
{
1828+
alias: JumpStartMetadataConfig(
1829+
alias,
1830+
config,
1831+
json_obj,
1832+
(
1833+
{
1834+
component_name: self.training_config_components.get(
1835+
component_name
1836+
)
1837+
for component_name in config.get("component_names")
1838+
}
1839+
if config and config.get("component_names")
1840+
else None
1841+
),
1842+
)
1843+
for alias, config in json_obj["training_configs"].items()
1844+
}
1845+
if json_obj.get("training_configs")
1846+
else None
1847+
)
18071848

18081849
self.training_configs: Optional[JumpStartMetadataConfigs] = (
18091850
JumpStartMetadataConfigs(

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def get_jumpstart_configs(
10741074
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
10751075
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
10761076
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1077+
hub_arn: Optional[str] = None,
10771078
) -> Dict[str, JumpStartMetadataConfig]:
10781079
"""Returns metadata configs for the given model ID and region.
10791080
@@ -1087,6 +1088,7 @@ def get_jumpstart_configs(
10871088
sagemaker_session=sagemaker_session,
10881089
scope=scope,
10891090
model_type=model_type,
1091+
hub_arn=hub_arn,
10901092
)
10911093

10921094
if scope == enums.JumpStartScriptScope.INFERENCE:

src/sagemaker/workflow/automl_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __init__(
3434
self,
3535
name: str,
3636
step_args: _JobStepArguments,
37-
display_name: str = None,
38-
description: str = None,
39-
cache_config: CacheConfig = None,
37+
display_name: Optional[str] = None,
38+
description: Optional[str] = None,
39+
cache_config: Optional[CacheConfig] = None,
4040
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
41-
retry_policies: List[RetryPolicy] = None,
41+
retry_policies: Optional[List[RetryPolicy]] = None,
4242
):
4343
"""Construct a `AutoMLStep`, given a `AutoML` instance.
4444

src/sagemaker/workflow/callback_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def __init__(
8484
sqs_queue_url: str,
8585
inputs: dict,
8686
outputs: List[CallbackOutput],
87-
display_name: str = None,
88-
description: str = None,
89-
cache_config: CacheConfig = None,
87+
display_name: Optional[str] = None,
88+
description: Optional[str] = None,
89+
cache_config: Optional[CacheConfig] = None,
9090
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
9191
):
9292
"""Constructs a CallbackStep.

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def __init__(
159159
skip_check: Union[bool, PipelineVariable] = False,
160160
fail_on_violation: Union[bool, PipelineVariable] = True,
161161
register_new_baseline: Union[bool, PipelineVariable] = False,
162-
model_package_group_name: Union[str, PipelineVariable] = None,
163-
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
164-
display_name: str = None,
165-
description: str = None,
166-
cache_config: CacheConfig = None,
162+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
163+
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
164+
display_name: Optional[str] = None,
165+
description: Optional[str] = None,
166+
cache_config: Optional[CacheConfig] = None,
167167
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
168168
):
169169
"""Constructs a ClarifyCheckStep.

src/sagemaker/workflow/condition_step.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from sagemaker.workflow.step_collections import StepCollection
2222
from sagemaker.workflow.functions import JsonGet as NewJsonGet
23-
from sagemaker.workflow.step_outputs import StepOutput
2423
from sagemaker.workflow.steps import (
2524
Step,
2625
StepTypeEnum,
@@ -41,11 +40,11 @@ def __init__(
4140
self,
4241
name: str,
4342
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
44-
display_name: str = None,
45-
description: str = None,
46-
conditions: List[Condition] = None,
47-
if_steps: List[Union[Step, StepCollection, StepOutput]] = None,
48-
else_steps: List[Union[Step, StepCollection, StepOutput]] = None,
43+
display_name: Optional[str] = None,
44+
description: Optional[str] = None,
45+
conditions: Optional[List[Condition]] = None,
46+
if_steps: Optional[List[Union[Step, StepCollection]]] = None,
47+
else_steps: Optional[List[Union[Step, StepCollection]]] = None,
4948
):
5049
"""Construct a ConditionStep for pipelines to support conditional branching.
5150

src/sagemaker/workflow/emr_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def __init__(
161161
cluster_id: str,
162162
step_config: EMRStepConfig,
163163
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
164-
cache_config: CacheConfig = None,
165-
cluster_config: Dict[str, Any] = None,
166-
execution_role_arn: str = None,
164+
cache_config: Optional[CacheConfig] = None,
165+
cluster_config: Optional[Dict[str, Any]] = None,
166+
execution_role_arn: Optional[str] = None,
167167
):
168168
"""Constructs an `EMRStep`.
169169

src/sagemaker/workflow/fail_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ class FailStep(Step):
2929
def __init__(
3030
self,
3131
name: str,
32-
error_message: Union[str, PipelineVariable] = None,
33-
display_name: str = None,
34-
description: str = None,
32+
error_message: Optional[Union[str, PipelineVariable]] = None,
33+
display_name: Optional[str] = None,
34+
description: Optional[str] = None,
3535
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
3636
):
3737
"""Constructs a `FailStep`.

src/sagemaker/workflow/lambda_step.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def __init__(
8484
self,
8585
name: str,
8686
lambda_func: Lambda,
87-
display_name: str = None,
88-
description: str = None,
89-
inputs: dict = None,
90-
outputs: List[LambdaOutput] = None,
91-
cache_config: CacheConfig = None,
87+
display_name: Optional[str] = None,
88+
description: Optional[str] = None,
89+
inputs: Optional[dict] = None,
90+
outputs: Optional[List[LambdaOutput]] = None,
91+
cache_config: Optional[CacheConfig] = None,
9292
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
9393
):
9494
"""Constructs a LambdaStep.

src/sagemaker/workflow/monitor_batch_transform_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848
check_job_configuration: CheckJobConfig,
4949
monitor_before_transform: bool = False,
5050
fail_on_violation: Union[bool, PipelineVariable] = True,
51-
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
52-
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
51+
supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None,
52+
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
5353
display_name: Optional[str] = None,
5454
description: Optional[str] = None,
5555
):

src/sagemaker/workflow/quality_check_step.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ def __init__(
125125
skip_check: Union[bool, PipelineVariable] = False,
126126
fail_on_violation: Union[bool, PipelineVariable] = True,
127127
register_new_baseline: Union[bool, PipelineVariable] = False,
128-
model_package_group_name: Union[str, PipelineVariable] = None,
129-
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
130-
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
131-
display_name: str = None,
132-
description: str = None,
133-
cache_config: CacheConfig = None,
128+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
129+
supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None,
130+
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
131+
display_name: Optional[str] = None,
132+
description: Optional[str] = None,
133+
cache_config: Optional[CacheConfig] = None,
134134
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
135135
):
136136
"""Constructs a QualityCheckStep.

0 commit comments

Comments
 (0)