Skip to content

Commit d05ee41

Browse files
author
Rohan Gujarathi
committed
change: include workflow integ tests with clarify and debugger enabled
1 parent 791bf0a commit d05ee41

File tree

2 files changed

+392
-0
lines changed

2 files changed

+392
-0
lines changed

tests/integ/test_workflow.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,18 @@
1616
import os
1717
import re
1818
import time
19+
import uuid
1920

2021
import boto3
2122
import pytest
2223

2324
from botocore.config import Config
2425
from botocore.exceptions import WaiterError
26+
from sagemaker.debugger import (
27+
DebuggerHookConfig,
28+
Rule,
29+
rule_configs,
30+
)
2531
from sagemaker.inputs import CreateModelInput, TrainingInput
2632
from sagemaker.model import Model
2733
from sagemaker.processing import ProcessingInput, ProcessingOutput
@@ -401,3 +407,101 @@ def test_conditional_pytorch_training_model_registration(
401407
pipeline.delete()
402408
except Exception:
403409
pass
410+
411+
412+
def test_training_job_with_debugger(
413+
sagemaker_session,
414+
pipeline_name,
415+
role,
416+
):
417+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
418+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
419+
420+
rules = [
421+
Rule.sagemaker(rule_configs.vanishing_gradient()),
422+
Rule.sagemaker(base_config=rule_configs.all_zero(), rule_parameters={"tensor_regex": ".*"}),
423+
Rule.sagemaker(rule_configs.loss_not_decreasing()),
424+
]
425+
debugger_hook_config = DebuggerHookConfig(
426+
s3_output_path=os.path.join(
427+
"s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()), "tensors"
428+
)
429+
)
430+
431+
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
432+
script_path = os.path.join(base_dir, "mnist.py")
433+
input_path = sagemaker_session.upload_data(
434+
path=os.path.join(base_dir, "training"),
435+
key_prefix="integ-test-data/pytorch_mnist/training",
436+
)
437+
inputs = TrainingInput(s3_data=input_path)
438+
439+
pytorch_estimator = PyTorch(
440+
entry_point=script_path,
441+
role="SageMakerRole",
442+
framework_version="1.5.0",
443+
py_version="py3",
444+
instance_count=instance_count,
445+
instance_type=instance_type,
446+
sagemaker_session=sagemaker_session,
447+
rules=rules,
448+
debugger_hook_config=debugger_hook_config,
449+
)
450+
451+
step_train = TrainingStep(
452+
name="pytorch-train",
453+
estimator=pytorch_estimator,
454+
inputs=inputs,
455+
)
456+
457+
pipeline = Pipeline(
458+
name=pipeline_name,
459+
parameters=[instance_count, instance_type],
460+
steps=[step_train],
461+
sagemaker_session=sagemaker_session,
462+
)
463+
464+
try:
465+
response = pipeline.create(role)
466+
create_arn = response["PipelineArn"]
467+
468+
execution = pipeline.start(parameters={})
469+
response = execution.describe()
470+
assert response["PipelineArn"] == create_arn
471+
472+
try:
473+
execution.wait(delay=10, max_attempts=60)
474+
except WaiterError:
475+
pass
476+
execution_steps = execution.list_steps()
477+
training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"]
478+
job_description = sagemaker_session.sagemaker_client.describe_training_job(
479+
TrainingJobName=training_job_arn.split("/")[1]
480+
)
481+
482+
assert len(execution_steps) == 1
483+
assert execution_steps[0]["StepName"] == "pytorch-train"
484+
assert execution_steps[0]["StepStatus"] == "Succeeded"
485+
486+
for index, rule in enumerate(rules):
487+
assert (
488+
job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"]
489+
== rule.name
490+
)
491+
assert (
492+
job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"]
493+
== rule.image_uri
494+
)
495+
assert job_description["DebugRuleConfigurations"][index]["VolumeSizeInGB"] == 0
496+
assert (
497+
job_description["DebugRuleConfigurations"][index]["RuleParameters"][
498+
"rule_to_invoke"
499+
]
500+
== rule.rule_parameters["rule_to_invoke"]
501+
)
502+
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
503+
finally:
504+
try:
505+
pipeline.delete()
506+
except Exception:
507+
pass

0 commit comments

Comments
 (0)