@@ -351,6 +351,50 @@ def test_mxnet_with_debugger_hook_config(
351
351
_wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
352
352
353
353
354
+ def test_debug_hook_disabled_with_checkpointing (
355
+ sagemaker_session ,
356
+ mxnet_training_latest_version ,
357
+ mxnet_training_latest_py_version ,
358
+ cpu_instance_type ,
359
+ ):
360
+ with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
361
+ s3_output_path = os .path .join ("s3://" , sagemaker_session .default_bucket (), str (uuid .uuid4 ()))
362
+ debugger_hook_config = DebuggerHookConfig (
363
+ s3_output_path = os .path .join (
364
+ s3_output_path , "tensors"
365
+ )
366
+ )
367
+
368
+ script_path = os .path .join (DATA_DIR , "mxnet_mnist" , "mnist_gluon.py" )
369
+ data_path = os .path .join (DATA_DIR , "mxnet_mnist" )
370
+
371
+ mx = MXNet (
372
+ entry_point = script_path ,
373
+ role = "SageMakerRole" ,
374
+ framework_version = mxnet_training_latest_version ,
375
+ py_version = mxnet_training_latest_py_version ,
376
+ instance_count = 1 ,
377
+ instance_type = cpu_instance_type ,
378
+ sagemaker_session = sagemaker_session ,
379
+ debugger_hook_config = debugger_hook_config ,
380
+ checkpoint_local_path = "/opt/ml/checkpoints" ,
381
+ checkpoint_s3_uri = os .path .join (s3_output_path , "checkpoints" )
382
+
383
+ )
384
+
385
+ train_input = mx .sagemaker_session .upload_data (
386
+ path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/mxnet_mnist/train"
387
+ )
388
+ test_input = mx .sagemaker_session .upload_data (
389
+ path = os .path .join (data_path , "test" ), key_prefix = "integ-test-data/mxnet_mnist/test"
390
+ )
391
+
392
+ mx .fit ({"train" : train_input , "test" : test_input })
393
+
394
+ job_description = mx .latest_training_job .describe ()
395
+ assert "DebugHookConfig" not in job_description
396
+
397
+
354
398
def test_mxnet_with_rules_and_debugger_hook_config (
355
399
sagemaker_session ,
356
400
mxnet_training_latest_version ,
0 commit comments