Skip to content

Commit 72b00b6

Browse files
author
Ashwin Krishna
committed
fix: SDK defaults formatting changes
Making changes wrt formatting to fix tox issues
1 parent ebd7872 commit 72b00b6

File tree

5 files changed

+441
-208
lines changed

5 files changed

+441
-208
lines changed

src/sagemaker/config/config_schema.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,20 @@ def _simple_path(*args: str):
171171
)
172172
AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG)
173173
MONITORING_JOB_DEFINITION_PREFIX = _simple_path(
174-
SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION
174+
SAGEMAKER,
175+
MONITORING_SCHEDULE,
176+
MONITORING_SCHEDULE_CONFIG,
177+
MONITORING_JOB_DEFINITION,
175178
)
176179
MONITORING_JOB_ENVIRONMENT_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ENVIRONMENT)
177180
MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path(
178181
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID
179182
)
180183
MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path(
181-
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID
184+
MONITORING_JOB_DEFINITION_PREFIX,
185+
MONITORING_RESOURCES,
186+
CLUSTER_CONFIG,
187+
VOLUME_KMS_KEY_ID,
182188
)
183189
MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG)
184190
MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(
@@ -292,7 +298,11 @@ def _simple_path(*args: str):
292298
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
293299
)
294300
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
295-
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
301+
SAGEMAKER,
302+
PYTHON_SDK,
303+
MODULES,
304+
REMOTE_FUNCTION,
305+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
296306
)
297307
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
298308
SAGEMAKER,
@@ -472,7 +482,11 @@ def _simple_path(*args: str):
472482
"maxProperties": 48,
473483
},
474484
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
475-
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
485+
"s3Uri": {
486+
TYPE: "string",
487+
"pattern": "^(https|s3)://([^/]+)/?(.*)$",
488+
"maxLength": 1024,
489+
},
476490
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
477491
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
478492
# Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html
@@ -755,9 +769,7 @@ def _simple_path(*args: str):
755769
ENDPOINT: {
756770
TYPE: OBJECT,
757771
ADDITIONAL_PROPERTIES: False,
758-
PROPERTIES: {
759-
TAGS: {"$ref": "#/definitions/tags"}
760-
}
772+
PROPERTIES: {TAGS: {"$ref": "#/definitions/tags"}},
761773
},
762774
# Endpoint Config
763775
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html

src/sagemaker/model.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@
5151
from sagemaker.predictor import PredictorBase
5252
from sagemaker.serverless import ServerlessInferenceConfig
5353
from sagemaker.transformer import Transformer
54-
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
54+
from sagemaker.jumpstart.utils import (
55+
add_jumpstart_tags,
56+
get_jumpstart_base_name_if_jumpstart_model,
57+
)
5558
from sagemaker.utils import (
5659
unique_name_from_base,
5760
update_container_with_inference_params,
@@ -64,15 +67,25 @@
6467
from sagemaker.workflow import is_pipeline_variable
6568
from sagemaker.workflow.entities import PipelineVariable
6669
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
67-
from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin
70+
from sagemaker.inference_recommender.inference_recommender_mixin import (
71+
InferenceRecommenderMixin,
72+
)
6873

6974
LOGGER = logging.getLogger("sagemaker")
7075

7176
NEO_ALLOWED_FRAMEWORKS = set(
7277
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
7378
)
7479

75-
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
80+
NEO_IOC_TARGET_DEVICES = [
81+
"ml_c4",
82+
"ml_c5",
83+
"ml_m4",
84+
"ml_m5",
85+
"ml_p2",
86+
"ml_p3",
87+
"ml_g4dn",
88+
]
7689

7790
NEO_MULTIVERSION_UNSUPPORTED = [
7891
"imx8mplus",
@@ -301,7 +314,9 @@ def __init__(
301314
self._base_name = None
302315
self.sagemaker_session = sagemaker_session
303316
self.role = resolve_value_from_config(
304-
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
317+
role,
318+
MODEL_EXECUTION_ROLE_ARN_PATH,
319+
sagemaker_session=self.sagemaker_session,
305320
)
306321
self.vpc_config = resolve_value_from_config(
307322
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
@@ -586,7 +601,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
586601
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
587602

588603
bucket, key_prefix = s3.determine_bucket_and_prefix(
589-
bucket=self.bucket, key_prefix=key_prefix, sagemaker_session=self.sagemaker_session
604+
bucket=self.bucket,
605+
key_prefix=key_prefix,
606+
sagemaker_session=self.sagemaker_session,
590607
)
591608

592609
if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None:
@@ -634,7 +651,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
634651
else:
635652
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
636653
self.uploaded_code = fw_utils.UploadedCode(
637-
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
654+
s3_prefix=repacked_model_data,
655+
script_name=os.path.basename(self.entry_point),
638656
)
639657

640658
LOGGER.info(
@@ -694,7 +712,11 @@ def enable_network_isolation(self):
694712
return False if not self._enable_network_isolation else self._enable_network_isolation
695713

696714
def _create_sagemaker_model(
697-
self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None
715+
self,
716+
instance_type=None,
717+
accelerator_type=None,
718+
tags=None,
719+
serverless_inference_config=None,
698720
):
699721
"""Create a SageMaker Model Entity
700722
@@ -735,10 +757,14 @@ def _create_sagemaker_model(
735757
self._init_sagemaker_session_if_does_not_exist(instance_type)
736758
# Depending on the instance type, a local session (or) a session is initialized.
737759
self.role = resolve_value_from_config(
738-
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
760+
self.role,
761+
MODEL_EXECUTION_ROLE_ARN_PATH,
762+
sagemaker_session=self.sagemaker_session,
739763
)
740764
self.vpc_config = resolve_value_from_config(
741-
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
765+
self.vpc_config,
766+
MODEL_VPC_CONFIG_PATH,
767+
sagemaker_session=self.sagemaker_session,
742768
)
743769
self._enable_network_isolation = resolve_value_from_config(
744770
self._enable_network_isolation,
@@ -956,12 +982,16 @@ def package_for_edge(
956982
job_name = f"packaging{self._compilation_job_name[11:]}"
957983
self._init_sagemaker_session_if_does_not_exist(None)
958984
s3_kms_key = resolve_value_from_config(
959-
s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
985+
s3_kms_key,
986+
EDGE_PACKAGING_KMS_KEY_ID_PATH,
987+
sagemaker_session=self.sagemaker_session,
960988
)
961989
role = resolve_value_from_config(
962990
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
963991
)
964-
resource_key = resolve_value_from_config(resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self)
992+
resource_key = resolve_value_from_config(
993+
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self
994+
)
965995
if role is not None:
966996
role = self.sagemaker_session.expand_role(role)
967997
config = self._edge_packaging_job_config(
@@ -1067,7 +1097,9 @@ def compile(
10671097

10681098
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
10691099
role = resolve_value_from_config(
1070-
role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
1100+
role,
1101+
COMPILATION_JOB_ROLE_ARN_PATH,
1102+
sagemaker_session=self.sagemaker_session,
10711103
)
10721104
if not role:
10731105
# Originally IAM role was a required parameter.
@@ -1234,10 +1266,14 @@ def deploy(
12341266
self._init_sagemaker_session_if_does_not_exist(instance_type)
12351267
# Depending on the instance type, a local session (or) a session is initialized.
12361268
self.role = resolve_value_from_config(
1237-
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
1269+
self.role,
1270+
MODEL_EXECUTION_ROLE_ARN_PATH,
1271+
sagemaker_session=self.sagemaker_session,
12381272
)
12391273
self.vpc_config = resolve_value_from_config(
1240-
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
1274+
self.vpc_config,
1275+
MODEL_VPC_CONFIG_PATH,
1276+
sagemaker_session=self.sagemaker_session,
12411277
)
12421278
self._enable_network_isolation = resolve_value_from_config(
12431279
self._enable_network_isolation,
@@ -1246,7 +1282,9 @@ def deploy(
12461282
)
12471283

12481284
tags = add_jumpstart_tags(
1249-
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
1285+
tags=tags,
1286+
inference_model_uri=self.model_data,
1287+
inference_script_uri=self.source_dir,
12501288
)
12511289

12521290
if self.role is None:
@@ -1294,7 +1332,9 @@ def deploy(
12941332
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
12951333
if self._is_compiled_model and not is_serverless:
12961334
self._ensure_base_name_if_needed(
1297-
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
1335+
image_uri=self.image_uri,
1336+
script_uri=self.source_dir,
1337+
model_uri=self.model_data,
12981338
)
12991339
if self._base_name is not None:
13001340
self._base_name = "-".join((self._base_name, compiled_model_suffix))
@@ -1675,7 +1715,12 @@ class ModelPackage(Model):
16751715
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
16761716

16771717
def __init__(
1678-
self, role=None, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs
1718+
self,
1719+
role=None,
1720+
model_data=None,
1721+
algorithm_arn=None,
1722+
model_package_arn=None,
1723+
**kwargs,
16791724
):
16801725
"""Initialize a SageMaker ModelPackage.
16811726

0 commit comments

Comments
 (0)