Skip to content

fix: disable Debugger defaults in unsupported regions #1272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
parse_s3_url,
UploadedCode,
validate_source_dir,
_region_supports_debugger,
)
from sagemaker.job import _Job
from sagemaker.local import LocalSession
Expand Down Expand Up @@ -1674,7 +1675,9 @@ def _validate_and_set_debugger_configs(self):
"""
Set defaults for debugging
"""
if self.debugger_hook_config is None:
if self.debugger_hook_config is None and _region_supports_debugger(
self.sagemaker_session.boto_region_name
):
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
elif not self.debugger_hook_config:
self.debugger_hook_config = None
Expand Down
15 changes: 15 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
"pytorch-serving": [1, 2, 0],
}

DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]


def is_version_equal_or_higher(lowest_version, framework_version):
"""Determine whether the ``framework_version`` is equal to or higher than
Expand Down Expand Up @@ -504,3 +506,16 @@ def python_deprecation_warning(framework, latest_supported_version):
return PYTHON_2_DEPRECATION_WARNING.format(
framework=framework, latest_supported_version=latest_supported_version
)


def _region_supports_debugger(region_name):
"""Returns boolean indicating whether the region supports Amazon SageMaker Debugger.

Args:
region_name (str): Name of the region to check against.

Returns:
bool: Whether or not the region supports Amazon SageMaker Debugger.

"""
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
4 changes: 3 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,9 @@ def _validate_and_set_debugger_configs(self):
)
self.debugger_hook_config = None
self.debugger_rule_configs = None
elif self.debugger_hook_config is None:
elif self.debugger_hook_config is None and fw._region_supports_debugger(
self.sagemaker_session.boto_session.region_name
):
# Set defaults for debugging.
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)

Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,3 +1035,13 @@ def test_model_code_key_prefix_with_all_none_fail():
with pytest.raises(TypeError) as error:
fw_utils.model_code_key_prefix(None, None, None)
assert "expected string" in str(error)


def test_region_supports_debugger_feature_returns_true_for_supported_regions():
assert fw_utils._region_supports_debugger("us-west-2") is True
assert fw_utils._region_supports_debugger("us-east-2") is True


def test_region_supports_debugger_feature_returns_false_for_unsupported_regions():
assert fw_utils._region_supports_debugger("us-gov-west-1") is False
assert fw_utils._region_supports_debugger("us-iso-east-1") is False