Skip to content

Commit c9c4800

Browse files
committed
change: disable debugger/profiler in cgk region
1 parent 6e5d247 commit c9c4800

File tree

6 files changed

+171
-16
lines changed

6 files changed

+171
-16
lines changed

src/sagemaker/estimator.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,29 @@ def _prepare_debugger_for_training(self):
823823
self.debugger_hook_config.s3_output_path = self.output_path
824824
self.debugger_rule_configs = self._prepare_debugger_rules()
825825
self._prepare_collection_configs()
826+
self._validate_and_set_debugger_configs()
827+
if not self.debugger_hook_config:
828+
if self.environment is None:
829+
self.environment = {}
830+
self.environment[DEBUGGER_FLAG] = "0"
831+
832+
def _validate_and_set_debugger_configs(self):
833+
"""Set defaults for debugging."""
834+
region_supports_debugger = _region_supports_debugger(
835+
self.sagemaker_session.boto_region_name
836+
)
837+
838+
if region_supports_debugger:
839+
if self.debugger_hook_config in [None, {}]:
840+
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
841+
else:
842+
if self.debugger_hook_config is not False and self.debugger_hook_config:
843+
# when user set debugger config in a unsupported region
844+
raise ValueError(
845+
"Current region does not support debugger but debugger hook config is set!"
846+
)
847+
# disable debugger in unsupported regions
848+
self.debugger_hook_config = False
826849

827850
def _prepare_debugger_rules(self):
828851
"""Set any necessary values in debugger rules, if they are provided."""
@@ -1766,6 +1789,8 @@ def enable_default_profiling(self):
17661789
Debugger monitoring is disabled.
17671790
"""
17681791
self._ensure_latest_training_job()
1792+
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
1793+
raise ValueError("Current region does not support profiler / debugger!")
17691794

17701795
training_job_details = self.latest_training_job.describe()
17711796

@@ -1799,6 +1824,8 @@ def disable_profiling(self):
17991824
18001825
"""
18011826
self._ensure_latest_training_job()
1827+
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
1828+
raise ValueError("Current region does not support profiler / debugger!")
18021829

18031830
training_job_details = self.latest_training_job.describe()
18041831

@@ -1852,6 +1879,8 @@ def update_profiler(
18521879
18531880
"""
18541881
self._ensure_latest_training_job()
1882+
if not _region_supports_debugger(self.sagemaker_session.boto_region_name):
1883+
raise ValueError("Current region does not support profiler / debugger!")
18551884

18561885
if (
18571886
not rules
@@ -2872,13 +2901,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None:
28722901

28732902
def _validate_and_set_debugger_configs(self):
28742903
"""Set defaults for debugging."""
2875-
if self.debugger_hook_config is None and _region_supports_debugger(
2876-
self.sagemaker_session.boto_region_name
2877-
):
2878-
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
2879-
elif not self.debugger_hook_config:
2880-
# set hook config to False if _region_supports_debugger is False
2881-
self.debugger_hook_config = False
2904+
super(Framework, self)._validate_and_set_debugger_configs()
28822905

28832906
# Disable debugger if checkpointing is enabled by the customer
28842907
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
@@ -2901,11 +2924,6 @@ def _validate_and_set_debugger_configs(self):
29012924
)
29022925
self.debugger_hook_config = False
29032926

2904-
if self.debugger_hook_config is False:
2905-
if self.environment is None:
2906-
self.environment = {}
2907-
self.environment[DEBUGGER_FLAG] = "0"
2908-
29092927
def _model_source_dir(self):
29102928
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
29112929

src/sagemaker/fw_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,26 @@
5353
"only one worker per host regardless of the number of GPUs."
5454
)
5555

56-
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
57-
PROFILER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
56+
DEBUGGER_UNSUPPORTED_REGIONS = (
57+
"us-iso-east-1",
58+
"ap-southeast-3",
59+
"ap-southeast-4",
60+
"eu-south-2",
61+
"me-central-1",
62+
"ap-south-2",
63+
"eu-central-2",
64+
"us-gov-east-1",
65+
)
66+
PROFILER_UNSUPPORTED_REGIONS = (
67+
"us-iso-east-1",
68+
"ap-southeast-3",
69+
"ap-southeast-4",
70+
"eu-south-2",
71+
"me-central-1",
72+
"ap-south-2",
73+
"eu-central-2",
74+
"us-gov-east-1",
75+
)
5876

5977
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
6078
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def test_fit_ps(time, strftime, sagemaker_session):
483483
expected_train_args = _create_train_job("1.11", ps=True, py_version="py2")
484484
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
485485
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True)
486+
expected_train_args["environment"] = {"USE_SMDEBUG": "0"}
486487

487488
actual_train_args = sagemaker_session.method_calls[0][2]
488489
assert actual_train_args == expected_train_args

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@ def test_training_step_base_estimator(sagemaker_session):
370370
},
371371
"RoleArn": ROLE,
372372
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
373+
"DebugHookConfig": {
374+
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},
375+
"CollectionConfigurations": [],
376+
},
373377
"ProfilerConfig": {
374378
"ProfilingIntervalInMilliseconds": 500,
375379
"S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}},

tests/unit/test_estimator.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,110 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region):
726726
assert args.get("profiler_rule_configs") is None
727727

728728

729+
@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
730+
def test_framework_with_debugger_config_set_up_in_unsupported_region(region):
731+
with pytest.raises(ValueError) as error:
732+
boto_mock = Mock(name="boto_session", region_name=region)
733+
sms = MagicMock(
734+
name="sagemaker_session",
735+
boto_session=boto_mock,
736+
boto_region_name=region,
737+
config=None,
738+
local_mode=False,
739+
s3_client=None,
740+
s3_resource=None,
741+
)
742+
f = DummyFramework(
743+
entry_point=SCRIPT_PATH,
744+
role=ROLE,
745+
sagemaker_session=sms,
746+
instance_count=INSTANCE_COUNT,
747+
instance_type=INSTANCE_TYPE,
748+
debugger_hook_config=DebuggerHookConfig(s3_output_path="s3://output"),
749+
)
750+
f.fit("s3://mydata")
751+
752+
assert "Current region does not support debugger but debugger hook config is set!" in str(error)
753+
754+
755+
@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
756+
def test_framework_enable_profiling_in_unsupported_region(region):
757+
with pytest.raises(ValueError) as error:
758+
boto_mock = Mock(name="boto_session", region_name=region)
759+
sms = MagicMock(
760+
name="sagemaker_session",
761+
boto_session=boto_mock,
762+
boto_region_name=region,
763+
config=None,
764+
local_mode=False,
765+
s3_client=None,
766+
s3_resource=None,
767+
)
768+
f = DummyFramework(
769+
entry_point=SCRIPT_PATH,
770+
role=ROLE,
771+
sagemaker_session=sms,
772+
instance_count=INSTANCE_COUNT,
773+
instance_type=INSTANCE_TYPE,
774+
)
775+
f.fit("s3://mydata")
776+
f.enable_default_profiling()
777+
778+
assert "Current region does not support profiler / debugger!" in str(error)
779+
780+
781+
@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
782+
def test_framework_update_profiling_in_unsupported_region(region):
783+
with pytest.raises(ValueError) as error:
784+
boto_mock = Mock(name="boto_session", region_name=region)
785+
sms = MagicMock(
786+
name="sagemaker_session",
787+
boto_session=boto_mock,
788+
boto_region_name=region,
789+
config=None,
790+
local_mode=False,
791+
s3_client=None,
792+
s3_resource=None,
793+
)
794+
f = DummyFramework(
795+
entry_point=SCRIPT_PATH,
796+
role=ROLE,
797+
sagemaker_session=sms,
798+
instance_count=INSTANCE_COUNT,
799+
instance_type=INSTANCE_TYPE,
800+
)
801+
f.fit("s3://mydata")
802+
f.update_profiler(system_monitor_interval_millis=1000)
803+
804+
assert "Current region does not support profiler / debugger!" in str(error)
805+
806+
807+
@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
808+
def test_framework_disable_profiling_in_unsupported_region(region):
809+
with pytest.raises(ValueError) as error:
810+
boto_mock = Mock(name="boto_session", region_name=region)
811+
sms = MagicMock(
812+
name="sagemaker_session",
813+
boto_session=boto_mock,
814+
boto_region_name=region,
815+
config=None,
816+
local_mode=False,
817+
s3_client=None,
818+
s3_resource=None,
819+
)
820+
f = DummyFramework(
821+
entry_point=SCRIPT_PATH,
822+
role=ROLE,
823+
sagemaker_session=sms,
824+
instance_count=INSTANCE_COUNT,
825+
instance_type=INSTANCE_TYPE,
826+
)
827+
f.fit("s3://mydata")
828+
f.disable_profiling()
829+
830+
assert "Current region does not support profiler / debugger!" in str(error)
831+
832+
729833
def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session):
730834
with pytest.raises(RuntimeError) as error:
731835
f = DummyFramework(
@@ -2683,6 +2787,7 @@ def test_generic_to_fit_no_input(time, sagemaker_session):
26832787

26842788
args.pop("job_name")
26852789
args.pop("role")
2790+
args.pop("debugger_hook_config")
26862791

26872792
assert args == NO_INPUT_TRAIN_CALL
26882793

@@ -2707,6 +2812,7 @@ def test_generic_to_fit_no_hps(time, sagemaker_session):
27072812

27082813
args.pop("job_name")
27092814
args.pop("role")
2815+
args.pop("debugger_hook_config")
27102816

27112817
assert args == BASE_TRAIN_CALL
27122818

@@ -2733,6 +2839,7 @@ def test_generic_to_fit_with_hps(time, sagemaker_session):
27332839

27342840
args.pop("job_name")
27352841
args.pop("role")
2842+
args.pop("debugger_hook_config")
27362843

27372844
assert args == HP_TRAIN_CALL
27382845

@@ -2764,6 +2871,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session):
27642871

27652872
args.pop("job_name")
27662873
args.pop("role")
2874+
args.pop("debugger_hook_config")
27672875

27682876
assert args == EXP_TRAIN_CALL
27692877

@@ -2917,6 +3025,7 @@ def test_generic_to_deploy(time, sagemaker_session):
29173025

29183026
args.pop("job_name")
29193027
args.pop("role")
3028+
args.pop("debugger_hook_config")
29203029

29213030
assert args == HP_TRAIN_CALL
29223031

@@ -3727,7 +3836,6 @@ def test_script_mode_estimator_same_calls_as_framework(
37273836
source_dir=script_uri,
37283837
image_uri=IMAGE_URI,
37293838
model_uri=model_uri,
3730-
environment={"USE_SMDEBUG": "0"},
37313839
dependencies=[],
37323840
debugger_hook_config={},
37333841
)

tests/unit/test_fw_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,13 @@ def test_region_supports_debugger_feature_returns_true_for_supported_regions():
552552

553553
def test_region_supports_debugger_feature_returns_false_for_unsupported_regions():
554554
assert fw_utils._region_supports_debugger("us-iso-east-1") is False
555-
555+
assert fw_utils._region_supports_debugger("ap-southeast-3") is False
556+
assert fw_utils._region_supports_debugger("ap-southeast-4") is False
557+
assert fw_utils._region_supports_debugger("eu-south-2") is False
558+
assert fw_utils._region_supports_debugger("me-central-1") is False
559+
assert fw_utils._region_supports_debugger("ap-south-2") is False
560+
assert fw_utils._region_supports_debugger("eu-central-2") is False
561+
assert fw_utils._region_supports_debugger("us-gov-east-1") is False
556562

557563
def test_warn_if_parameter_server_with_multi_gpu(caplog):
558564
instance_type = "ml.p2.8xlarge"

0 commit comments

Comments
 (0)