Skip to content

Commit 220929a

Browse files
committed
chore: jumpstart deprecation messages
1 parent baf96cd commit 220929a

File tree

5 files changed

+67
-6
lines changed

5 files changed

+67
-6
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
334334
"training_dependencies",
335335
"training_vulnerabilities",
336336
"deprecated",
337+
"deprecated_message",
338+
"deprecate_warn_message",
337339
"default_inference_instance_type",
338340
"supported_inference_instance_types",
339341
"default_training_instance_type",
@@ -387,6 +389,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
387389
self.training_dependencies: List[str] = json_obj["training_dependencies"]
388390
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
389391
self.deprecated: bool = bool(json_obj["deprecated"])
392+
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
393+
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
390394
self.default_inference_instance_type: Optional[str] = json_obj.get(
391395
"default_inference_instance_type"
392396
)

src/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,16 @@ def verify_model_region_and_return_specs(
402402
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403403
)
404404

405-
if model_specs.deprecated:
405+
if model_specs.deprecated or model_specs.deprecated_message:
406406
if not tolerate_deprecated_model:
407-
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)
407+
raise DeprecatedJumpStartModelError(
408+
model_id=model_id, version=version, message=model_specs.deprecated_message
409+
)
408410
LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version)
409411

412+
if model_specs.deprecate_warn_message:
413+
LOGGER.warning(model_specs.deprecate_warn_message)
414+
410415
if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
411416
if not tolerate_vulnerable_model:
412417
raise VulnerableJumpStartModelError(

src/sagemaker/jumpstart/validators.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
from typing import Any, Dict, List, Optional
1616
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1717

18-
from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes
19-
from sagemaker.jumpstart import accessors as jumpstart_accessors
18+
from sagemaker.jumpstart.enums import (
19+
HyperparameterValidationMode,
20+
JumpStartScriptScope,
21+
VariableScope,
22+
VariableTypes,
23+
)
2024
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
2125
from sagemaker.jumpstart.types import JumpStartHyperparameter
26+
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs
2227

2328

2429
def _validate_hyperparameter(
@@ -190,8 +195,11 @@ def validate_hyperparameters(
190195
if region is None:
191196
region = JUMPSTART_DEFAULT_REGION_NAME
192197

193-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
194-
region=region, model_id=model_id, version=model_version
198+
model_specs = verify_model_region_and_return_specs(
199+
model_id=model_id,
200+
model_version=model_version,
201+
region=region,
202+
scope=JumpStartScriptScope.TRAINING,
195203
)
196204
hyperparameters_specs = model_specs.hyperparameters
197205

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,8 @@
22992299
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002300
"training_prepacked_script_key": None,
23012301
"hosting_prepacked_artifact_key": None,
2302+
"deprecate_warn_message": None,
2303+
"deprecated_message": None,
23022304
"hyperparameters": [
23032305
{
23042306
"name": "epochs",

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,48 @@ def make_deprecated_spec(*largs, **kwargs):
887887
"*",
888888
)
889889

890+
deprecated_message = "this model is deprecated"
891+
892+
def make_deprecated_message_spec(*largs, **kwargs):
893+
spec = get_spec_from_base_spec(*largs, **kwargs)
894+
spec.deprecated_message = deprecated_message
895+
return spec
896+
897+
patched_get_model_specs.side_effect = make_deprecated_message_spec
898+
899+
with pytest.raises(DeprecatedJumpStartModelError) as e:
900+
utils.verify_model_region_and_return_specs(
901+
model_id="pytorch-eqa-bert-base-cased",
902+
version="*",
903+
scope=JumpStartScriptScope.INFERENCE.value,
904+
region="us-west-2",
905+
)
906+
assert deprecated_message == str(e.value.message)
907+
908+
deprecate_warn_message = "warn-msg"
909+
910+
def make_deprecated_warning_message_spec(*largs, **kwargs):
911+
spec = get_spec_from_base_spec(*largs, **kwargs)
912+
spec.deprecate_warn_message = deprecate_warn_message
913+
return spec
914+
915+
patched_get_model_specs.side_effect = make_deprecated_warning_message_spec
916+
917+
with patch("logging.Logger.warning") as mocked_warning_log:
918+
assert (
919+
utils.verify_model_region_and_return_specs(
920+
model_id="pytorch-eqa-bert-base-cased",
921+
version="*",
922+
scope=JumpStartScriptScope.INFERENCE.value,
923+
region="us-west-2",
924+
tolerate_deprecated_model=True,
925+
)
926+
is not None
927+
)
928+
mocked_warning_log.assert_called_once_with(
929+
deprecate_warn_message,
930+
)
931+
890932

891933
def test_get_jumpstart_base_name_if_jumpstart_model():
892934
uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]

0 commit comments

Comments
 (0)