Skip to content

Commit bb683d0

Browse files
author
Ashwin Krishna
committed
fix: SDK defaults formatting changes
Making changes wrt formatting to fix tox issues
1 parent 1ed2d18 commit bb683d0

File tree

5 files changed

+441
-208
lines changed

5 files changed

+441
-208
lines changed

src/sagemaker/config/config_schema.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,20 @@ def _simple_path(*args: str):
171171
)
172172
AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG)
173173
MONITORING_JOB_DEFINITION_PREFIX = _simple_path(
174-
SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION
174+
SAGEMAKER,
175+
MONITORING_SCHEDULE,
176+
MONITORING_SCHEDULE_CONFIG,
177+
MONITORING_JOB_DEFINITION,
175178
)
176179
MONITORING_JOB_ENVIRONMENT_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ENVIRONMENT)
177180
MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path(
178181
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID
179182
)
180183
MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path(
181-
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID
184+
MONITORING_JOB_DEFINITION_PREFIX,
185+
MONITORING_RESOURCES,
186+
CLUSTER_CONFIG,
187+
VOLUME_KMS_KEY_ID,
182188
)
183189
MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG)
184190
MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(
@@ -292,7 +298,11 @@ def _simple_path(*args: str):
292298
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
293299
)
294300
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
295-
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
301+
SAGEMAKER,
302+
PYTHON_SDK,
303+
MODULES,
304+
REMOTE_FUNCTION,
305+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
296306
)
297307
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
298308
SAGEMAKER,
@@ -472,7 +482,11 @@ def _simple_path(*args: str):
472482
"maxProperties": 48,
473483
},
474484
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
475-
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
485+
"s3Uri": {
486+
TYPE: "string",
487+
"pattern": "^(https|s3)://([^/]+)/?(.*)$",
488+
"maxLength": 1024,
489+
},
476490
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
477491
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
478492
# Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html
@@ -755,9 +769,7 @@ def _simple_path(*args: str):
755769
ENDPOINT: {
756770
TYPE: OBJECT,
757771
ADDITIONAL_PROPERTIES: False,
758-
PROPERTIES: {
759-
TAGS: {"$ref": "#/definitions/tags"}
760-
}
772+
PROPERTIES: {TAGS: {"$ref": "#/definitions/tags"}},
761773
},
762774
# Endpoint Config
763775
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html

src/sagemaker/model.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@
5050
from sagemaker.predictor import PredictorBase
5151
from sagemaker.serverless import ServerlessInferenceConfig
5252
from sagemaker.transformer import Transformer
53-
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
53+
from sagemaker.jumpstart.utils import (
54+
add_jumpstart_tags,
55+
get_jumpstart_base_name_if_jumpstart_model,
56+
)
5457
from sagemaker.utils import (
5558
unique_name_from_base,
5659
update_container_with_inference_params,
@@ -63,15 +66,25 @@
6366
from sagemaker.workflow import is_pipeline_variable
6467
from sagemaker.workflow.entities import PipelineVariable
6568
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
66-
from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin
69+
from sagemaker.inference_recommender.inference_recommender_mixin import (
70+
InferenceRecommenderMixin,
71+
)
6772

6873
LOGGER = logging.getLogger("sagemaker")
6974

7075
NEO_ALLOWED_FRAMEWORKS = set(
7176
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
7277
)
7378

74-
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
79+
NEO_IOC_TARGET_DEVICES = [
80+
"ml_c4",
81+
"ml_c5",
82+
"ml_m4",
83+
"ml_m5",
84+
"ml_p2",
85+
"ml_p3",
86+
"ml_g4dn",
87+
]
7588

7689
NEO_MULTIVERSION_UNSUPPORTED = [
7790
"imx8mplus",
@@ -300,7 +313,9 @@ def __init__(
300313
self._base_name = None
301314
self.sagemaker_session = sagemaker_session
302315
self.role = resolve_value_from_config(
303-
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
316+
role,
317+
MODEL_EXECUTION_ROLE_ARN_PATH,
318+
sagemaker_session=self.sagemaker_session,
304319
)
305320
self.vpc_config = resolve_value_from_config(
306321
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
@@ -585,7 +600,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
585600
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
586601

587602
bucket, key_prefix = s3.determine_bucket_and_prefix(
588-
bucket=self.bucket, key_prefix=key_prefix, sagemaker_session=self.sagemaker_session
603+
bucket=self.bucket,
604+
key_prefix=key_prefix,
605+
sagemaker_session=self.sagemaker_session,
589606
)
590607

591608
if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None:
@@ -633,7 +650,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
633650
else:
634651
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
635652
self.uploaded_code = fw_utils.UploadedCode(
636-
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
653+
s3_prefix=repacked_model_data,
654+
script_name=os.path.basename(self.entry_point),
637655
)
638656

639657
LOGGER.info(
@@ -693,7 +711,11 @@ def enable_network_isolation(self):
693711
return False if not self._enable_network_isolation else self._enable_network_isolation
694712

695713
def _create_sagemaker_model(
696-
self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None
714+
self,
715+
instance_type=None,
716+
accelerator_type=None,
717+
tags=None,
718+
serverless_inference_config=None,
697719
):
698720
"""Create a SageMaker Model Entity
699721
@@ -734,10 +756,14 @@ def _create_sagemaker_model(
734756
self._init_sagemaker_session_if_does_not_exist(instance_type)
735757
# Depending on the instance type, a local session (or) a session is initialized.
736758
self.role = resolve_value_from_config(
737-
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
759+
self.role,
760+
MODEL_EXECUTION_ROLE_ARN_PATH,
761+
sagemaker_session=self.sagemaker_session,
738762
)
739763
self.vpc_config = resolve_value_from_config(
740-
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
764+
self.vpc_config,
765+
MODEL_VPC_CONFIG_PATH,
766+
sagemaker_session=self.sagemaker_session,
741767
)
742768
self._enable_network_isolation = resolve_value_from_config(
743769
self._enable_network_isolation,
@@ -955,12 +981,16 @@ def package_for_edge(
955981
job_name = f"packaging{self._compilation_job_name[11:]}"
956982
self._init_sagemaker_session_if_does_not_exist(None)
957983
s3_kms_key = resolve_value_from_config(
958-
s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
984+
s3_kms_key,
985+
EDGE_PACKAGING_KMS_KEY_ID_PATH,
986+
sagemaker_session=self.sagemaker_session,
959987
)
960988
role = resolve_value_from_config(
961989
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
962990
)
963-
resource_key = resolve_value_from_config(resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self)
991+
resource_key = resolve_value_from_config(
992+
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self
993+
)
964994
if role is not None:
965995
role = self.sagemaker_session.expand_role(role)
966996
config = self._edge_packaging_job_config(
@@ -1066,7 +1096,9 @@ def compile(
10661096

10671097
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
10681098
role = resolve_value_from_config(
1069-
role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
1099+
role,
1100+
COMPILATION_JOB_ROLE_ARN_PATH,
1101+
sagemaker_session=self.sagemaker_session,
10701102
)
10711103
if not role:
10721104
# Originally IAM role was a required parameter.
@@ -1231,10 +1263,14 @@ def deploy(
12311263
self._init_sagemaker_session_if_does_not_exist(instance_type)
12321264
# Depending on the instance type, a local session (or) a session is initialized.
12331265
self.role = resolve_value_from_config(
1234-
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
1266+
self.role,
1267+
MODEL_EXECUTION_ROLE_ARN_PATH,
1268+
sagemaker_session=self.sagemaker_session,
12351269
)
12361270
self.vpc_config = resolve_value_from_config(
1237-
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
1271+
self.vpc_config,
1272+
MODEL_VPC_CONFIG_PATH,
1273+
sagemaker_session=self.sagemaker_session,
12381274
)
12391275
self._enable_network_isolation = resolve_value_from_config(
12401276
self._enable_network_isolation,
@@ -1243,7 +1279,9 @@ def deploy(
12431279
)
12441280

12451281
tags = add_jumpstart_tags(
1246-
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
1282+
tags=tags,
1283+
inference_model_uri=self.model_data,
1284+
inference_script_uri=self.source_dir,
12471285
)
12481286

12491287
if self.role is None:
@@ -1291,7 +1329,9 @@ def deploy(
12911329
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
12921330
if self._is_compiled_model and not is_serverless:
12931331
self._ensure_base_name_if_needed(
1294-
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
1332+
image_uri=self.image_uri,
1333+
script_uri=self.source_dir,
1334+
model_uri=self.model_data,
12951335
)
12961336
if self._base_name is not None:
12971337
self._base_name = "-".join((self._base_name, compiled_model_suffix))
@@ -1668,7 +1708,12 @@ class ModelPackage(Model):
16681708
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
16691709

16701710
def __init__(
1671-
self, role=None, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs
1711+
self,
1712+
role=None,
1713+
model_data=None,
1714+
algorithm_arn=None,
1715+
model_package_arn=None,
1716+
**kwargs,
16721717
):
16731718
"""Initialize a SageMaker ModelPackage.
16741719

0 commit comments

Comments
 (0)