Skip to content

feature: support remote debug for sagemaker training job #4315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
container_entry_point: Optional[List[str]] = None,
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -540,6 +541,8 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -777,6 +780,8 @@ def __init__(

self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name)

self._enable_remote_debug = enable_remote_debug

@abstractmethod
def training_image_uri(self):
"""Return the Docker image to use for training.
Expand Down Expand Up @@ -1958,6 +1963,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
if max_wait:
init_params["max_wait"] = max_wait

if "RemoteDebugConfig" in job_details:
init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get(
"EnableRemoteDebug"
)
return init_params

def _get_instance_type(self):
Expand Down Expand Up @@ -2292,6 +2302,32 @@ def update_profiler(

_TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict)

def get_remote_debug_config(self):
"""dict: Return the configuration of RemoteDebug"""
return (
None
if self._enable_remote_debug is None
else {"EnableRemoteDebug": self._enable_remote_debug}
)

def enable_remote_debug(self):
"""Enable remote debug for a training job."""
self._update_remote_debug(True)

def disable_remote_debug(self):
"""Disable remote debug for a training job."""
self._update_remote_debug(False)

def _update_remote_debug(self, enable_remote_debug: bool):
"""Update to enable or disable remote debug for a training job.

This method updates the ``_enable_remote_debug`` parameter
and enables or disables remote debug for a training job
"""
self._ensure_latest_training_job()
_TrainingJob.update(self, remote_debug_config={"EnableRemoteDebug": enable_remote_debug})
self._enable_remote_debug = enable_remote_debug

def get_app_url(
self,
app_type,
Expand Down Expand Up @@ -2520,6 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.profiler_config:
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()

if estimator.get_remote_debug_config() is not None:
train_args["remote_debug_config"] = estimator.get_remote_debug_config()

return train_args

@classmethod
Expand Down Expand Up @@ -2549,7 +2588,12 @@ def _is_local_channel(cls, input_uri):

@classmethod
def update(
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
cls,
estimator,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
remote_debug_config=None,
):
"""Update a running Amazon SageMaker training job.

Expand All @@ -2562,20 +2606,31 @@ def update(
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: None).
remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,

.. code:: python

remote_debug_config = {
"EnableRemoteDebug": True,
} (default: None).

Returns:
sagemaker.estimator._TrainingJob: Constructed object that captures
all information about the updated training job.
"""
update_args = cls._get_update_args(
estimator, profiler_rule_configs, profiler_config, resource_config
estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
)
estimator.sagemaker_session.update_training_job(**update_args)

return estimator.latest_training_job

@classmethod
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
def _get_update_args(
cls, estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
):
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.

Args:
Expand All @@ -2596,6 +2651,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
update_args.update(build_dict("profiler_config", profiler_config))
update_args.update(build_dict("resource_config", resource_config))
update_args.update(build_dict("remote_debug_config", remote_debug_config))

return update_args

Expand Down Expand Up @@ -2694,6 +2750,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3055,6 +3112,8 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3106,6 +3165,7 @@ def __init__(
container_entry_point=container_entry_point,
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
container_entry_point: Optional[List[str]] = None,
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.

Expand Down Expand Up @@ -495,6 +496,8 @@ def __init__(
a training job.
disable_output_compression (Optional[bool]): When set to true, Model is uploaded
to Amazon S3 without compression after training finishes.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job

Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -569,6 +572,7 @@ def _is_valid_model_id_hook():
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
)

self.model_id = estimator_init_kwargs.model_id
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def get_init_kwargs(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -183,6 +184,7 @@ def get_init_kwargs(
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
)

estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"container_arguments",
"disable_output_compression",
"enable_infra_check",
"enable_remote_debug",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1344,6 +1345,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -1401,6 +1403,7 @@ def __init__(
self.container_arguments = container_arguments
self.disable_output_compression = disable_output_compression
self.enable_infra_check = enable_infra_check
self.enable_remote_debug = enable_remote_debug


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
48 changes: 48 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def train( # noqa: C901
profiler_config=None,
environment: Optional[Dict[str, str]] = None,
retry_strategy=None,
remote_debug_config=None,
):
"""Create an Amazon SageMaker training job.

Expand Down Expand Up @@ -858,6 +859,15 @@ def train( # noqa: C901
configurations.src/sagemaker/lineage/artifact.py:285
profiler_config (dict): Configuration for how profiling information is emitted
with SageMaker Profiler. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,

.. code:: python

remote_debug_config = {
"EnableRemoteDebug": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -950,6 +960,7 @@ def train( # noqa: C901
enable_sagemaker_metrics=enable_sagemaker_metrics,
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
remote_debug_config=remote_debug_config,
environment=environment,
retry_strategy=retry_strategy,
)
Expand Down Expand Up @@ -992,6 +1003,7 @@ def _get_train_request( # noqa: C901
enable_sagemaker_metrics=None,
profiler_rule_configs=None,
profiler_config=None,
remote_debug_config=None,
environment=None,
retry_strategy=None,
):
Expand Down Expand Up @@ -1103,6 +1115,15 @@ def _get_train_request( # noqa: C901
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
profiler_config(dict): Configuration for how profiling information is emitted with
SageMaker Profiler. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,

.. code:: python

remote_debug_config = {
"EnableRemoteDebug": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -1206,6 +1227,9 @@ def _get_train_request( # noqa: C901
if profiler_config is not None:
train_request["ProfilerConfig"] = profiler_config

if remote_debug_config is not None:
train_request["RemoteDebugConfig"] = remote_debug_config

if retry_strategy is not None:
train_request["RetryStrategy"] = retry_strategy

Expand All @@ -1217,6 +1241,7 @@ def update_training_job(
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
remote_debug_config=None,
):
"""Calls the UpdateTrainingJob API for the given job name and returns the response.

Expand All @@ -1228,6 +1253,15 @@ def update_training_job(
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,

.. code:: python

remote_debug_config = {
"EnableRemoteDebug": True,
}
"""
# No injections from sagemaker_config because the UpdateTrainingJob API's resource_config
# object accepts fewer parameters than the CreateTrainingJob API, and none that the
Expand All @@ -1240,6 +1274,7 @@ def update_training_job(
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
resource_config=resource_config,
remote_debug_config=remote_debug_config,
)
LOGGER.info("Updating training job with name %s", job_name)
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
Expand All @@ -1251,6 +1286,7 @@ def _get_update_training_job_request(
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
remote_debug_config=None,
):
"""Constructs a request compatible for updating an Amazon SageMaker training job.

Expand All @@ -1262,6 +1298,15 @@ def _get_update_training_job_request(
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,

.. code:: python

remote_debug_config = {
"EnableRemoteDebug": True,
}

Returns:
Dict: an update training request dict
Expand All @@ -1279,6 +1324,9 @@ def _get_update_training_job_request(
if resource_config is not None:
update_training_job_request["ResourceConfig"] = resource_config

if remote_debug_config is not None:
update_training_job_request["RemoteDebugConfig"] = remote_debug_config

return update_training_job_request

def process(
Expand Down
Loading