Skip to content

Commit 88cc735

Browse files
author
BruceZhang@eitug
committed
fix black formatting issue
1 parent 4898bc1 commit 88cc735

File tree

2 files changed

+42
-71
lines changed

2 files changed

+42
-71
lines changed

tests/integ/test_training_compiler.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,7 @@ def skip_if_incompatible(gpu_instance_type, request):
7676
region = integ.test_region()
7777
if region not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS:
7878
pytest.skip("SageMaker Training Compiler is not supported in this region")
79-
if (
80-
gpu_instance_type == "ml.p3.16xlarge"
81-
and region not in integ.DATA_PARALLEL_TESTING_REGIONS
82-
):
79+
if gpu_instance_type == "ml.p3.16xlarge" and region not in integ.DATA_PARALLEL_TESTING_REGIONS:
8380
pytest.skip("Data parallel testing is not allowed in this region")
8481
if gpu_instance_type == "ml.p3.2xlarge" and region in integ.TRAINING_NO_P3_REGIONS:
8582
pytest.skip("no ml.p3 instances in this region")
@@ -127,9 +124,7 @@ def test_huggingface_pytorch(
127124
sagemaker_session=sagemaker_session,
128125
disable_profiler=True,
129126
compiler_config=HFTrainingCompilerConfig(),
130-
distribution={"pytorchxla": {"enabled": True}}
131-
if instance_count > 1
132-
else None,
127+
distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None,
133128
)
134129

135130
hf.fit(huggingface_dummy_dataset)
@@ -175,9 +170,7 @@ def test_pytorch(
175170
sagemaker_session=sagemaker_session,
176171
disable_profiler=True,
177172
compiler_config=PTTrainingCompilerConfig(),
178-
distribution={"pytorchxla": {"enabled": True}}
179-
if instance_count > 1
180-
else None,
173+
distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None,
181174
)
182175

183176
hf.fit(huggingface_dummy_dataset)
@@ -257,9 +250,7 @@ def test_tensorflow(
257250
py_version="py39",
258251
git_config={
259252
"repo": "https://github.com/tensorflow/models.git",
260-
"branch": "v"
261-
+ ".".join(tensorflow_training_latest_version.split(".")[:2])
262-
+ ".0",
253+
"branch": "v" + ".".join(tensorflow_training_latest_version.split(".")[:2]) + ".0",
263254
},
264255
source_dir=".",
265256
entry_point="official/vision/train.py",

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@
4545
REGION = "us-east-1"
4646
GPU = "ml.p3.2xlarge"
4747
SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"}
48-
UNSUPPORTED_GPU_INSTANCE_CLASSES = (
49-
EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
50-
)
48+
UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
5149

5250
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
5351

@@ -98,13 +96,9 @@ def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
9896
)
9997

10098

101-
def _create_train_job(
102-
version, instance_type, training_compiler_config, instance_count=1
103-
):
99+
def _create_train_job(version, instance_type, training_compiler_config, instance_count=1):
104100
return {
105-
"image_uri": _get_full_gpu_image_uri(
106-
version, instance_type, training_compiler_config
107-
),
101+
"image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config),
108102
"input_mode": "File",
109103
"input_config": [
110104
{
@@ -189,9 +183,7 @@ def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_v
189183
).fit()
190184

191185

192-
@pytest.mark.parametrize(
193-
"unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES
194-
)
186+
@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES)
195187
def test_unsupported_gpu_instance(
196188
unsupported_gpu_instance_class, pytorch_training_compiler_version
197189
):
@@ -351,19 +343,15 @@ def test_pytorchxla_distribution(
351343
compiler_config,
352344
instance_count=2,
353345
)
354-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
355-
"S3Uri"
356-
] = inputs
346+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
357347
expected_train_args["enable_sagemaker_metrics"] = False
358-
expected_train_args["hyperparameters"][
359-
TrainingCompilerConfig.HP_ENABLE_COMPILER
360-
] = json.dumps(True)
361-
expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(
348+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
362349
True
363350
)
364-
expected_train_args["hyperparameters"][
365-
TrainingCompilerConfig.HP_ENABLE_DEBUG
366-
] = json.dumps(False)
351+
expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True)
352+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
353+
False
354+
)
367355

368356
actual_train_args = sagemaker_session.method_calls[0][2]
369357
assert (
@@ -411,16 +399,14 @@ def test_default_compiler_config(
411399
expected_train_args = _create_train_job(
412400
pytorch_training_compiler_version, instance_type, compiler_config
413401
)
414-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
415-
"S3Uri"
416-
] = inputs
402+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
417403
expected_train_args["enable_sagemaker_metrics"] = False
418-
expected_train_args["hyperparameters"][
419-
TrainingCompilerConfig.HP_ENABLE_COMPILER
420-
] = json.dumps(True)
421-
expected_train_args["hyperparameters"][
422-
TrainingCompilerConfig.HP_ENABLE_DEBUG
423-
] = json.dumps(False)
404+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
405+
True
406+
)
407+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
408+
False
409+
)
424410

425411
actual_train_args = sagemaker_session.method_calls[0][2]
426412
assert (
@@ -465,16 +451,14 @@ def test_debug_compiler_config(
465451
expected_train_args = _create_train_job(
466452
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
467453
)
468-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
469-
"S3Uri"
470-
] = inputs
454+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
471455
expected_train_args["enable_sagemaker_metrics"] = False
472-
expected_train_args["hyperparameters"][
473-
TrainingCompilerConfig.HP_ENABLE_COMPILER
474-
] = json.dumps(True)
475-
expected_train_args["hyperparameters"][
476-
TrainingCompilerConfig.HP_ENABLE_DEBUG
477-
] = json.dumps(True)
456+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
457+
True
458+
)
459+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
460+
True
461+
)
478462

479463
actual_train_args = sagemaker_session.method_calls[0][2]
480464
assert (
@@ -519,16 +503,14 @@ def test_disable_compiler_config(
519503
expected_train_args = _create_train_job(
520504
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
521505
)
522-
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
523-
"S3Uri"
524-
] = inputs
506+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
525507
expected_train_args["enable_sagemaker_metrics"] = False
526-
expected_train_args["hyperparameters"][
527-
TrainingCompilerConfig.HP_ENABLE_COMPILER
528-
] = json.dumps(False)
529-
expected_train_args["hyperparameters"][
530-
TrainingCompilerConfig.HP_ENABLE_DEBUG
531-
] = json.dumps(False)
508+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
509+
False
510+
)
511+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
512+
False
513+
)
532514

533515
actual_train_args = sagemaker_session.method_calls[0][2]
534516
assert (
@@ -582,9 +564,7 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
582564
name="describe_training_job", return_value=returned_job_description
583565
)
584566

585-
estimator = PyTorch.attach(
586-
training_job_name="trcomp", sagemaker_session=sagemaker_session
587-
)
567+
estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session)
588568
assert estimator.latest_training_job.job_name == "trcomp"
589569
assert estimator.py_version == "py38"
590570
assert estimator.framework_version == "1.12.0"
@@ -596,12 +576,12 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
596576
assert estimator.output_path == "s3://place/output/trcomp"
597577
assert estimator.output_kms_key == ""
598578
assert estimator.hyperparameters()["training_steps"] == "100"
599-
assert estimator.hyperparameters()[
600-
TrainingCompilerConfig.HP_ENABLE_COMPILER
601-
] == json.dumps(compiler_enabled)
602-
assert estimator.hyperparameters()[
603-
TrainingCompilerConfig.HP_ENABLE_DEBUG
604-
] == json.dumps(debug_enabled)
579+
assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps(
580+
compiler_enabled
581+
)
582+
assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps(
583+
debug_enabled
584+
)
605585
assert estimator.source_dir == "s3://some/sourcedir.tar.gz"
606586
assert estimator.entry_point == "iris-dnn-classifier.py"
607587

0 commit comments

Comments
 (0)