Skip to content

Commit 0c15bc7

Browse files
committed
test: adding integration tests targetting MWMS in TF
1 parent c77f874 commit 0c15bc7

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

tests/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,24 @@ def inf_instance_family(inf_instance_type):
468468
return "_".join(inf_instance_type.split(".")[0:2])
469469

470470

471+
@pytest.fixture(scope="session")
472+
def imagenet_train_subset(request, sagemaker_session, tmpdir_factory):
473+
"""
474+
Copies the Imagenet dataset from the bucket it's hosted in to the local bucket in the test region
475+
"""
476+
local_path = tmpdir_factory.mktemp("imagenet_tfrecords_train_subset")
477+
sagemaker_session.download_data(
478+
path=local_path,
479+
bucket="collection-of-ml-datasets",
480+
key_prefix="Imagenet/TFRecords/train_1_of_10",
481+
)
482+
train_input = sagemaker_session.upload_data(
483+
path=local_path,
484+
key_prefix="integ-test-data/imagenet/TFRecords/train",
485+
)
486+
return train_input
487+
488+
471489
def pytest_generate_tests(metafunc):
472490
if "instance_type" in metafunc.fixturenames:
473491
boto_config = metafunc.config.getoption("--boto-config")

tests/integ/test_tf.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SCRIPT = "mnist.py"
3939
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
4040
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
41+
MWMS_DISTRIBUTION = {"multi_worker_mirrored_strategy": {"enabled": True}}
4142
TAGS = [{"Key": "some-key", "Value": "some-value"}]
4243
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
4344

@@ -181,6 +182,79 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
181182
)
182183

183184

185+
@pytest.mark.release
186+
@pytest.mark.skipif(
187+
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
188+
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
189+
reason="no ml.p2 or ml.p3 instances in this region",
190+
)
191+
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
192+
def test_mwms_gpu(
193+
sagemaker_session,
194+
tensorflow_training_latest_version,
195+
tensorflow_training_latest_py_version,
196+
capsys,
197+
imagenet_train_subset,
198+
**kwargs,
199+
):
200+
epochs = 1
201+
global_batch_size = 64
202+
train_steps = int(10**4 * epochs / global_batch_size)
203+
steps_per_loop = train_steps // 10
204+
overrides = (
205+
f"runtime.enable_xla=False,"
206+
f"runtime.num_gpus=1,"
207+
f"runtime.distribution_strategy=multi_worker_mirrored,"
208+
f"runtime.mixed_precision_dtype=float16,"
209+
f"task.train_data.global_batch_size={global_batch_size},"
210+
f"task.train_data.input_path=/opt/ml/input/data/training/train-000*,"
211+
f"task.train_data.cache=True,"
212+
f"trainer.train_steps={train_steps},"
213+
f"trainer.steps_per_loop={steps_per_loop},"
214+
f"trainer.summary_interval={steps_per_loop},"
215+
f"trainer.checkpoint_interval={train_steps},"
216+
f"task.model.backbone.type=resnet,"
217+
f"task.model.backbone.resnet.model_id=50"
218+
)
219+
estimator = TensorFlow(
220+
git_config={
221+
"repo": "https://github.com/tensorflow/models.git",
222+
"branch": "v2.9.2",
223+
},
224+
source_dir=".",
225+
entry_point="official/vision/train.py",
226+
model_dir=False,
227+
instance_type=kwargs["instance_type"],
228+
instance_count=2,
229+
framework_version=tensorflow_training_latest_version,
230+
py_version=tensorflow_training_latest_py_version,
231+
distribution=MWMS_DISTRIBUTION,
232+
hyperparameters={
233+
"experiment": "resnet_imagenet",
234+
"config_file": "official/vision/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml",
235+
"mode": "train",
236+
"model_dir": "/opt/ml/model",
237+
"params_override": overrides,
238+
},
239+
environment={
240+
"NCCL_DEBUG": "INFO",
241+
},
242+
max_run=60 * 60 * 1, # 1 hour
243+
role=ROLE,
244+
volume_size=400,
245+
sagemaker_session=sagemaker_session,
246+
disable_profiler=True,
247+
)
248+
249+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
250+
estimator.fit(inputs=imagenet_train_subset, job_name=unique_name_from_base("test-tf-mwms"))
251+
252+
captured = capsys.readouterr()
253+
logs = captured.out + captured.err
254+
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
255+
raise NotImplementedError("Check model saving")
256+
257+
184258
@pytest.mark.release
185259
def test_mnist_distributed_cpu(
186260
sagemaker_session,

0 commit comments

Comments
 (0)