@@ -358,11 +358,11 @@ def test_debug_hook_disabled_with_checkpointing(
358
358
cpu_instance_type ,
359
359
):
360
360
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
361
- s3_output_path = os .path .join ("s3://" , sagemaker_session .default_bucket (), str (uuid .uuid4 ()))
361
+ s3_output_path = os .path .join (
362
+ "s3://" , sagemaker_session .default_bucket (), str (uuid .uuid4 ())
363
+ )
362
364
debugger_hook_config = DebuggerHookConfig (
363
- s3_output_path = os .path .join (
364
- s3_output_path , "tensors"
365
- )
365
+ s3_output_path = os .path .join (s3_output_path , "tensors" )
366
366
)
367
367
368
368
script_path = os .path .join (DATA_DIR , "mxnet_mnist" , "mnist_gluon.py" )
@@ -378,21 +378,10 @@ def test_debug_hook_disabled_with_checkpointing(
378
378
sagemaker_session = sagemaker_session ,
379
379
debugger_hook_config = debugger_hook_config ,
380
380
checkpoint_local_path = "/opt/ml/checkpoints" ,
381
- checkpoint_s3_uri = os .path .join (s3_output_path , "checkpoints" )
382
-
383
- )
384
-
385
- train_input = mx .sagemaker_session .upload_data (
386
- path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/mxnet_mnist/train"
381
+ checkpoint_s3_uri = os .path .join (s3_output_path , "checkpoints" ),
387
382
)
388
- test_input = mx .sagemaker_session .upload_data (
389
- path = os .path .join (data_path , "test" ), key_prefix = "integ-test-data/mxnet_mnist/test"
390
- )
391
-
392
- mx .fit ({"train" : train_input , "test" : test_input })
393
-
394
- job_description = mx .latest_training_job .describe ()
395
- assert "DebugHookConfig" not in job_description
383
+ mx ._prepare_for_training ()
384
+ assert mx .debugger_hook_config is False
396
385
397
386
398
387
def test_mxnet_with_rules_and_debugger_hook_config (
0 commit comments