@@ -413,6 +413,8 @@ def test_training_job_with_debugger(
413
413
sagemaker_session ,
414
414
pipeline_name ,
415
415
role ,
416
+ pytorch_training_latest_version ,
417
+ pytorch_training_latest_py_version ,
416
418
):
417
419
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
418
420
instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
@@ -424,6 +426,7 @@ def test_training_job_with_debugger(
424
426
]
425
427
debugger_hook_config = DebuggerHookConfig (
426
428
s3_output_path = f"s3://{ sagemaker_session .default_bucket ()} /{ uuid .uuid4 ()} /tensors"
429
+ )
427
430
428
431
base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
429
432
script_path = os .path .join (base_dir , "mnist.py" )
@@ -436,8 +439,8 @@ def test_training_job_with_debugger(
436
439
pytorch_estimator = PyTorch (
437
440
entry_point = script_path ,
438
441
role = "SageMakerRole" ,
439
- framework_version = "1.5.0" ,
440
- py_version = "py3" ,
442
+ framework_version = pytorch_training_latest_version ,
443
+ py_version = pytorch_training_latest_py_version ,
441
444
instance_count = instance_count ,
442
445
instance_type = instance_type ,
443
446
sagemaker_session = sagemaker_session ,
@@ -462,7 +465,7 @@ def test_training_job_with_debugger(
462
465
response = pipeline .create (role )
463
466
create_arn = response ["PipelineArn" ]
464
467
465
- execution = pipeline .start (parameters = {} )
468
+ execution = pipeline .start ()
466
469
response = execution .describe ()
467
470
assert response ["PipelineArn" ] == create_arn
468
471
0 commit comments