|
24 | 24 | TensorBoardOutputConfig,
|
25 | 25 | )
|
26 | 26 | from sagemaker.mxnet.estimator import MXNet
|
| 27 | +from sagemaker.pytorch.estimator import PyTorch |
| 28 | +from sagemaker.tensorflow.estimator import TensorFlow |
| 29 | +from sagemaker.xgboost.estimator import XGBoost |
27 | 30 | from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
|
28 | 31 | from tests.integ.retry import retries
|
29 | 32 | from tests.integ.timeout import timeout
|
@@ -351,6 +354,115 @@ def test_mxnet_with_debugger_hook_config(
|
351 | 354 | _wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
|
352 | 355 |
|
353 | 356 |
|
| 357 | +def test_debug_hook_disabled_with_checkpointing( |
| 358 | + sagemaker_session, |
| 359 | + mxnet_training_latest_version, |
| 360 | + mxnet_training_latest_py_version, |
| 361 | + cpu_instance_type, |
| 362 | +): |
| 363 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 364 | + s3_output_path = os.path.join( |
| 365 | + "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()) |
| 366 | + ) |
| 367 | + debugger_hook_config = DebuggerHookConfig( |
| 368 | + s3_output_path=os.path.join(s3_output_path, "tensors") |
| 369 | + ) |
| 370 | + |
| 371 | + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") |
| 372 | + |
| 373 | + # Estimator with checkpointing enabled |
| 374 | + mx = MXNet( |
| 375 | + entry_point=script_path, |
| 376 | + role="SageMakerRole", |
| 377 | + framework_version=mxnet_training_latest_version, |
| 378 | + py_version=mxnet_training_latest_py_version, |
| 379 | + instance_count=1, |
| 380 | + instance_type=cpu_instance_type, |
| 381 | + sagemaker_session=sagemaker_session, |
| 382 | + debugger_hook_config=debugger_hook_config, |
| 383 | + checkpoint_local_path="/opt/ml/checkpoints", |
| 384 | + checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), |
| 385 | + ) |
| 386 | + mx._prepare_for_training() |
| 387 | + |
| 388 | + # Debug Hook should be enabled |
| 389 | + assert mx.debugger_hook_config is not None |
| 390 | + |
| 391 | + # Estimator with checkpointing enabled and Instance Count>1 |
| 392 | + mx = MXNet( |
| 393 | + entry_point=script_path, |
| 394 | + role="SageMakerRole", |
| 395 | + framework_version=mxnet_training_latest_version, |
| 396 | + py_version=mxnet_training_latest_py_version, |
| 397 | + instance_count=2, |
| 398 | + instance_type=cpu_instance_type, |
| 399 | + sagemaker_session=sagemaker_session, |
| 400 | + debugger_hook_config=debugger_hook_config, |
| 401 | + checkpoint_local_path="/opt/ml/checkpoints", |
| 402 | + checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), |
| 403 | + ) |
| 404 | + mx._prepare_for_training() |
| 405 | + # Debug Hook should be disabled |
| 406 | + assert mx.debugger_hook_config is False |
| 407 | + |
| 408 | + # Estimator with checkpointing enabled and SMDataParallel Enabled |
| 409 | + pt = PyTorch( |
| 410 | + base_job_name="pytorch-smdataparallel-mnist", |
| 411 | + entry_point=script_path, |
| 412 | + role="SageMakerRole", |
| 413 | + framework_version="1.8.0", |
| 414 | + py_version="py36", |
| 415 | + instance_count=1, |
| 416 | + # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge |
| 417 | + instance_type="ml.p3.16xlarge", |
| 418 | + sagemaker_session=sagemaker_session, |
| 419 | + # Training using SMDataParallel Distributed Training Framework |
| 420 | + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, |
| 421 | + checkpoint_local_path="/opt/ml/checkpoints", |
| 422 | + checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), |
| 423 | + ) |
| 424 | + pt._prepare_for_training() |
| 425 | + # Debug Hook should be disabled |
| 426 | + assert pt.debugger_hook_config is False |
| 427 | + |
| 428 | + # Estimator with checkpointing enabled and SMModelParallel Enabled |
| 429 | + tf = TensorFlow( |
| 430 | + base_job_name="tf-smdataparallel-mnist", |
| 431 | + entry_point=script_path, |
| 432 | + role="SageMakerRole", |
| 433 | + framework_version="2.4.1", |
| 434 | + py_version="py36", |
| 435 | + instance_count=1, |
| 436 | + # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge |
| 437 | + instance_type="ml.p3.16xlarge", |
| 438 | + sagemaker_session=sagemaker_session, |
| 439 | + # Training using SMDataParallel Distributed Training Framework |
| 440 | + distribution={"smdistributed": {"modelparallel": {"enabled": True}}}, |
| 441 | + checkpoint_local_path="/opt/ml/checkpoints", |
| 442 | + checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"), |
| 443 | + ) |
| 444 | + tf._prepare_for_training() |
| 445 | + # Debug Hook should be disabled |
| 446 | + assert tf.debugger_hook_config is False |
| 447 | + |
| 448 | + # Estimator with checkpointing enabled with Xgboost Estimator |
| 449 | + xg = XGBoost( |
| 450 | + base_job_name="test_xgboost", |
| 451 | + entry_point=script_path, |
| 452 | + role="SageMakerRole", |
| 453 | + framework_version="1.2-1", |
| 454 | + py_version="py3", |
| 455 | + instance_count=2, |
| 456 | + # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge |
| 457 | + instance_type="ml.p3.16xlarge", |
| 458 | + sagemaker_session=sagemaker_session, |
| 459 | + # Training using SMDataParallel Distributed Training Framework |
| 460 | + ) |
| 461 | + xg._prepare_for_training() |
| 462 | + # Debug Hook should be enabled |
| 463 | + assert xg.debugger_hook_config is not None |
| 464 | + |
| 465 | + |
354 | 466 | def test_mxnet_with_rules_and_debugger_hook_config(
|
355 | 467 | sagemaker_session,
|
356 | 468 | mxnet_training_latest_version,
|
|
0 commit comments