Skip to content

Commit e3c1555

Browse files
fix: Set flag when debugger is disabled (#2520)
* Set flag when debugger is disabled * Address edge case * Address edge case * Address edge case * Fix formatting errors * retrigger CI Co-authored-by: Shreya Pandit <[email protected]>
1 parent 7d56242 commit e3c1555

File tree

5 files changed

+12
-1
lines changed

5 files changed

+12
-1
lines changed

src/sagemaker/debugger/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from sagemaker.debugger.debugger import ( # noqa: F401
1717
CollectionConfig,
18+
DEBUGGER_FLAG,
1819
DebuggerHookConfig,
1920
framework_name,
2021
get_default_profiler_rule,

src/sagemaker/debugger/debugger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.utils import build_dict
3333

3434
framework_name = "debugger"
35+
DEBUGGER_FLAG = "USE_SMDEBUG"
3536

3637

3738
def get_rule_container_image_uri(region):

src/sagemaker/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.analytics import TrainingJobAnalytics
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import (
32+
DEBUGGER_FLAG,
3233
DebuggerHookConfig,
3334
FrameworkProfile,
3435
get_default_profiler_rule,
@@ -2269,6 +2270,11 @@ def _validate_and_set_debugger_configs(self):
22692270
)
22702271
self.debugger_hook_config = False
22712272

2273+
if self.debugger_hook_config is False:
2274+
if self.environment is None:
2275+
self.environment = {}
2276+
self.environment[DEBUGGER_FLAG] = "0"
2277+
22722278
def _stage_user_code_in_s3(self):
22732279
"""Upload the user training script to s3 and return the location.
22742280

tests/integ/test_debugger.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker.debugger.debugger import (
21+
DEBUGGER_FLAG,
2122
DebuggerHookConfig,
2223
Rule,
2324
rule_configs,
@@ -748,6 +749,7 @@ def test_mxnet_with_debugger_hook_config_disabled(
748749
job_description = mx.latest_training_job.describe()
749750

750751
assert job_description.get("DebugHookConfig") is None
752+
assert job_description.get("Environment", {}).get(DEBUGGER_FLAG) == "0"
751753

752754

753755
def _get_rule_evaluation_statuses(job_description):

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
patch,
2424
)
2525

26-
from sagemaker.debugger import ProfilerConfig
26+
from sagemaker.debugger import DEBUGGER_FLAG, ProfilerConfig
2727
from sagemaker.estimator import Estimator
2828
from sagemaker.tensorflow import TensorFlow
2929
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
@@ -275,6 +275,7 @@ def test_training_step_tensorflow(sagemaker_session):
275275
"sagemaker_distributed_dataparallel_custom_mpi_options": '""',
276276
},
277277
"ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"},
278+
"Environment": {DEBUGGER_FLAG: "0"},
278279
},
279280
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
280281
}

0 commit comments

Comments
 (0)