Skip to content

Commit 219ad24

Browse files
authored
chore: jumpstart deprecation messages (#3992)
1 parent fdc0ac1 commit 219ad24

File tree

5 files changed

+67
-5
lines changed

5 files changed

+67
-5
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
334334
"training_dependencies",
335335
"training_vulnerabilities",
336336
"deprecated",
337+
"deprecated_message",
338+
"deprecate_warn_message",
337339
"default_inference_instance_type",
338340
"supported_inference_instance_types",
339341
"default_training_instance_type",
@@ -389,6 +391,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
389391
self.training_dependencies: List[str] = json_obj["training_dependencies"]
390392
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
391393
self.deprecated: bool = bool(json_obj["deprecated"])
394+
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
395+
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
392396
self.default_inference_instance_type: Optional[str] = json_obj.get(
393397
"default_inference_instance_type"
394398
)

src/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,14 @@ def verify_model_region_and_return_specs(
415415

416416
if model_specs.deprecated:
417417
if not tolerate_deprecated_model:
418-
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)
418+
raise DeprecatedJumpStartModelError(
419+
model_id=model_id, version=version, message=model_specs.deprecated_message
420+
)
419421
LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version)
420422

423+
if model_specs.deprecate_warn_message:
424+
LOGGER.warning(model_specs.deprecate_warn_message)
425+
421426
if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
422427
if not tolerate_vulnerable_model:
423428
raise VulnerableJumpStartModelError(

src/sagemaker/jumpstart/validators.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
from typing import Any, Dict, List, Optional
1616
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1717

18-
from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes
19-
from sagemaker.jumpstart import accessors as jumpstart_accessors
18+
from sagemaker.jumpstart.enums import (
19+
HyperparameterValidationMode,
20+
JumpStartScriptScope,
21+
VariableScope,
22+
VariableTypes,
23+
)
2024
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
2125
from sagemaker.jumpstart.types import JumpStartHyperparameter
26+
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs
2227

2328

2429
def _validate_hyperparameter(
@@ -190,8 +195,11 @@ def validate_hyperparameters(
190195
if region is None:
191196
region = JUMPSTART_DEFAULT_REGION_NAME
192197

193-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
194-
region=region, model_id=model_id, version=model_version
198+
model_specs = verify_model_region_and_return_specs(
199+
model_id=model_id,
200+
version=model_version,
201+
region=region,
202+
scope=JumpStartScriptScope.TRAINING,
195203
)
196204
hyperparameters_specs = model_specs.hyperparameters
197205

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,8 @@
23562356
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23572357
"training_prepacked_script_key": None,
23582358
"hosting_prepacked_artifact_key": None,
2359+
"deprecate_warn_message": None,
2360+
"deprecated_message": None,
23592361
"hosting_model_package_arns": None,
23602362
"hosting_eula_key": None,
23612363
"hyperparameters": [

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,49 @@ def make_deprecated_spec(*largs, **kwargs):
887887
"*",
888888
)
889889

890+
deprecated_message = "this model is deprecated"
891+
892+
def make_deprecated_message_spec(*largs, **kwargs):
893+
spec = get_spec_from_base_spec(*largs, **kwargs)
894+
spec.deprecated_message = deprecated_message
895+
spec.deprecated = True
896+
return spec
897+
898+
patched_get_model_specs.side_effect = make_deprecated_message_spec
899+
900+
with pytest.raises(DeprecatedJumpStartModelError) as e:
901+
utils.verify_model_region_and_return_specs(
902+
model_id="pytorch-eqa-bert-base-cased",
903+
version="*",
904+
scope=JumpStartScriptScope.INFERENCE.value,
905+
region="us-west-2",
906+
)
907+
assert deprecated_message == str(e.value.message)
908+
909+
deprecate_warn_message = "warn-msg"
910+
911+
def make_deprecated_warning_message_spec(*largs, **kwargs):
912+
spec = get_spec_from_base_spec(*largs, **kwargs)
913+
spec.deprecate_warn_message = deprecate_warn_message
914+
return spec
915+
916+
patched_get_model_specs.side_effect = make_deprecated_warning_message_spec
917+
918+
with patch("logging.Logger.warning") as mocked_warning_log:
919+
assert (
920+
utils.verify_model_region_and_return_specs(
921+
model_id="pytorch-eqa-bert-base-cased",
922+
version="*",
923+
scope=JumpStartScriptScope.INFERENCE.value,
924+
region="us-west-2",
925+
tolerate_deprecated_model=True,
926+
)
927+
is not None
928+
)
929+
mocked_warning_log.assert_called_once_with(
930+
deprecate_warn_message,
931+
)
932+
890933

891934
def test_get_jumpstart_base_name_if_jumpstart_model():
892935
uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]

0 commit comments

Comments
 (0)