Skip to content

Commit 0d674f0

Browse files
author
Chia-Eng
committed
Fix pylint, flake8 and inte test failures
1 parent af75e99 commit 0d674f0

File tree

8 files changed

+38
-19
lines changed

8 files changed

+38
-19
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
enable_network_isolation=False,
124124
profiler_config=None,
125125
disable_profiler=False,
126-
environment = None,
126+
environment=None,
127127
**kwargs,
128128
):
129129
"""Initialize an ``EstimatorBase`` instance.
@@ -267,7 +267,8 @@ def __init__(
267267
``disable_profiler`` parameter to ``True``.
268268
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
269269
will be disabled (default: ``False``).
270-
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
270+
environment (dict[str, str]) : Environment variables to be set for
271+
use during training job (default: ``None``)
271272
272273
"""
273274
instance_count = renamed_kwargs(
@@ -1664,7 +1665,7 @@ def __init__(
16641665
enable_sagemaker_metrics=None,
16651666
profiler_config=None,
16661667
disable_profiler=False,
1667-
environment = None,
1668+
environment=None,
16681669
**kwargs,
16691670
):
16701671
"""Initialize an ``Estimator`` instance.
@@ -1813,7 +1814,8 @@ def __init__(
18131814
``disable_profiler`` parameter to ``True``.
18141815
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
18151816
will be disabled (default: ``False``).
1816-
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
1817+
environment (dict[str, str]) : Environment variables to be set for
1818+
use during training job (default: ``None``)
18171819
"""
18181820
self.image_uri = image_uri
18191821
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}

src/sagemaker/session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,12 @@ def train( # noqa: C901
523523
Series. For more information see:
524524
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
525525
(default: ``None``).
526-
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
526+
profiler_rule_configs (list[dict]): A list of profiler rule
527+
configurations.src/sagemaker/lineage/artifact.py:285
527528
profiler_config (dict): Configuration for how profiling information is emitted
528529
with SageMaker Profiler. (default: ``None``).
529-
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
530+
environment (dict[str, str]) : Environment variables to be set for
531+
use during training job (default: ``None``)
530532
531533
Returns:
532534
str: ARN of the training job, if it is created.
@@ -661,7 +663,8 @@ def _get_train_request( # noqa: C901
661663
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
662664
profiler_config(dict): Configuration for how profiling information is emitted with
663665
SageMaker Profiler. (default: ``None``).
664-
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
666+
environment (dict[str, str]) : Environment variables to be set for
667+
use during training job (default: ``None``)
665668
666669
Returns:
667670
Dict: a training request dict

tests/integ/test_tf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
3737
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
3838
TAGS = [{"Key": "some-key", "Value": "some-value"}]
39-
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
39+
ENV_INPUT = {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
4040

4141

4242
def test_mnist_with_checkpoint_config(
@@ -84,8 +84,11 @@ def test_mnist_with_checkpoint_config(
8484
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
8585
TrainingJobName=training_job_name
8686
)["CheckpointConfig"]
87+
actual_training_environment_variable_config = sagemaker_session.sagemaker_client.describe_training_job(
88+
TrainingJobName=training_job_name
89+
)["Environment"]
8790
assert actual_training_checkpoint_config == expected_training_checkpoint_config
88-
assert actual_training_checkpoint_config['Environment'] == ENV_INPUT
91+
assert actual_training_environment_variable_config == ENV_INPUT
8992

9093

9194
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):

tests/unit/sagemaker/tensorflow/test_estimator_init.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
REGION = "us-west-2"
2222

23-
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
23+
ENV_INPUT = {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
24+
2425

2526
@pytest.fixture()
2627
def sagemaker_session():
@@ -68,6 +69,7 @@ def test_framework_name(sagemaker_session):
6869
tf = _build_tf(sagemaker_session, framework_version="1.15.2", py_version="py3")
6970
assert tf._framework_name == "tensorflow"
7071

72+
7173
def test_tf_add_environment_variables(sagemaker_session):
7274
tf = _build_tf(
7375
sagemaker_session,
@@ -77,6 +79,7 @@ def test_tf_add_environment_variables(sagemaker_session):
7779
)
7880
assert tf.environment == ENV_INPUT
7981

82+
8083
def test_tf_miss_environment_variables(sagemaker_session):
8184
tf = _build_tf(
8285
sagemaker_session,
@@ -86,6 +89,7 @@ def test_tf_miss_environment_variables(sagemaker_session):
8689
)
8790
assert not tf.environment
8891

92+
8993
def test_enable_sm_metrics(sagemaker_session):
9094
tf = _build_tf(
9195
sagemaker_session,

tests/unit/test_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
7272
CODECOMMIT_BRANCH = "master"
7373
REPO_DIR = "/tmp/repo_dir"
74-
ENV_INPUT = {'env_key1':'env_val1', 'env_key2':'env_val2', 'env_key3':'env_val3'}
74+
ENV_INPUT = {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
7575

7676
DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}
7777

@@ -2682,6 +2682,7 @@ def test_generic_to_fit_with_sagemaker_metrics_missing(sagemaker_session):
26822682
args = sagemaker_session.train.call_args[1]
26832683
assert "enable_sagemaker_metrics" not in args
26842684

2685+
26852686
def test_add_environment_variables_to_train_args(sagemaker_session):
26862687
e = Estimator(
26872688
IMAGE_URI,
@@ -2699,6 +2700,7 @@ def test_add_environment_variables_to_train_args(sagemaker_session):
26992700
args = sagemaker_session.train.call_args[1]
27002701
assert args["environment"] == ENV_INPUT
27012702

2703+
27022704
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
27032705
e = Estimator(
27042706
IMAGE_URI,

tests/unit/test_mxnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"}
6767

68-
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
68+
ENV_INPUT = {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
6969

7070

7171
@pytest.fixture()
@@ -961,6 +961,7 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
961961

962962
assert model.image_uri == custom_hosting_image
963963

964+
964965
def test_mx_add_environment_variables(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
965966
mx = MXNet(
966967
entry_point=SCRIPT_PATH,
@@ -974,6 +975,7 @@ def test_mx_add_environment_variables(sagemaker_session, mxnet_training_version,
974975
)
975976
assert mx.environment == ENV_INPUT
976977

978+
977979
def test_mx_missing_environment_variables(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
978980
mx = MXNet(
979981
entry_point=SCRIPT_PATH,
@@ -987,6 +989,7 @@ def test_mx_missing_environment_variables(sagemaker_session, mxnet_training_vers
987989
)
988990
assert not mx.environment
989991

992+
990993
def test_mx_enable_sm_metrics(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
991994
mx = MXNet(
992995
entry_point=SCRIPT_PATH,

tests/unit/test_pytorch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]}
4848

49-
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
49+
ENV_INPUT = {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
5050

5151
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
5252

@@ -629,7 +629,7 @@ def test_pt_enable_sm_metrics(
629629

630630

631631
def test_pt_disable_sm_metrics(
632-
sagemaker_session, pytorch_training_version, pytorch_training_py_version
632+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
633633
):
634634
pytorch = _pytorch_estimator(
635635
sagemaker_session,
@@ -639,9 +639,10 @@ def test_pt_disable_sm_metrics(
639639
)
640640
assert not pytorch.enable_sagemaker_metrics
641641

642+
642643
def test_pt_add_environment_variables(
643-
sagemaker_session, pytorch_training_version, pytorch_training_py_version
644-
):
644+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
645+
):
645646
pytorch = _pytorch_estimator(
646647
sagemaker_session,
647648
framework_version=pytorch_training_version,
@@ -650,9 +651,10 @@ def test_pt_add_environment_variables(
650651
)
651652
assert pytorch.environment
652653

654+
653655
def test_pt_miss_environment_variables(
654-
sagemaker_session, pytorch_training_version, pytorch_training_py_version
655-
):
656+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
657+
):
656658
pytorch = _pytorch_estimator(
657659
sagemaker_session,
658660
framework_version=pytorch_training_version,

tests/unit/test_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
SAMPLE_PARAM_RANGES = [{"Name": "mini_batch_size", "MinValue": "10", "MaxValue": "100"}]
3838

39-
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
39+
ENV_INPUT = {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
4040

4141
REGION = "us-west-2"
4242
STS_ENDPOINT = "sts.us-west-2.amazonaws.com"

0 commit comments

Comments
 (0)