Skip to content

Commit 67e26c9

Browse files
authored
feature: Support profiler config in the pipeline training job step (#2183)
1 parent c13fc6d commit 67e26c9

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/sagemaker/workflow/steps.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,6 @@ def arguments(self) -> RequestType:
160160
NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
161161
The TrainingJobName and ExperimentConfig attributes cannot be included.
162162
"""
163-
self.estimator.disable_profiler = True
164-
self.estimator.profiler_config = None
165-
self.estimator.profiler_rules = None
166163

167164
self.estimator._prepare_for_training()
168165
train_args = _TrainingJob._get_train_args(

tests/integ/test_workflow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def test_conditional_pytorch_training_model_registration(
450450
pass
451451

452452

453-
def test_training_job_with_debugger(
453+
def test_training_job_with_debugger_and_profiler(
454454
sagemaker_session,
455455
pipeline_name,
456456
role,
@@ -535,6 +535,9 @@ def test_training_job_with_debugger(
535535
config["RuleParameters"]["rule_to_invoke"] == rule.rule_parameters["rule_to_invoke"]
536536
)
537537
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
538+
539+
assert job_description["ProfilingStatus"] == "Enabled"
540+
assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500
538541
finally:
539542
try:
540543
pipeline.delete()

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
PropertyMock,
2222
)
2323

24+
from sagemaker.debugger import ProfilerConfig
2425
from sagemaker.estimator import Estimator
2526
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
2627
from sagemaker.model import Model
@@ -112,6 +113,8 @@ def test_training_step(sagemaker_session):
112113
role=ROLE,
113114
instance_count=1,
114115
instance_type="c4.4xlarge",
116+
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
117+
rules=[],
115118
sagemaker_session=sagemaker_session,
116119
)
117120
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
@@ -144,6 +147,10 @@ def test_training_step(sagemaker_session):
144147
},
145148
"RoleArn": ROLE,
146149
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
150+
"ProfilerConfig": {
151+
"ProfilingIntervalInMilliseconds": 500,
152+
"S3OutputPath": f"s3://{BUCKET}/",
153+
},
147154
},
148155
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
149156
}

0 commit comments

Comments
 (0)