Skip to content

Commit d87ce71

Browse files
committed
chore: address PR comments, fix formatting
1 parent 6ca6627 commit d87ce71

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

src/sagemaker/environment_variables.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def retrieve_default(
6161
object, used for SageMaker interactions. If not
6262
specified, one is created using the default AWS configuration
6363
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64-
instance_type (str): An instance type to optionally supply in order to get environment variables
65-
specific for the instance type.
66-
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables.
64+
instance_type (str): An instance type to optionally supply in order to get environment
65+
variables specific for the instance type.
66+
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
67+
variables.
6768
Returns:
6869
dict: The variables to use for the model.
6970

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def _retrieve_default_environment_variables(
6161
object, used for SageMaker interactions. If not
6262
specified, one is created using the default AWS configuration
6363
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64-
instance_type (str): An instance type to optionally supply in order to get environment variables
65-
specific for the instance type.
66-
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables.
64+
instance_type (str): An instance type to optionally supply in order to get
65+
environment variables specific for the instance type.
66+
script (JumpStartScriptScope): The JumpStart script for which to retrieve
67+
environment variables.
6768
Returns:
6869
dict: the inference environment variables to use for the model.
6970
"""
@@ -94,15 +95,15 @@ def _retrieve_default_environment_variables(
9495
model_specs, "hosting_instance_type_variants", None
9596
):
9697
default_environment_variables.update(
97-
model_specs.hosting_instance_type_variants.get_instance_specific_environment_variables(
98+
model_specs.hosting_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
9899
instance_type
99100
)
100101
)
101102
elif script == JumpStartScriptScope.TRAINING and getattr(
102103
model_specs, "training_instance_type_variants", None
103104
):
104105
default_environment_variables.update(
105-
model_specs.training_instance_type_variants.get_instance_specific_environment_variables(
106+
model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
106107
instance_type
107108
)
108109
)

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,14 @@ def to_json(self) -> Dict[str, Any]:
349349
def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
350350
"""Returns instance specific environment variables.
351351
352-
Not all models and images have instance specific environment variables.
352+
Returns empty dict if a model, instance type tuple does not have specific
353+
environment variables.
353354
"""
354355

355356
if self.variants is None:
356357
return {}
357358

358-
instance_specific_environment_variables: dict = (
359+
instance_specific_environment_variables: Dict[str, str] = (
359360
self.variants.get(instance_type, {})
360361
.get("properties", {})
361362
.get("environment_variables", {})

0 commit comments

Comments
 (0)