Skip to content

chore: jumpstart deprecation messages #3992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
"training_dependencies",
"training_vulnerabilities",
"deprecated",
"deprecated_message",
"deprecate_warn_message",
"default_inference_instance_type",
"supported_inference_instance_types",
"default_training_instance_type",
Expand Down Expand Up @@ -389,6 +391,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.training_dependencies: List[str] = json_obj["training_dependencies"]
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
self.deprecated: bool = bool(json_obj["deprecated"])
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
self.default_inference_instance_type: Optional[str] = json_obj.get(
"default_inference_instance_type"
)
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,14 @@ def verify_model_region_and_return_specs(

if model_specs.deprecated:
if not tolerate_deprecated_model:
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)
raise DeprecatedJumpStartModelError(
model_id=model_id, version=version, message=model_specs.deprecated_message
)
LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version)

if model_specs.deprecate_warn_message:
LOGGER.warning(model_specs.deprecate_warn_message)

if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
if not tolerate_vulnerable_model:
raise VulnerableJumpStartModelError(
Expand Down
16 changes: 12 additions & 4 deletions src/sagemaker/jumpstart/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
from typing import Any, Dict, List, Optional
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME

from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes
from sagemaker.jumpstart import accessors as jumpstart_accessors
from sagemaker.jumpstart.enums import (
HyperparameterValidationMode,
JumpStartScriptScope,
VariableScope,
VariableTypes,
)
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
from sagemaker.jumpstart.types import JumpStartHyperparameter
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs


def _validate_hyperparameter(
Expand Down Expand Up @@ -190,8 +195,11 @@ def validate_hyperparameters(
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
region=region,
scope=JumpStartScriptScope.TRAINING,
)
hyperparameters_specs = model_specs.hyperparameters

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,8 @@
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
"training_prepacked_script_key": None,
"hosting_prepacked_artifact_key": None,
"deprecate_warn_message": None,
"deprecated_message": None,
"hosting_model_package_arns": None,
"hosting_eula_key": None,
"hyperparameters": [
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,49 @@ def make_deprecated_spec(*largs, **kwargs):
"*",
)

deprecated_message = "this model is deprecated"

def make_deprecated_message_spec(*largs, **kwargs):
spec = get_spec_from_base_spec(*largs, **kwargs)
spec.deprecated_message = deprecated_message
spec.deprecated = True
return spec

patched_get_model_specs.side_effect = make_deprecated_message_spec

with pytest.raises(DeprecatedJumpStartModelError) as e:
utils.verify_model_region_and_return_specs(
model_id="pytorch-eqa-bert-base-cased",
version="*",
scope=JumpStartScriptScope.INFERENCE.value,
region="us-west-2",
)
assert deprecated_message == str(e.value.message)

deprecate_warn_message = "warn-msg"

def make_deprecated_warning_message_spec(*largs, **kwargs):
spec = get_spec_from_base_spec(*largs, **kwargs)
spec.deprecate_warn_message = deprecate_warn_message
return spec

patched_get_model_specs.side_effect = make_deprecated_warning_message_spec

with patch("logging.Logger.warning") as mocked_warning_log:
assert (
utils.verify_model_region_and_return_specs(
model_id="pytorch-eqa-bert-base-cased",
version="*",
scope=JumpStartScriptScope.INFERENCE.value,
region="us-west-2",
tolerate_deprecated_model=True,
)
is not None
)
mocked_warning_log.assert_called_once_with(
deprecate_warn_message,
)


def test_get_jumpstart_base_name_if_jumpstart_model():
uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]
Expand Down