Skip to content

Commit 7c9f870

Browse files
Add optional typecheck for nullable parameters (#4767)
* Add optional typecheck for nullable parameters * Add optional typecheck for nullable parameters
1 parent 330e47a commit 7c9f870

13 files changed

+91
-92
lines changed

src/sagemaker/workflow/automl_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __init__(
3434
self,
3535
name: str,
3636
step_args: _JobStepArguments,
37-
display_name: str = None,
38-
description: str = None,
39-
cache_config: CacheConfig = None,
37+
display_name: Optional[str] = None,
38+
description: Optional[str] = None,
39+
cache_config: Optional[CacheConfig] = None,
4040
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
41-
retry_policies: List[RetryPolicy] = None,
41+
retry_policies: Optional[List[RetryPolicy]] = None,
4242
):
4343
"""Construct a `AutoMLStep`, given a `AutoML` instance.
4444

src/sagemaker/workflow/callback_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def __init__(
8484
sqs_queue_url: str,
8585
inputs: dict,
8686
outputs: List[CallbackOutput],
87-
display_name: str = None,
88-
description: str = None,
89-
cache_config: CacheConfig = None,
87+
display_name: Optional[str] = None,
88+
description: Optional[str] = None,
89+
cache_config: Optional[CacheConfig] = None,
9090
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
9191
):
9292
"""Constructs a CallbackStep.

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def __init__(
159159
skip_check: Union[bool, PipelineVariable] = False,
160160
fail_on_violation: Union[bool, PipelineVariable] = True,
161161
register_new_baseline: Union[bool, PipelineVariable] = False,
162-
model_package_group_name: Union[str, PipelineVariable] = None,
163-
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
164-
display_name: str = None,
165-
description: str = None,
166-
cache_config: CacheConfig = None,
162+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
163+
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
164+
display_name: Optional[str] = None,
165+
description: Optional[str] = None,
166+
cache_config: Optional[CacheConfig] = None,
167167
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
168168
):
169169
"""Constructs a ClarifyCheckStep.

src/sagemaker/workflow/condition_step.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from sagemaker.workflow.step_collections import StepCollection
2222
from sagemaker.workflow.functions import JsonGet as NewJsonGet
23-
from sagemaker.workflow.step_outputs import StepOutput
2423
from sagemaker.workflow.steps import (
2524
Step,
2625
StepTypeEnum,
@@ -41,11 +40,11 @@ def __init__(
4140
self,
4241
name: str,
4342
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
44-
display_name: str = None,
45-
description: str = None,
46-
conditions: List[Condition] = None,
47-
if_steps: List[Union[Step, StepCollection, StepOutput]] = None,
48-
else_steps: List[Union[Step, StepCollection, StepOutput]] = None,
43+
display_name: Optional[str] = None,
44+
description: Optional[str] = None,
45+
conditions: Optional[List[Condition]] = None,
46+
if_steps: Optional[List[Union[Step, StepCollection]]] = None,
47+
else_steps: Optional[List[Union[Step, StepCollection]]] = None,
4948
):
5049
"""Construct a ConditionStep for pipelines to support conditional branching.
5150

src/sagemaker/workflow/emr_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def __init__(
161161
cluster_id: str,
162162
step_config: EMRStepConfig,
163163
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
164-
cache_config: CacheConfig = None,
165-
cluster_config: Dict[str, Any] = None,
166-
execution_role_arn: str = None,
164+
cache_config: Optional[CacheConfig] = None,
165+
cluster_config: Optional[Dict[str, Any]] = None,
166+
execution_role_arn: Optional[str] = None,
167167
):
168168
"""Constructs an `EMRStep`.
169169

src/sagemaker/workflow/fail_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ class FailStep(Step):
2929
def __init__(
3030
self,
3131
name: str,
32-
error_message: Union[str, PipelineVariable] = None,
33-
display_name: str = None,
34-
description: str = None,
32+
error_message: Optional[Union[str, PipelineVariable]] = None,
33+
display_name: Optional[str] = None,
34+
description: Optional[str] = None,
3535
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
3636
):
3737
"""Constructs a `FailStep`.

src/sagemaker/workflow/lambda_step.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def __init__(
8484
self,
8585
name: str,
8686
lambda_func: Lambda,
87-
display_name: str = None,
88-
description: str = None,
89-
inputs: dict = None,
90-
outputs: List[LambdaOutput] = None,
91-
cache_config: CacheConfig = None,
87+
display_name: Optional[str] = None,
88+
description: Optional[str] = None,
89+
inputs: Optional[dict] = None,
90+
outputs: Optional[List[LambdaOutput]] = None,
91+
cache_config: Optional[CacheConfig] = None,
9292
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
9393
):
9494
"""Constructs a LambdaStep.

src/sagemaker/workflow/monitor_batch_transform_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848
check_job_configuration: CheckJobConfig,
4949
monitor_before_transform: bool = False,
5050
fail_on_violation: Union[bool, PipelineVariable] = True,
51-
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
52-
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
51+
supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None,
52+
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
5353
display_name: Optional[str] = None,
5454
description: Optional[str] = None,
5555
):

src/sagemaker/workflow/quality_check_step.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ def __init__(
125125
skip_check: Union[bool, PipelineVariable] = False,
126126
fail_on_violation: Union[bool, PipelineVariable] = True,
127127
register_new_baseline: Union[bool, PipelineVariable] = False,
128-
model_package_group_name: Union[str, PipelineVariable] = None,
129-
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
130-
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
131-
display_name: str = None,
132-
description: str = None,
133-
cache_config: CacheConfig = None,
128+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
129+
supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None,
130+
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
131+
display_name: Optional[str] = None,
132+
description: Optional[str] = None,
133+
cache_config: Optional[CacheConfig] = None,
134134
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
135135
):
136136
"""Constructs a QualityCheckStep.

src/sagemaker/workflow/retry.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from enum import Enum
17-
from typing import List
17+
from typing import List, Optional
1818
import attr
1919

2020
from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType
@@ -133,8 +133,8 @@ def __init__(
133133
exception_types: List[StepExceptionTypeEnum],
134134
backoff_rate: float = 2.0,
135135
interval_seconds: int = 1,
136-
max_attempts: int = None,
137-
expire_after_mins: int = None,
136+
max_attempts: Optional[int] = None,
137+
expire_after_mins: Optional[int] = None,
138138
):
139139
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
140140
for exception_type in exception_types:
@@ -177,12 +177,12 @@ class SageMakerJobStepRetryPolicy(RetryPolicy):
177177

178178
def __init__(
179179
self,
180-
exception_types: List[SageMakerJobExceptionTypeEnum] = None,
181-
failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None,
180+
exception_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None,
181+
failure_reason_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None,
182182
backoff_rate: float = 2.0,
183183
interval_seconds: int = 1,
184-
max_attempts: int = None,
185-
expire_after_mins: int = None,
184+
max_attempts: Optional[int] = None,
185+
expire_after_mins: Optional[int] = None,
186186
):
187187
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
188188

src/sagemaker/workflow/selective_execution_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Pipeline Parallelism Configuration"""
1414
from __future__ import absolute_import
15-
from typing import List
15+
from typing import List, Optional
1616
from sagemaker.workflow.entities import RequestType
1717

1818

@@ -25,8 +25,8 @@ class SelectiveExecutionConfig:
2525
def __init__(
2626
self,
2727
selected_steps: List[str],
28-
source_pipeline_execution_arn: str = None,
2928
reference_latest_execution: bool = True,
29+
source_pipeline_execution_arn: Optional[str] = None,
3030
):
3131
"""Create a `SelectiveExecutionConfig`.
3232

src/sagemaker/workflow/step_collections.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def __init__(
7272
response_types,
7373
inference_instances=None,
7474
transform_instances=None,
75-
estimator: EstimatorBase = None,
75+
estimator: Optional[EstimatorBase] = None,
7676
model_data=None,
7777
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
78-
repack_model_step_retry_policies: List[RetryPolicy] = None,
79-
register_model_step_retry_policies: List[RetryPolicy] = None,
78+
repack_model_step_retry_policies: Optional[List[RetryPolicy]] = None,
79+
register_model_step_retry_policies: Optional[List[RetryPolicy]] = None,
8080
model_package_group_name=None,
8181
model_metrics=None,
8282
approval_status=None,
@@ -85,7 +85,7 @@ def __init__(
8585
display_name=None,
8686
description=None,
8787
tags=None,
88-
model: Union[Model, PipelineModel] = None,
88+
model: Optional[Union[Model, PipelineModel]] = None,
8989
drift_check_baselines=None,
9090
customer_metadata_properties=None,
9191
domain=None,
@@ -328,8 +328,8 @@ def __init__(
328328
instance_count,
329329
instance_type,
330330
transform_inputs,
331-
description: str = None,
332-
display_name: str = None,
331+
description: Optional[str] = None,
332+
display_name: Optional[str] = None,
333333
# model arguments
334334
image_uri=None,
335335
predictor_cls=None,
@@ -346,9 +346,9 @@ def __init__(
346346
volume_kms_key=None,
347347
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
348348
# step retry policies
349-
repack_model_step_retry_policies: List[RetryPolicy] = None,
350-
model_step_retry_policies: List[RetryPolicy] = None,
351-
transform_step_retry_policies: List[RetryPolicy] = None,
349+
repack_model_step_retry_policies: Optional[List[RetryPolicy]] = None,
350+
model_step_retry_policies: Optional[List[RetryPolicy]] = None,
351+
transform_step_retry_policies: Optional[List[RetryPolicy]] = None,
352352
**kwargs,
353353
):
354354
"""Construct steps required for a Transformer step collection:

src/sagemaker/workflow/steps.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,10 @@ def __init__(
362362
self,
363363
name: str,
364364
step_type: StepTypeEnum,
365-
display_name: str = None,
366-
description: str = None,
367-
depends_on: Optional[List[Union[str, Step, "StepCollection", StepOutput]]] = None,
368-
retry_policies: List[RetryPolicy] = None,
365+
display_name: Optional[str] = None,
366+
description: Optional[str] = None,
367+
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
368+
retry_policies: Optional[List[RetryPolicy]] = None,
369369
):
370370
super().__init__(
371371
name=name,
@@ -404,14 +404,14 @@ class TrainingStep(ConfigurableRetryStep):
404404
def __init__(
405405
self,
406406
name: str,
407-
step_args: _JobStepArguments = None,
408-
estimator: EstimatorBase = None,
409-
display_name: str = None,
410-
description: str = None,
411-
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
412-
cache_config: CacheConfig = None,
407+
step_args: Optional[_JobStepArguments] = None,
408+
estimator: Optional[EstimatorBase] = None,
409+
display_name: Optional[str] = None,
410+
description: Optional[str] = None,
411+
inputs: Optional[Union[TrainingInput, dict, str, FileSystemInput]] = None,
412+
cache_config: Optional[CacheConfig] = None,
413413
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
414-
retry_policies: List[RetryPolicy] = None,
414+
retry_policies: Optional[List[RetryPolicy]] = None,
415415
):
416416
"""Construct a `TrainingStep`, given an `EstimatorBase` instance.
417417
@@ -681,14 +681,14 @@ class TransformStep(ConfigurableRetryStep):
681681
def __init__(
682682
self,
683683
name: str,
684-
step_args: _JobStepArguments = None,
685-
transformer: Transformer = None,
686-
inputs: TransformInput = None,
687-
display_name: str = None,
688-
description: str = None,
689-
cache_config: CacheConfig = None,
684+
step_args: Optional[_JobStepArguments] = None,
685+
transformer: Optional[Transformer] = None,
686+
inputs: Optional[TransformInput] = None,
687+
display_name: Optional[str] = None,
688+
description: Optional[str] = None,
689+
cache_config: Optional[CacheConfig] = None,
690690
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
691-
retry_policies: List[RetryPolicy] = None,
691+
retry_policies: Optional[List[RetryPolicy]] = None,
692692
):
693693
"""Constructs a `TransformStep`, given a `Transformer` instance.
694694
@@ -808,19 +808,19 @@ class ProcessingStep(ConfigurableRetryStep):
808808
def __init__(
809809
self,
810810
name: str,
811-
step_args: _JobStepArguments = None,
812-
processor: Processor = None,
813-
display_name: str = None,
814-
description: str = None,
815-
inputs: List[ProcessingInput] = None,
816-
outputs: List[ProcessingOutput] = None,
817-
job_arguments: List[str] = None,
818-
code: str = None,
819-
property_files: List[PropertyFile] = None,
820-
cache_config: CacheConfig = None,
811+
step_args: Optional[_JobStepArguments] = None,
812+
processor: Optional[Processor] = None,
813+
display_name: Optional[str] = None,
814+
description: Optional[str] = None,
815+
inputs: Optional[List[ProcessingInput]] = None,
816+
outputs: Optional[List[ProcessingOutput]] = None,
817+
job_arguments: Optional[List[str]] = None,
818+
code: Optional[str] = None,
819+
property_files: Optional[List[PropertyFile]] = None,
820+
cache_config: Optional[CacheConfig] = None,
821821
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
822-
retry_policies: List[RetryPolicy] = None,
823-
kms_key=None,
822+
retry_policies: Optional[List[RetryPolicy]] = None,
823+
kms_key: Optional[str] = None,
824824
):
825825
"""Construct a `ProcessingStep`, given a `Processor` instance.
826826
@@ -980,15 +980,15 @@ class TuningStep(ConfigurableRetryStep):
980980
def __init__(
981981
self,
982982
name: str,
983-
step_args: _JobStepArguments = None,
984-
tuner: HyperparameterTuner = None,
985-
display_name: str = None,
986-
description: str = None,
983+
step_args: Optional[_JobStepArguments] = None,
984+
tuner: Optional[HyperparameterTuner] = None,
985+
display_name: Optional[str] = None,
986+
description: Optional[str] = None,
987987
inputs=None,
988-
job_arguments: List[str] = None,
989-
cache_config: CacheConfig = None,
988+
job_arguments: Optional[List[str]] = None,
989+
cache_config: Optional[CacheConfig] = None,
990990
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
991-
retry_policies: List[RetryPolicy] = None,
991+
retry_policies: Optional[List[RetryPolicy]] = None,
992992
):
993993
"""Construct a `TuningStep`, given a `HyperparameterTuner` instance.
994994

0 commit comments

Comments
 (0)