Skip to content

Commit 506d8d2

Browse files
committed
Set flag when debugger is disabled
1 parent 9e1fe91 commit 506d8d2

File tree

4 files changed

+9
-0
lines changed

4 files changed

+9
-0
lines changed

src/sagemaker/debugger/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from sagemaker.debugger.debugger import ( # noqa: F401
1717
CollectionConfig,
18+
DEBUGGER_FLAG,
1819
DebuggerHookConfig,
1920
framework_name,
2021
get_default_profiler_rule,

src/sagemaker/debugger/debugger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.utils import build_dict
3333

3434
framework_name = "debugger"
35+
DEBUGGER_FLAG = "USE_SMDEBUG"
3536

3637

3738
def get_rule_container_image_uri(region):

src/sagemaker/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.analytics import TrainingJobAnalytics
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import (
32+
DEBUGGER_FLAG,
3233
DebuggerHookConfig,
3334
FrameworkProfile,
3435
get_default_profiler_rule,
@@ -2269,6 +2270,9 @@ def _validate_and_set_debugger_configs(self):
22692270
)
22702271
self.debugger_hook_config = False
22712272

2273+
if self.debugger_hook_config is False:
2274+
self.environment[DEBUGGER_FLAG] = "0"
2275+
22722276
def _stage_user_code_in_s3(self):
22732277
"""Upload the user training script to s3 and return the location.
22742278

tests/integ/test_debugger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker.debugger.debugger import (
21+
DEBUGGER_FLAG,
2122
DebuggerHookConfig,
2223
Rule,
2324
rule_configs,
@@ -745,6 +746,8 @@ def test_mxnet_with_debugger_hook_config_disabled(
745746

746747
mx.fit({"train": train_input, "test": test_input})
747748

749+
assert mx.environment.get(DEBUGGER_FLAG) == "0"
750+
748751
job_description = mx.latest_training_job.describe()
749752

750753
assert job_description.get("DebugHookConfig") is None

0 commit comments

Comments
 (0)