Skip to content

feat: Adding scan and tagging utility #4499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
73ce401
fix: Adding new utils
chrstfu Mar 12, 2024
fbd233f
Merge branch 'curated_hub_tagris' into curated_hub_tagris_copy
chrstfu Mar 12, 2024
6d43cbe
feat: Adding Curated Hub scanning feature
chrstfu Mar 12, 2024
2ec0ca5
fix: Refactoring
chrstfu Mar 13, 2024
c1519b8
fix: Removing test values
chrstfu Mar 13, 2024
c2d4a17
fix: Refactoring
chrstfu Mar 13, 2024
48d6325
fix: removing lru cache temporarily
chrstfu Mar 13, 2024
0947840
fix: Adding unit tests
chrstfu Mar 13, 2024
5f8119b
fix: renaming
chrstfu Mar 13, 2024
395e99c
fix: Initial refactorings
chrstfu Mar 14, 2024
c957e71
fix: Adding more alterations
chrstfu Mar 15, 2024
453bed4
fix: Addressing unit tests
chrstfu Mar 15, 2024
9e703a9
fix: Adding more unittests
chrstfu Mar 15, 2024
7cda304
fix: Add tests
chrstfu Mar 15, 2024
9dbe3d6
fix: Adding list to scan input
chrstfu Mar 15, 2024
554dd20
fix: typo
chrstfu Mar 15, 2024
9255849
fix: Addressing naming comments
chrstfu Mar 15, 2024
6d5f599
fix: changing from string
chrstfu Mar 15, 2024
77d7b50
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
chrstfu Mar 18, 2024
0e5d5b4
fix: formatter
chrstfu Mar 18, 2024
4b82389
fix: linter
chrstfu Mar 18, 2024
42efbaa
fix: linter
chrstfu Mar 18, 2024
a73b5b2
fix: linters
chrstfu Mar 18, 2024
4252d86
fix: linter
chrstfu Mar 18, 2024
0deb20e
fix: linting
chrstfu Mar 18, 2024
d21fd41
fix: more linting :/
chrstfu Mar 18, 2024
6b107a2
fix: linting
chrstfu Mar 18, 2024
29a0740
fix: more linting :/
chrstfu Mar 19, 2024
054a8cf
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
chrstfu Mar 19, 2024
bddf5e3
fix: linting
chrstfu Mar 19, 2024
bd02739
fix: tests
chrstfu Mar 19, 2024
b33a63b
fix: Adding __init__.py
chrstfu Mar 19, 2024
e1da9fc
fix: linting
chrstfu Mar 19, 2024
4c829c9
fix: linting
chrstfu Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def _retrieve_default_environment_variables(
)

if gated_model_env_var is None and model_specs.is_gated_model():

possible_env_vars: Set[str] = {
retrieve_gated_env_var_for_instance_type(instance_type)
for instance_type in model_specs.supported_training_instance_types
Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def _retrieve_model_package_arn(
)

if scope == JumpStartScriptScope.INFERENCE:

instance_specific_arn: Optional[str] = (
model_specs.hosting_instance_type_variants.get_model_package_arn(
region=region, instance_type=instance_type
Expand Down Expand Up @@ -155,7 +154,6 @@ def _retrieve_model_package_model_artifact_s3_uri(
"""

if scope == JumpStartScriptScope.TRAINING:

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def _retrieve_model_uri(
model_artifact_key: str

if model_scope == JumpStartScriptScope.INFERENCE:

is_prepacked = not model_specs.use_inference_script_uri()

model_artifact_key = (
Expand All @@ -159,7 +158,6 @@ def _retrieve_model_uri(
)

elif model_scope == JumpStartScriptScope.TRAINING:

model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)

default_jumpstart_bucket: str = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def _retrieve_default_resources(
}

if is_dynamic_container_deployment_supported:

all_resource_requirement_kwargs = {}

for (
Expand Down
96 changes: 37 additions & 59 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def set_manifest_file_s3_key(
}
property_name = file_mapping.get(file_type)
if not property_name:
raise ValueError(
self._file_type_error_msg(file_type, manifest_only=True)
)
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
if key != property_name:
setattr(self, property_name, key)
self.clear()
Expand All @@ -192,9 +190,7 @@ def get_manifest_file_s3_key(
return self._manifest_file_s3_key
if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST:
return self._proprietary_manifest_s3_key
raise ValueError(
self._file_type_error_msg(file_type, manifest_only=True)
)
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))

def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
"""Set s3 bucket used for cache."""
Expand Down Expand Up @@ -247,7 +243,8 @@ def _model_id_retrieval_function(
sm_version = utils.get_sagemaker_version()
manifest = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
)
)[0].formatted_content

versions_compatible_with_sagemaker = [
Expand All @@ -264,7 +261,8 @@ def _model_id_retrieval_function(
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)

versions_incompatible_with_sagemaker = [
Version(header.version) for header in manifest.values() # type: ignore
Version(header.version)
for header in manifest.values() # type: ignore
if header.model_id == model_id
]
sm_incompatible_model_version = self._select_version(
Expand Down Expand Up @@ -294,9 +292,7 @@ def _model_id_retrieval_function(
raise KeyError(error_msg)

error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
error_msg += (
f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
)
error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "

other_model_id_version = None
if model_type == JumpStartModelType.OPEN_WEIGHTS:
Expand All @@ -305,19 +301,17 @@ def _model_id_retrieval_function(
) # all versions here are incompatible with sagemaker
elif model_type == JumpStartModelType.PROPRIETARY:
all_possible_model_id_version = [
header.version for header in manifest.values() # type: ignore
header.version
for header in manifest.values() # type: ignore
if header.model_id == model_id
]
other_model_id_version = (
None
if not all_possible_model_id_version
else all_possible_model_id_version[0]
None if not all_possible_model_id_version else all_possible_model_id_version[0]
)

if other_model_id_version is not None:
error_msg += (
f"Consider using model ID '{model_id}' with version "
f"'{other_model_id_version}'."
f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'."
)
else:
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
Expand Down Expand Up @@ -359,15 +353,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],

def _is_local_metadata_mode(self) -> bool:
"""Returns True if the cache should use local metadata mode, based off env variables."""
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))
return (
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])
)

def _get_json_file(
self,
key: str,
filetype: JumpStartS3FileType
self, key: str, filetype: JumpStartS3FileType
) -> Tuple[Union[dict, list], Optional[str]]:
"""Returns json file either from s3 or local file system.

Expand All @@ -391,21 +385,19 @@ def _get_json_md5_hash(self, key: str):
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]

def _get_json_file_from_local_override(
self,
key: str,
filetype: JumpStartS3FileType
self, key: str, filetype: JumpStartS3FileType
) -> Union[dict, list]:
"""Reads json file from local filesystem and returns data."""
if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
metadata_local_root = (
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
)
metadata_local_root = os.environ[
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
]
elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS:
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
else:
raise ValueError(f"Unsupported file type for local override: {filetype}")
file_path = os.path.join(metadata_local_root, key)
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = json.load(f)
return data

Expand Down Expand Up @@ -450,9 +442,7 @@ def _retrieval_function(
formatted_body, _ = self._get_json_file(id_info, data_type)
model_specs = JumpStartModelSpecs(formatted_body)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedContentValue(
formatted_content=model_specs
)
return JumpStartCachedContentValue(formatted_content=model_specs)

if data_type == HubContentType.MODEL:
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
Expand All @@ -462,21 +452,15 @@ def _retrieval_function(
hub_name=hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=data_type
hub_content_type=data_type,
)

model_specs = JumpStartModelSpecs(
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
)

utils.emit_logs_based_on_model_specs(
model_specs,
self.get_region(),
self._s3_client
)
return JumpStartCachedContentValue(
formatted_content=model_specs
)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedContentValue(formatted_content=model_specs)

if data_type == HubType.HUB:
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
Expand All @@ -486,9 +470,7 @@ def _retrieval_function(
formatted_content=DescribeHubResponse(hub_description)
)

raise ValueError(
self._file_type_error_msg(data_type)
)
raise ValueError(self._file_type_error_msg(data_type))

def get_manifest(
self,
Expand All @@ -497,7 +479,8 @@ def get_manifest(
"""Return entire JumpStart models manifest."""
manifest_dict = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
)
)[0].formatted_content
manifest = list(manifest_dict.values()) # type: ignore
return manifest
Expand Down Expand Up @@ -554,16 +537,14 @@ def _select_version(
except InvalidSpecifier:
raise KeyError(f"Bad semantic version: {version_str}")
available_versions_filtered = list(spec.filter(available_versions))
return (
str(max(available_versions_filtered)) if available_versions_filtered != [] else None
)
return str(max(available_versions_filtered)) if available_versions_filtered != [] else None

def _get_header_impl(
self,
model_id: str,
semantic_version_str: str,
attempt: int = 0,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> JumpStartModelHeader:
"""Lower-level function to return header.

Expand All @@ -586,7 +567,8 @@ def _get_header_impl(

manifest = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
)
)[0].formatted_content

try:
Expand All @@ -602,7 +584,7 @@ def get_specs(
self,
model_id: str,
version_str: str,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> JumpStartModelSpecs:
"""Return specs for a given JumpStart model ID and semantic version.

Expand All @@ -615,16 +597,12 @@ def get_specs(
header = self.get_header(model_id, version_str, model_type)
spec_key = header.spec_key
specs, cache_hit = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key
)
JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
)

if not cache_hit and "*" in version_str:
JUMPSTART_LOGGER.warning(
get_wildcard_model_version_msg(
header.model_id, version_str, header.version
)
get_wildcard_model_version_msg(header.model_id, version_str, header.version)
)
return specs.formatted_content

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
HubContentDependencyType,
S3ObjectLocation,
)
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import (
PublicModelDataAccessor,
)
from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket
from sagemaker.jumpstart.types import JumpStartModelSpecs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Dict[str, Any]],
):
"""Creates a PublicModelDataAccessor."""
self._region = region
self._bucket = (
get_jumpstart_gated_content_bucket(region)
Expand Down
Loading