Skip to content

fix: Set flag when debugger is disabled #2520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/debugger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from sagemaker.debugger.debugger import ( # noqa: F401
CollectionConfig,
DEBUGGER_FLAG,
DebuggerHookConfig,
framework_name,
get_default_profiler_rule,
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/debugger/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sagemaker.utils import build_dict

framework_name = "debugger"
DEBUGGER_FLAG = "USE_SMDEBUG"


def get_rule_container_image_uri(region):
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
from sagemaker.debugger import (
DEBUGGER_FLAG,
DebuggerHookConfig,
FrameworkProfile,
get_default_profiler_rule,
Expand Down Expand Up @@ -2269,6 +2270,11 @@ def _validate_and_set_debugger_configs(self):
)
self.debugger_hook_config = False

if self.debugger_hook_config is False:
if self.environment is None:
self.environment = {}
self.environment[DEBUGGER_FLAG] = "0"

def _stage_user_code_in_s3(self):
"""Upload the user training script to s3 and return the location.

Expand Down
2 changes: 2 additions & 0 deletions tests/integ/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

from sagemaker.debugger.debugger import (
DEBUGGER_FLAG,
DebuggerHookConfig,
Rule,
rule_configs,
Expand Down Expand Up @@ -748,6 +749,7 @@ def test_mxnet_with_debugger_hook_config_disabled(
job_description = mx.latest_training_job.describe()

assert job_description.get("DebugHookConfig") is None
assert job_description.get("Environment", {}).get(DEBUGGER_FLAG) == "0"


def _get_rule_evaluation_statuses(job_description):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
patch,
)

from sagemaker.debugger import ProfilerConfig
from sagemaker.debugger import DEBUGGER_FLAG, ProfilerConfig
from sagemaker.estimator import Estimator
from sagemaker.tensorflow import TensorFlow
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
Expand Down Expand Up @@ -275,6 +275,7 @@ def test_training_step_tensorflow(sagemaker_session):
"sagemaker_distributed_dataparallel_custom_mpi_options": '""',
},
"ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"},
"Environment": {DEBUGGER_FLAG: "0"},
},
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
}
Expand Down