Skip to content

Commit 6471a43

Browse files
committed
fix: Adding more fields
1 parent 19c3a36 commit 6471a43

File tree

5 files changed

+25
-25
lines changed

5 files changed

+25
-25
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
489489
)
490490
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
491491

492-
if hasattr(kwargs.specs, "capabilities"):
492+
if kwargs.specs.capabilities is not None:
493493
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
494494
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
495495

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -551,17 +551,17 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
551551
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
552552
"""
553553
self.url: str = json_obj["Url"]
554-
self.min_sdk_version: str = json_obj["MinSdkVersion"]
555-
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
556-
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
557-
self.hosting_script_uri = json_obj["HostingScriptUri"]
558-
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
554+
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
555+
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
556+
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
557+
self.hosting_script_uri = json_obj.get("HostingScriptUri")
558+
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
559559
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
560560
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
561-
for env_variable in json_obj["InferenceEnvironmentVariables"]
561+
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
562562
]
563-
self.training_supported: bool = bool(json_obj["TrainingSupported"])
564-
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
563+
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
564+
self.incremental_training_supported: bool = bool(json_obj.get("IncrementalTrainingSupported"))
565565
self.dynamic_container_deployment_supported: Optional[bool] = (
566566
bool(json_obj.get("DynamicContainerDeploymentSupported"))
567567
if json_obj.get("DynamicContainerDeploymentSupported")

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020

2121
def camel_to_snake(camel_case_string: str) -> str:
2222
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
23-
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
24-
if "-" in snake_case_string:
25-
# remove any hyphen from the string for accurate conversion.
26-
snake_case_string = snake_case_string.replace("-", "")
27-
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
23+
return re.sub(r'(?<!^)(?=[A-Z])', '_', camel_case_string).lower()
2824

2925

3026
def snake_to_upper_camel(snake_case_string: str) -> str:

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,19 @@ def make_model_specs_from_describe_hub_content_response(
146146
specs["inference_config_components"] = hub_model_document.inference_config_components
147147
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
148148

149-
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
150-
hub_model_document.hosting_artifact_uri
151-
)
152-
specs["hosting_artifact_key"] = hosting_artifact_key
153-
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
154-
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
155-
hub_model_document.hosting_script_uri
156-
)
157-
specs["hosting_script_key"] = hosting_script_key
149+
if hub_model_document.hosting_artifact_uri:
150+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
151+
hub_model_document.hosting_artifact_uri
152+
)
153+
specs["hosting_artifact_key"] = hosting_artifact_key
154+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
155+
156+
if hub_model_document.hosting_script_uri:
157+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
158+
hub_model_document.hosting_script_uri
159+
)
160+
specs["hosting_script_key"] = hosting_script_key
161+
158162
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
159163
specs["inference_vulnerable"] = False
160164
specs["inference_dependencies"] = hub_model_document.inference_dependencies

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def generate_hub_arn_for_init_kwargs(
117117

118118
hub_arn = None
119119
if hub_name:
120-
if hub_name == constants.JUMPSTART_MODEL_HUB_NAME:
121-
return None
120+
# if hub_name == constants.JUMPSTART_MODEL_HUB_NAME:
121+
# return None
122122
match = re.match(constants.HUB_ARN_REGEX, hub_name)
123123
if match:
124124
hub_arn = hub_name

0 commit comments

Comments
 (0)