Skip to content

Commit 43e0174

Browse files
beniericknikure
authored andcommitted
black-format
1 parent 821fa78 commit 43e0174

File tree

94 files changed

+608
-525
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+608
-525
lines changed

src/sagemaker/amazon/record_pb2.py

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/sagemaker/automl/automl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,9 @@ def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
930930

931931
auto_ml_model_deploy_config = {}
932932
if auto_ml.auto_generate_endpoint_name is not None:
933-
auto_ml_model_deploy_config[
934-
"AutoGenerateEndpointName"
935-
] = auto_ml.auto_generate_endpoint_name
933+
auto_ml_model_deploy_config["AutoGenerateEndpointName"] = (
934+
auto_ml.auto_generate_endpoint_name
935+
)
936936
if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None:
937937
auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name
938938

@@ -1034,9 +1034,9 @@ def _prepare_auto_ml_stop_condition(
10341034
if max_candidates is not None:
10351035
stopping_condition["MaxCandidates"] = max_candidates
10361036
if max_runtime_per_training_job_in_seconds is not None:
1037-
stopping_condition[
1038-
"MaxRuntimePerTrainingJobInSeconds"
1039-
] = max_runtime_per_training_job_in_seconds
1037+
stopping_condition["MaxRuntimePerTrainingJobInSeconds"] = (
1038+
max_runtime_per_training_job_in_seconds
1039+
)
10401040
if total_job_runtime_in_seconds is not None:
10411041
stopping_condition["MaxAutoMLJobRuntimeInSeconds"] = total_job_runtime_in_seconds
10421042

src/sagemaker/automl/automlv2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,9 +1446,9 @@ def _load_config(cls, inputs, auto_ml, expand_role=True):
14461446

14471447
auto_ml_model_deploy_config = {}
14481448
if auto_ml.auto_generate_endpoint_name is not None:
1449-
auto_ml_model_deploy_config[
1450-
"AutoGenerateEndpointName"
1451-
] = auto_ml.auto_generate_endpoint_name
1449+
auto_ml_model_deploy_config["AutoGenerateEndpointName"] = (
1450+
auto_ml.auto_generate_endpoint_name
1451+
)
14521452
if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None:
14531453
auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name
14541454

src/sagemaker/collection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,11 @@ def _convert_group_resource_response(
377377
{
378378
"Name": collection_name,
379379
"Arn": collection_arn,
380-
"Type": resource_group["Identifier"]["ResourceType"]
381-
if is_model_group
382-
else "Collection",
380+
"Type": (
381+
resource_group["Identifier"]["ResourceType"]
382+
if is_model_group
383+
else "Collection"
384+
),
383385
}
384386
)
385387
return collection_details

src/sagemaker/debugger/profiler_config.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,19 @@ def _to_request_dict(self):
162162
profiler_config_request["DisableProfiler"] = self.disable_profiler
163163

164164
if self.system_monitor_interval_millis is not None:
165-
profiler_config_request[
166-
"ProfilingIntervalInMilliseconds"
167-
] = self.system_monitor_interval_millis
165+
profiler_config_request["ProfilingIntervalInMilliseconds"] = (
166+
self.system_monitor_interval_millis
167+
)
168168

169169
if self.framework_profile_params is not None:
170-
profiler_config_request[
171-
"ProfilingParameters"
172-
] = self.framework_profile_params.profiling_parameters
170+
profiler_config_request["ProfilingParameters"] = (
171+
self.framework_profile_params.profiling_parameters
172+
)
173173

174174
if self.profile_params is not None:
175-
profiler_config_request[
176-
"ProfilingParameters"
177-
] = self.profile_params.profiling_parameters
175+
profiler_config_request["ProfilingParameters"] = (
176+
self.profile_params.profiling_parameters
177+
)
178178

179179
return profiler_config_request
180180

src/sagemaker/djl_inference/model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ def _create_estimator(
213213
vpc_config: Optional[
214214
Dict[
215215
str,
216-
List[
217-
str,
218-
],
216+
List[str,],
219217
]
220218
] = None,
221219
volume_kms_key=None,
@@ -820,9 +818,9 @@ def _get_container_env(self):
820818
logger.warning("Ignoring invalid container log level: %s", self.container_log_level)
821819
return self.env
822820

823-
self.env[
824-
"SERVING_OPTS"
825-
] = f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"'
821+
self.env["SERVING_OPTS"] = (
822+
f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"'
823+
)
826824
return self.env
827825

828826

src/sagemaker/estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,9 +2539,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25392539
# which is parsed in execution time
25402540
# This does not check config because the EstimatorBase constuctor already did that check
25412541
if estimator.encrypt_inter_container_traffic:
2542-
train_args[
2543-
"encrypt_inter_container_traffic"
2544-
] = estimator.encrypt_inter_container_traffic
2542+
train_args["encrypt_inter_container_traffic"] = (
2543+
estimator.encrypt_inter_container_traffic
2544+
)
25452545

25462546
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
25472547
train_args["algorithm_arn"] = estimator.algorithm_arn
@@ -2556,9 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25562556
train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict()
25572557

25582558
if estimator.tensorboard_output_config:
2559-
train_args[
2560-
"tensorboard_output_config"
2561-
] = estimator.tensorboard_output_config._to_request_dict()
2559+
train_args["tensorboard_output_config"] = (
2560+
estimator.tensorboard_output_config._to_request_dict()
2561+
)
25622562

25632563
cls._add_spot_checkpoint_args(local_mode, estimator, train_args)
25642564

src/sagemaker/experiments/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ def __init__(
220220
trial_component_name=self._trial_component.trial_component_name,
221221
sagemaker_session=sagemaker_session,
222222
artifact_bucket=artifact_bucket,
223-
artifact_prefix=_DEFAULT_ARTIFACT_PREFIX
224-
if artifact_prefix is None
225-
else artifact_prefix,
223+
artifact_prefix=(
224+
_DEFAULT_ARTIFACT_PREFIX if artifact_prefix is None else artifact_prefix
225+
),
226226
)
227227
self._lineage_artifact_tracker = _LineageArtifactTracker(
228228
trial_component_arn=self._trial_component.trial_component_arn,

src/sagemaker/explainer/explainer_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def _to_request_dict(self):
3737
request_dict = {}
3838

3939
if self.clarify_explainer_config:
40-
request_dict[
41-
"ClarifyExplainerConfig"
42-
] = self.clarify_explainer_config._to_request_dict()
40+
request_dict["ClarifyExplainerConfig"] = (
41+
self.clarify_explainer_config._to_request_dict()
42+
)
4343

4444
return request_dict

src/sagemaker/feature_store/feature_processor/lineage/_feature_processor_lineage.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,19 @@ class FeatureProcessorLineageHandler:
115115

116116
def create_lineage(self, tags: Optional[List[Dict[str, str]]] = None) -> None:
117117
"""Create and Update Feature Processor Lineage"""
118-
input_feature_group_contexts: List[
119-
FeatureGroupContexts
120-
] = self._retrieve_input_feature_group_contexts()
118+
input_feature_group_contexts: List[FeatureGroupContexts] = (
119+
self._retrieve_input_feature_group_contexts()
120+
)
121121
output_feature_group_contexts: FeatureGroupContexts = (
122122
self._retrieve_output_feature_group_contexts()
123123
)
124124
input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts()
125-
transformation_code_artifact: Optional[
126-
Artifact
127-
] = S3LineageEntityHandler.create_transformation_code_artifact(
128-
transformation_code=self.transformation_code,
129-
pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"),
130-
sagemaker_session=self.sagemaker_session,
125+
transformation_code_artifact: Optional[Artifact] = (
126+
S3LineageEntityHandler.create_transformation_code_artifact(
127+
transformation_code=self.transformation_code,
128+
pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"),
129+
sagemaker_session=self.sagemaker_session,
130+
)
131131
)
132132
if transformation_code_artifact is not None:
133133
logger.info("Created Transformation Code Artifact: %s", transformation_code_artifact)
@@ -362,40 +362,40 @@ def _update_pipeline_lineage(
362362
current_pipeline_version_context: Context = self._get_pipeline_version_context(
363363
last_update_time=pipeline_context.properties[LAST_UPDATE_TIME]
364364
)
365-
upstream_feature_group_associations: Iterator[
366-
AssociationSummary
367-
] = LineageAssociationHandler.list_upstream_associations(
368-
# pylint: disable=no-member
369-
entity_arn=current_pipeline_version_context.context_arn,
370-
source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE,
371-
sagemaker_session=self.sagemaker_session,
365+
upstream_feature_group_associations: Iterator[AssociationSummary] = (
366+
LineageAssociationHandler.list_upstream_associations(
367+
# pylint: disable=no-member
368+
entity_arn=current_pipeline_version_context.context_arn,
369+
source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE,
370+
sagemaker_session=self.sagemaker_session,
371+
)
372372
)
373373

374-
upstream_raw_data_associations: Iterator[
375-
AssociationSummary
376-
] = LineageAssociationHandler.list_upstream_associations(
377-
# pylint: disable=no-member
378-
entity_arn=current_pipeline_version_context.context_arn,
379-
source_type=DATA_SET,
380-
sagemaker_session=self.sagemaker_session,
374+
upstream_raw_data_associations: Iterator[AssociationSummary] = (
375+
LineageAssociationHandler.list_upstream_associations(
376+
# pylint: disable=no-member
377+
entity_arn=current_pipeline_version_context.context_arn,
378+
source_type=DATA_SET,
379+
sagemaker_session=self.sagemaker_session,
380+
)
381381
)
382382

383-
upstream_transformation_code: Iterator[
384-
AssociationSummary
385-
] = LineageAssociationHandler.list_upstream_associations(
386-
# pylint: disable=no-member
387-
entity_arn=current_pipeline_version_context.context_arn,
388-
source_type=TRANSFORMATION_CODE,
389-
sagemaker_session=self.sagemaker_session,
383+
upstream_transformation_code: Iterator[AssociationSummary] = (
384+
LineageAssociationHandler.list_upstream_associations(
385+
# pylint: disable=no-member
386+
entity_arn=current_pipeline_version_context.context_arn,
387+
source_type=TRANSFORMATION_CODE,
388+
sagemaker_session=self.sagemaker_session,
389+
)
390390
)
391391

392-
downstream_feature_group_associations: Iterator[
393-
AssociationSummary
394-
] = LineageAssociationHandler.list_downstream_associations(
395-
# pylint: disable=no-member
396-
entity_arn=current_pipeline_version_context.context_arn,
397-
destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE,
398-
sagemaker_session=self.sagemaker_session,
392+
downstream_feature_group_associations: Iterator[AssociationSummary] = (
393+
LineageAssociationHandler.list_downstream_associations(
394+
# pylint: disable=no-member
395+
entity_arn=current_pipeline_version_context.context_arn,
396+
destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE,
397+
sagemaker_session=self.sagemaker_session,
398+
)
399399
)
400400

401401
is_upstream_feature_group_equal: bool = self._compare_upstream_feature_groups(
@@ -598,9 +598,9 @@ def _update_last_transformation_code(
598598
last_transformation_code_artifact.properties["state"]
599599
== TRANSFORMATION_CODE_STATUS_ACTIVE
600600
):
601-
last_transformation_code_artifact.properties[
602-
"state"
603-
] = TRANSFORMATION_CODE_STATUS_INACTIVE
601+
last_transformation_code_artifact.properties["state"] = (
602+
TRANSFORMATION_CODE_STATUS_INACTIVE
603+
)
604604
last_transformation_code_artifact.properties["exclusive_end_date"] = self.pipeline[
605605
LAST_MODIFIED_TIME
606606
].strftime("%s")

src/sagemaker/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ def retrieve_pipeline_schedule_artifact(
172172
sagemaker_session=sagemaker_session,
173173
)
174174
pipeline_schedule_artifact.properties["pipeline_name"] = pipeline_schedule.pipeline_name
175-
pipeline_schedule_artifact.properties[
176-
"schedule_expression"
177-
] = pipeline_schedule.schedule_expression
175+
pipeline_schedule_artifact.properties["schedule_expression"] = (
176+
pipeline_schedule.schedule_expression
177+
)
178178
pipeline_schedule_artifact.properties["state"] = pipeline_schedule.state
179179
pipeline_schedule_artifact.properties["start_date"] = pipeline_schedule.start_date
180180
pipeline_schedule_artifact.save()

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,16 @@ def _retrieve_default_environment_variables(
112112

113113
default_environment_variables.update(instance_specific_environment_variables)
114114

115-
retrieve_gated_env_var_for_instance_type: Callable[
116-
[str], Optional[str]
117-
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
118-
model_id=model_id,
119-
model_version=model_version,
120-
region=region,
121-
tolerate_vulnerable_model=tolerate_vulnerable_model,
122-
tolerate_deprecated_model=tolerate_deprecated_model,
123-
sagemaker_session=sagemaker_session,
124-
instance_type=instance_type,
115+
retrieve_gated_env_var_for_instance_type: Callable[[str], Optional[str]] = (
116+
lambda instance_type: _retrieve_gated_model_uri_env_var_value(
117+
model_id=model_id,
118+
model_version=model_version,
119+
region=region,
120+
tolerate_vulnerable_model=tolerate_vulnerable_model,
121+
tolerate_deprecated_model=tolerate_deprecated_model,
122+
sagemaker_session=sagemaker_session,
123+
instance_type=instance_type,
124+
)
125125
)
126126

127127
gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type(
@@ -213,10 +213,10 @@ def _retrieve_gated_model_uri_env_var_value(
213213
sagemaker_session=sagemaker_session,
214214
)
215215

216-
s3_key: Optional[
217-
str
218-
] = model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301
219-
instance_type
216+
s3_key: Optional[str] = (
217+
model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301
218+
instance_type
219+
)
220220
)
221221
if s3_key is None:
222222
return None

0 commit comments

Comments
 (0)