|
24 | 24 | TensorBoardOutputConfig,
|
25 | 25 | )
|
26 | 26 | from sagemaker.mxnet.estimator import MXNet
|
| 27 | +from sagemaker.pytorch.estimator import PyTorch |
27 | 28 | from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
|
28 | 29 | from tests.integ.retry import retries
|
29 | 30 | from tests.integ.timeout import timeout
|
@@ -351,6 +352,77 @@ def test_mxnet_with_debugger_hook_config(
|
351 | 352 | _wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
|
352 | 353 |
|
353 | 354 |
|
| 355 | +def test_debug_hook_disabled_with_checkpointing( |
| 356 | + sagemaker_session, |
| 357 | + mxnet_training_latest_version, |
| 358 | + mxnet_training_latest_py_version, |
| 359 | + cpu_instance_type, |
| 360 | +): |
| 361 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 362 | + s3_output_path = os.path.join( |
| 363 | + "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()) |
| 364 | + ) |
| 365 | + debugger_hook_config = DebuggerHookConfig( |
| 366 | + s3_output_path=os.path.join(s3_output_path, "tensors") |
| 367 | + ) |
| 368 | + |
| 369 | + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") |
| 370 | + |
| 371 | + # Estimator with checkpointing enabled |
| 372 | + mx = MXNet( |
| 373 | + entry_point=script_path, |
| 374 | + role="SageMakerRole", |
| 375 | + framework_version=mxnet_training_latest_version, |
| 376 | + py_version=mxnet_training_latest_py_version, |
| 377 | + instance_count=1, |
| 378 | + instance_type=cpu_instance_type, |
| 379 | + sagemaker_session=sagemaker_session, |
| 380 | + debugger_hook_config=debugger_hook_config, |
| 381 | + checkpoint_local_path="/opt/ml/checkpoints", |
| 382 | + checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), |
| 383 | + ) |
| 384 | + mx._prepare_for_training() |
| 385 | + |
| 386 | + # Debug Hook should be enabled |
| 387 | + assert mx.debugger_hook_config is not None |
| 388 | + |
| 389 | + # Estimator with checkpointing enabled and Instance Count>1 |
| 390 | + mx = MXNet( |
| 391 | + entry_point=script_path, |
| 392 | + role="SageMakerRole", |
| 393 | + framework_version=mxnet_training_latest_version, |
| 394 | + py_version=mxnet_training_latest_py_version, |
| 395 | + instance_count=2, |
| 396 | + instance_type=cpu_instance_type, |
| 397 | + sagemaker_session=sagemaker_session, |
| 398 | + debugger_hook_config=debugger_hook_config, |
| 399 | + checkpoint_local_path="/opt/ml/checkpoints", |
| 400 | + checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), |
| 401 | + ) |
| 402 | + mx._prepare_for_training() |
| 403 | + # Debug Hook should be enabled |
| 404 | + assert mx.debugger_hook_config is False |
| 405 | + |
| 406 | + # Estimator with checkpointing enabled and Model Parallel Enabled |
| 407 | + pt = PyTorch( |
| 408 | + base_job_name="pytorch-smdataparallel-mnist", |
| 409 | + entry_point=script_path, |
| 410 | + role="SageMakerRole", |
| 411 | + framework_version="1.8.0", |
| 412 | + py_version="py36", |
| 413 | + instance_count=1, |
| 414 | + # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge |
| 415 | + instance_type="ml.p3.16xlarge", |
| 416 | + sagemaker_session=sagemaker_session, |
| 417 | + # Training using SMDataParallel Distributed Training Framework |
| 418 | + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, |
| 419 | + debugger_hook_config=False, |
| 420 | + ) |
| 421 | + pt._prepare_for_training() |
| 422 | + # Debug Hook should be enabled |
| 423 | + assert pt.debugger_hook_config is False |
| 424 | + |
| 425 | + |
354 | 426 | def test_mxnet_with_rules_and_debugger_hook_config(
|
355 | 427 | sagemaker_session,
|
356 | 428 | mxnet_training_latest_version,
|
|
0 commit comments