25
25
)
26
26
from sagemaker .mxnet .estimator import MXNet
27
27
from sagemaker .pytorch .estimator import PyTorch
28
+ from sagemaker .xgboost .estimator import XGBoost
28
29
from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
29
30
from tests .integ .retry import retries
30
31
from tests .integ .timeout import timeout
@@ -400,7 +401,7 @@ def test_debug_hook_disabled_with_checkpointing(
400
401
checkpoint_s3_uri = os .path .join (s3_output_path , "checkpoints" ),
401
402
)
402
403
mx ._prepare_for_training ()
403
- # Debug Hook should be enabled
404
+ # Debug Hook should be disabled
404
405
assert mx .debugger_hook_config is False
405
406
406
407
# Estimator with checkpointing enabled and Model Parallel Enabled
@@ -416,12 +417,30 @@ def test_debug_hook_disabled_with_checkpointing(
416
417
sagemaker_session = sagemaker_session ,
417
418
# Training using SMDataParallel Distributed Training Framework
418
419
distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
419
- debugger_hook_config = False ,
420
+ checkpoint_local_path = "/opt/ml/checkpoints" ,
421
+ checkpoint_s3_uri = os .path .join (s3_output_path , "checkpoints" ),
420
422
)
421
423
pt ._prepare_for_training ()
422
- # Debug Hook should be enabled
424
+ # Debug Hook should be disabled
423
425
assert pt .debugger_hook_config is False
424
426
427
+ # Estimator with checkpointing enabled with Xgboost Estimator
428
+ xg = XGBoost (
429
+ base_job_name = "test_xgboost" ,
430
+ entry_point = script_path ,
431
+ role = "SageMakerRole" ,
432
+ framework_version = "1.2-1" ,
433
+ py_version = "py3" ,
434
+ instance_count = 2 ,
435
+ # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
436
+ instance_type = "ml.p3.16xlarge" ,
437
+ sagemaker_session = sagemaker_session ,
438
+ # Training using SMDataParallel Distributed Training Framework
439
+ )
440
+ xg ._prepare_for_training ()
441
+ # Debug Hook should be enabled
442
+ assert xg .debugger_hook_config is not None
443
+
425
444
426
445
def test_mxnet_with_rules_and_debugger_hook_config (
427
446
sagemaker_session ,
0 commit comments