Skip to content

change: disable debugger/profiler in cgk region #3312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,29 @@ def _prepare_debugger_for_training(self):
self.debugger_hook_config.s3_output_path = self.output_path
self.debugger_rule_configs = self._prepare_debugger_rules()
self._prepare_collection_configs()
self._validate_and_set_debugger_configs()
if not self.debugger_hook_config:
if self.environment is None:
self.environment = {}
self.environment[DEBUGGER_FLAG] = "0"

def _validate_and_set_debugger_configs(self):
"""Set defaults for debugging."""
region_supports_debugger = _region_supports_debugger(
self.sagemaker_session.boto_region_name
)

if region_supports_debugger:
if self.debugger_hook_config in [None, {}]:
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
else:
if self.debugger_hook_config is not False and self.debugger_hook_config:
# when user set debugger config in a unsupported region
raise ValueError(
"Current region does not support debugger but debugger hook config is set!"
)
# disable debugger in unsupported regions
self.debugger_hook_config = False

def _prepare_debugger_rules(self):
"""Set any necessary values in debugger rules, if they are provided."""
Expand Down Expand Up @@ -1766,6 +1789,8 @@ def enable_default_profiling(self):
Debugger monitoring is disabled.
"""
self._ensure_latest_training_job()
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
raise ValueError("Current region does not support profiler / debugger!")

training_job_details = self.latest_training_job.describe()

Expand Down Expand Up @@ -1799,6 +1824,8 @@ def disable_profiling(self):

"""
self._ensure_latest_training_job()
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
raise ValueError("Current region does not support profiler / debugger!")

training_job_details = self.latest_training_job.describe()

Expand Down Expand Up @@ -1852,6 +1879,8 @@ def update_profiler(

"""
self._ensure_latest_training_job()
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
raise ValueError("Current region does not support profiler / debugger!")

if (
not rules
Expand Down Expand Up @@ -2872,13 +2901,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None:

def _validate_and_set_debugger_configs(self):
"""Set defaults for debugging."""
if self.debugger_hook_config is None and _region_supports_debugger(
self.sagemaker_session.boto_region_name
):
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
elif not self.debugger_hook_config:
# set hook config to False if _region_supports_debugger is False
self.debugger_hook_config = False
super(Framework, self)._validate_and_set_debugger_configs()

# Disable debugger if checkpointing is enabled by the customer
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
Expand All @@ -2901,11 +2924,6 @@ def _validate_and_set_debugger_configs(self):
)
self.debugger_hook_config = False

if self.debugger_hook_config is False:
if self.environment is None:
self.environment = {}
self.environment[DEBUGGER_FLAG] = "0"

def _model_source_dir(self):
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.

Expand Down
22 changes: 20 additions & 2 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,26 @@
"only one worker per host regardless of the number of GPUs."
)

DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
PROFILER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
DEBUGGER_UNSUPPORTED_REGIONS = (
"us-iso-east-1",
"ap-southeast-3",
"ap-southeast-4",
"eu-south-2",
"me-central-1",
"ap-south-2",
"eu-central-2",
"us-gov-east-1",
)
PROFILER_UNSUPPORTED_REGIONS = (
"us-iso-east-1",
"ap-southeast-3",
"ap-southeast-4",
"eu-south-2",
"me-central-1",
"ap-south-2",
"eu-central-2",
"us-gov-east-1",
)

SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def test_fit_ps(time, strftime, sagemaker_session):
expected_train_args = _create_train_job("1.11", ps=True, py_version="py2")
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True)
expected_train_args["environment"] = {"USE_SMDEBUG": "0"}

actual_train_args = sagemaker_session.method_calls[0][2]
assert actual_train_args == expected_train_args
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def test_training_step_base_estimator(sagemaker_session):
},
"RoleArn": ROLE,
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"DebugHookConfig": {
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},
"CollectionConfigurations": [],
},
"ProfilerConfig": {
"ProfilingIntervalInMilliseconds": 500,
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},
Expand Down
110 changes: 109 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,110 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region):
assert args.get("profiler_rule_configs") is None


@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
def test_framework_with_debugger_config_set_up_in_unsupported_region(region):
with pytest.raises(ValueError) as error:
boto_mock = Mock(name="boto_session", region_name=region)
sms = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=region,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sms,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
debugger_hook_config=DebuggerHookConfig(s3_output_path="s3://output"),
)
f.fit("s3://mydata")

assert "Current region does not support debugger but debugger hook config is set!" in str(error)


@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
def test_framework_enable_profiling_in_unsupported_region(region):
with pytest.raises(ValueError) as error:
boto_mock = Mock(name="boto_session", region_name=region)
sms = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=region,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sms,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
)
f.fit("s3://mydata")
f.enable_default_profiling()

assert "Current region does not support profiler / debugger!" in str(error)


@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
def test_framework_update_profiling_in_unsupported_region(region):
with pytest.raises(ValueError) as error:
boto_mock = Mock(name="boto_session", region_name=region)
sms = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=region,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sms,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
)
f.fit("s3://mydata")
f.update_profiler(system_monitor_interval_millis=1000)

assert "Current region does not support profiler / debugger!" in str(error)


@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
def test_framework_disable_profiling_in_unsupported_region(region):
with pytest.raises(ValueError) as error:
boto_mock = Mock(name="boto_session", region_name=region)
sms = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=region,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sms,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
)
f.fit("s3://mydata")
f.disable_profiling()

assert "Current region does not support profiler / debugger!" in str(error)


def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session):
with pytest.raises(RuntimeError) as error:
f = DummyFramework(
Expand Down Expand Up @@ -2683,6 +2787,7 @@ def test_generic_to_fit_no_input(time, sagemaker_session):

args.pop("job_name")
args.pop("role")
args.pop("debugger_hook_config")

assert args == NO_INPUT_TRAIN_CALL

Expand All @@ -2707,6 +2812,7 @@ def test_generic_to_fit_no_hps(time, sagemaker_session):

args.pop("job_name")
args.pop("role")
args.pop("debugger_hook_config")

assert args == BASE_TRAIN_CALL

Expand All @@ -2733,6 +2839,7 @@ def test_generic_to_fit_with_hps(time, sagemaker_session):

args.pop("job_name")
args.pop("role")
args.pop("debugger_hook_config")

assert args == HP_TRAIN_CALL

Expand Down Expand Up @@ -2764,6 +2871,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session):

args.pop("job_name")
args.pop("role")
args.pop("debugger_hook_config")

assert args == EXP_TRAIN_CALL

Expand Down Expand Up @@ -2917,6 +3025,7 @@ def test_generic_to_deploy(time, sagemaker_session):

args.pop("job_name")
args.pop("role")
args.pop("debugger_hook_config")

assert args == HP_TRAIN_CALL

Expand Down Expand Up @@ -3727,7 +3836,6 @@ def test_script_mode_estimator_same_calls_as_framework(
source_dir=script_uri,
image_uri=IMAGE_URI,
model_uri=model_uri,
environment={"USE_SMDEBUG": "0"},
dependencies=[],
debugger_hook_config={},
)
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,13 @@ def test_region_supports_debugger_feature_returns_true_for_supported_regions():

def test_region_supports_debugger_feature_returns_false_for_unsupported_regions():
assert fw_utils._region_supports_debugger("us-iso-east-1") is False

assert fw_utils._region_supports_debugger("ap-southeast-3") is False
assert fw_utils._region_supports_debugger("ap-southeast-4") is False
assert fw_utils._region_supports_debugger("eu-south-2") is False
assert fw_utils._region_supports_debugger("me-central-1") is False
assert fw_utils._region_supports_debugger("ap-south-2") is False
assert fw_utils._region_supports_debugger("eu-central-2") is False
assert fw_utils._region_supports_debugger("us-gov-east-1") is False

def test_warn_if_parameter_server_with_multi_gpu(caplog):
instance_type = "ml.p2.8xlarge"
Expand Down