Skip to content

Commit 8bdf840

Browse files
CaptainiaJonathan Makunga
authored and
Jonathan Makunga
committed
Merge master
1 parent eebd610 commit 8bdf840

File tree

5 files changed

+21
-4
lines changed

5 files changed

+21
-4
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
disable_output_compression: Optional[bool] = None,
113113
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
114114
config_name: Optional[str] = None,
115+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
115116
):
116117
"""Initializes a ``JumpStartEstimator``.
117118
@@ -505,6 +506,8 @@ def __init__(
505506
Specifies whether RemoteDebug is enabled for the training job
506507
config_name (Optional[str]):
507508
Name of the training configuration to apply to the Estimator. (Default: None).
509+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
510+
Specifies whether SessionTagChaining is enabled for the training job
508511
509512
Raises:
510513
ValueError: If the model ID is not recognized by JumpStart.
@@ -584,6 +587,7 @@ def _validate_model_id_and_get_type_hook():
584587
enable_infra_check=enable_infra_check,
585588
enable_remote_debug=enable_remote_debug,
586589
config_name=config_name,
590+
enable_session_tag_chaining=enable_session_tag_chaining,
587591
)
588592

589593
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def get_init_kwargs(
131131
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
132132
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
133133
config_name: Optional[str] = None,
134+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
134135
) -> JumpStartEstimatorInitKwargs:
135136
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
136137

@@ -190,6 +191,7 @@ def get_init_kwargs(
190191
enable_infra_check=enable_infra_check,
191192
enable_remote_debug=enable_remote_debug,
192193
config_name=config_name,
194+
enable_session_tag_chaining=enable_session_tag_chaining,
193195
)
194196

195197
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/session_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_model_info_from_training_job(
219219
model_id,
220220
inferred_model_version,
221221
inference_config_name,
222-
trainig_config_name,
222+
training_config_name,
223223
) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)
224224

225225
model_version = inferred_model_version or None
@@ -231,4 +231,4 @@ def get_model_info_from_training_job(
231231
"for this training job."
232232
)
233233

234-
return model_id, model_version, inference_config_name, trainig_config_name
234+
return model_id, model_version, inference_config_name, training_config_name

src/sagemaker/jumpstart/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10781078
"resolved_metadata_config",
10791079
"config_name",
10801080
"default_inference_config",
1081-
"default_incremental_trainig_config",
1081+
"default_incremental_training_config",
10821082
"supported_inference_configs",
10831083
"supported_incremental_training_configs",
10841084
]
@@ -1114,7 +1114,7 @@ def __init__(
11141114
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
11151115
self.config_name: Optional[str] = config_name
11161116
self.default_inference_config: Optional[str] = config.get("default_inference_config")
1117-
self.default_incremental_trainig_config: Optional[str] = config.get(
1117+
self.default_incremental_training_config: Optional[str] = config.get(
11181118
"default_incremental_training_config"
11191119
)
11201120
self.supported_inference_configs: Optional[List[str]] = config.get(
@@ -1775,6 +1775,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
17751775
"enable_infra_check",
17761776
"enable_remote_debug",
17771777
"config_name",
1778+
"enable_session_tag_chaining",
17781779
]
17791780

17801781
SERIALIZATION_EXCLUSION_SET = {
@@ -1844,6 +1845,7 @@ def __init__(
18441845
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
18451846
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
18461847
config_name: Optional[str] = None,
1848+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
18471849
) -> None:
18481850
"""Instantiates JumpStartEstimatorInitKwargs object."""
18491851

@@ -1904,6 +1906,7 @@ def __init__(
19041906
self.enable_infra_check = enable_infra_check
19051907
self.enable_remote_debug = enable_remote_debug
19061908
self.config_name = config_name
1909+
self.enable_session_tag_chaining = enable_session_tag_chaining
19071910

19081911

19091912
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,14 @@ def test_inference_configs_parsing():
10551055
)
10561056
assert list(config.config_components.keys()) == ["neuron-inference"]
10571057

1058+
spec = {
1059+
**BASE_SPEC,
1060+
**INFERENCE_CONFIGS,
1061+
**INFERENCE_CONFIG_RANKINGS,
1062+
"unrecognized-field": "blah", # New fields in base metadata fields should be ignored
1063+
}
1064+
specs1 = JumpStartModelSpecs(spec)
1065+
10581066

10591067
def test_set_inference_configs():
10601068
spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}

0 commit comments

Comments
 (0)