Skip to content

Commit 4955d22

Browse files
authored
feat: SDK Defaults - DebugHookConfig defaults in TrainingJob API (#3947)
1 parent bc948e5 commit 4955d22

File tree

8 files changed

+264
-2
lines changed

8 files changed

+264
-2
lines changed

src/sagemaker/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sagemaker.config.config_schema import ( # noqa: F401
1818
KEY,
1919
TRAINING_JOB,
20+
ESTIMATOR_DEBUG_HOOK_CONFIG_PATH,
2021
TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
2122
TRAINING_JOB_ROLE_ARN_PATH,
2223
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
@@ -158,4 +159,6 @@
158159
CONTAINERS,
159160
PRIMARY_CONTAINER,
160161
INFERENCE_SPECIFICATION,
162+
ESTIMATOR,
163+
DEBUG_HOOK_CONFIG,
161164
)

src/sagemaker/config/config_schema.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
INFERENCE_SPECIFICATION = "InferenceSpecification"
101101
PROFILER_CONFIG = "ProfilerConfig"
102102
DISABLE_PROFILER = "DisableProfiler"
103+
ESTIMATOR = "Estimator"
104+
DEBUG_HOOK_CONFIG = "DebugHookConfig"
103105

104106

105107
def _simple_path(*args: str):
@@ -338,6 +340,9 @@ def _simple_path(*args: str):
338340
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH = _simple_path(
339341
SAGEMAKER, PYTHON_SDK, MODULES, SESSION, DEFAULT_S3_OBJECT_KEY_PREFIX
340342
)
343+
ESTIMATOR_DEBUG_HOOK_CONFIG_PATH = _simple_path(
344+
SAGEMAKER, PYTHON_SDK, MODULES, ESTIMATOR, DEBUG_HOOK_CONFIG
345+
)
341346

342347

343348
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
@@ -645,6 +650,27 @@ def _simple_path(*args: str):
645650
},
646651
},
647652
},
653+
ESTIMATOR: {
654+
TYPE: OBJECT,
655+
ADDITIONAL_PROPERTIES: False,
656+
PROPERTIES: {
657+
DEBUG_HOOK_CONFIG: {
658+
TYPE: "boolean",
659+
"description": (
660+
"Sets a boolean for `debugger_hook_config` of"
661+
"Estimator which will be then used for training job"
662+
"API call. Today, the config_schema doesn't support"
663+
"a dictionary as a valid value to be provided."
664+
"In the future to add support for DebugHookConfig"
665+
"as a dictionary, schema should be added under"
666+
"the config path `SageMaker.TrainingJob` instead of"
667+
"here, since the TrainingJob API supports"
668+
"DebugHookConfig as a dictionary, we can add"
669+
"a schema for it at API level."
670+
),
671+
},
672+
},
673+
},
648674
REMOTE_FUNCTION: {
649675
TYPE: OBJECT,
650676
ADDITIONAL_PROPERTIES: False,
@@ -990,6 +1016,11 @@ def _simple_path(*args: str):
9901016
},
9911017
# Training Job
9921018
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html
1019+
# Please note that we currently support 'DebugHookConfig' as a boolean value
1020+
# which can be provided under [SageMaker.PythonSDK.Modules.Estimator] config path.
1021+
# As of today, config_schema does not support the dict as a valid value to be
1022+
# provided. In case, we decide to support it in the future, we can add a new schema
1023+
# for it under [SageMaker.TrainingJob] config path.
9931024
TRAINING_JOB: {
9941025
TYPE: OBJECT,
9951026
ADDITIONAL_PROPERTIES: False,

src/sagemaker/estimator.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sagemaker import git_utils, image_uris, vpc_utils, s3
3131
from sagemaker.analytics import TrainingJobAnalytics
3232
from sagemaker.config import (
33+
ESTIMATOR_DEBUG_HOOK_CONFIG_PATH,
3334
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
3435
TRAINING_JOB_SECURITY_GROUP_IDS_PATH,
3536
TRAINING_JOB_SUBNETS_PATH,
@@ -675,7 +676,26 @@ def __init__(
675676
self.checkpoint_local_path = checkpoint_local_path
676677

677678
self.rules = rules
678-
self.debugger_hook_config = debugger_hook_config
679+
680+
# Today, we ONLY support debugger_hook_config to be provided as a boolean value
681+
# from sagemaker_config. We resolve value for this parameter as per the order
682+
# 1. value from direct_input which can be a boolean or a dictionary
683+
# 2. value from sagemaker_config which can be a boolean
684+
# In future, if we support debugger_hook_config to be provided as a dictionary
685+
# from sagemaker_config [SageMaker.TrainingJob] then we will need to update the
686+
# logic below to resolve the values as per the type of value received from
687+
# direct_input and sagemaker_config
688+
self.debugger_hook_config = resolve_value_from_config(
689+
direct_input=debugger_hook_config,
690+
config_path=ESTIMATOR_DEBUG_HOOK_CONFIG_PATH,
691+
sagemaker_session=sagemaker_session,
692+
)
693+
# If customer passes True from either direct_input or sagemaker_config, we will
694+
# create a default hook config as an empty dict which will later be populated
695+
# with default s3_output_path from _prepare_debugger_for_training function
696+
if self.debugger_hook_config is True:
697+
self.debugger_hook_config = {}
698+
679699
self.tensorboard_output_config = tensorboard_output_config
680700

681701
self.debugger_rule_configs = None

tests/data/config/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ SageMaker:
55
Session:
66
DefaultS3Bucket: 'sagemaker-python-sdk-test-bucket'
77
DefaultS3ObjectKeyPrefix: 'test-prefix'
8+
Estimator:
9+
DebugHookConfig: false
810
RemoteFunction:
911
Dependencies: "./requirements.txt"
1012
EnvironmentVariables:

tests/unit/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
CONTAINERS,
8686
PRIMARY_CONTAINER,
8787
INFERENCE_SPECIFICATION,
88+
ESTIMATOR,
89+
DEBUG_HOOK_CONFIG,
8890
)
8991

9092
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
@@ -323,6 +325,13 @@
323325
SAGEMAKER_CONFIG_TRAINING_JOB = {
324326
SCHEMA_VERSION: "1.0",
325327
SAGEMAKER: {
328+
PYTHON_SDK: {
329+
MODULES: {
330+
ESTIMATOR: {
331+
DEBUG_HOOK_CONFIG: False,
332+
},
333+
},
334+
},
326335
TRAINING_JOB: {
327336
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True,
328337
ENABLE_NETWORK_ISOLATION: True,
@@ -337,6 +346,32 @@
337346
},
338347
}
339348

349+
SAGEMAKER_CONFIG_TRAINING_JOB_WITH_DEBUG_HOOK_CONFIG_AS_FALSE = {
350+
SCHEMA_VERSION: "1.0",
351+
SAGEMAKER: {
352+
PYTHON_SDK: {
353+
MODULES: {
354+
ESTIMATOR: {
355+
DEBUG_HOOK_CONFIG: False,
356+
},
357+
},
358+
},
359+
},
360+
}
361+
362+
SAGEMAKER_CONFIG_TRAINING_JOB_WITH_DEBUG_HOOK_CONFIG_AS_TRUE = {
363+
SCHEMA_VERSION: "1.0",
364+
SAGEMAKER: {
365+
PYTHON_SDK: {
366+
MODULES: {
367+
ESTIMATOR: {
368+
DEBUG_HOOK_CONFIG: True,
369+
},
370+
},
371+
},
372+
},
373+
}
374+
340375
SAGEMAKER_CONFIG_TRANSFORM_JOB = {
341376
SCHEMA_VERSION: "1.0",
342377
SAGEMAKER: {

tests/unit/sagemaker/config/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def valid_session_config():
4545
}
4646

4747

48+
@pytest.fixture()
49+
def valid_estimator_config():
50+
return {
51+
"DebugHookConfig": False,
52+
}
53+
54+
4855
@pytest.fixture()
4956
def valid_environment_config():
5057
return {
@@ -251,10 +258,12 @@ def valid_config_with_all_the_scopes(
251258
valid_training_job_config,
252259
valid_edge_packaging_config,
253260
valid_remote_function_config,
261+
valid_estimator_config,
254262
):
255263
return {
256264
"PythonSDK": {
257265
"Modules": {
266+
"Estimator": valid_estimator_config,
258267
"RemoteFunction": valid_remote_function_config,
259268
"Session": valid_session_config,
260269
}

tests/unit/sagemaker/config/test_config_schema.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,25 @@ def test_valid_remote_function_schema(base_config_with_schema, valid_remote_func
102102
)
103103

104104

105+
def test_valid_estimator_schema(base_config_with_schema, valid_estimator_config):
106+
_validate_config(
107+
base_config_with_schema,
108+
{"PythonSDK": {"Modules": {"Estimator": valid_estimator_config}}},
109+
)
110+
111+
112+
def test_invalid_estimator_schema(base_config_with_schema, valid_estimator_config):
113+
invalid_estimator_config = {
114+
"DebugHookConfig": {
115+
"S3OutputPath": "s3://somepath",
116+
}
117+
}
118+
config = base_config_with_schema
119+
config["SageMaker"] = {"PythonSDK": {"Modules": {"Estimator": invalid_estimator_config}}}
120+
with pytest.raises(exceptions.ValidationError):
121+
validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
122+
123+
105124
def test_tags_with_invalid_schema(base_config_with_schema, valid_edge_packaging_config):
106125
edge_packaging_config = valid_edge_packaging_config.copy()
107126
edge_packaging_config["Tags"] = [{"Key": "somekey"}]

0 commit comments

Comments
 (0)