Skip to content

feat: Support custom repack model settings #4328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def __init__(

# the real estimator and inputs
repacker = SKLearn(
framework_version=FRAMEWORK_VERSION,
instance_type=INSTANCE_TYPE,
framework_version=kwargs.pop("framework_version", None) or FRAMEWORK_VERSION,
instance_type=kwargs.pop("instance_type", None) or INSTANCE_TYPE,
entry_point=REPACK_SCRIPT_LAUNCHER,
source_dir=self._source_dir,
dependencies=self._dependencies,
Expand Down
84 changes: 74 additions & 10 deletions src/sagemaker/workflow/model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
_REGISTER_MODEL_NAME_BASE = "RegisterModel"
_CREATE_MODEL_NAME_BASE = "CreateModel"
_REPACK_MODEL_NAME_BASE = "RepackModel"
_IGNORED_REPACK_PARAM_LIST = ["entry_point", "source_dir", "hyperparameters", "dependencies"]

logger = logging.getLogger(__name__)


class ModelStep(StepCollection):
Expand All @@ -42,6 +45,7 @@ def __init__(
retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
repack_model_step_settings: Optional[Dict[str, any]] = None,
):
"""Constructs a `ModelStep`.

Expand Down Expand Up @@ -115,6 +119,15 @@ def __init__(
display_name (str): The display name of the `ModelStep`.
The display name provides better UI readability. (default: None).
description (str): The description of the `ModelStep` (default: None).
repack_model_step_settings (Dict[str, any]): The kwargs passed to the _RepackModelStep
to customize the configuration of the underlying repack model job (default: None).
Notes:
1. If the _RepackModelStep is unnecessary, the settings will be ignored.
2. If the _RepackModelStep is added, the repack_model_step_settings
is honored if set.
3. In repack_model_step_settings, the arguments with misspelled keys will be
ignored. Please refer to the expected parameters of repack model job in
:class:`~sagemaker.sklearn.estimator.SKLearn` and its base classes.
"""
from sagemaker.workflow.utilities import validate_step_args_input

Expand Down Expand Up @@ -148,6 +161,9 @@ def __init__(
self.display_name = display_name
self.description = description
self.steps: List[Step] = []
self._repack_model_step_settings = (
dict(repack_model_step_settings) if repack_model_step_settings else {}
)
self._model = step_args.model
self._create_model_args = self.step_args.create_model_request
self._register_model_args = self.step_args.create_model_package_request
Expand All @@ -157,6 +173,12 @@ def __init__(

if self._need_runtime_repack:
self._append_repack_model_step()
elif self._repack_model_step_settings:
logger.warning(
"Non-empty repack_model_step_settings is supplied but no repack model "
"step is needed. Ignoring the repack_model_step_settings."
)

if self._register_model_args:
self._append_register_model_step()
else:
Expand Down Expand Up @@ -235,14 +257,12 @@ def _append_repack_model_step(self):
elif isinstance(self._model, Model):
model_list = [self._model]
else:
logging.warning("No models to repack")
logger.warning("No models to repack")
return

security_group_ids = None
subnets = None
if self._model.vpc_config:
security_group_ids = self._model.vpc_config.get("SecurityGroupIds", None)
subnets = self._model.vpc_config.get("Subnets", None)
self._pop_out_non_configurable_repack_model_step_args()

security_group_ids, subnets = self._resolve_repack_model_step_vpc_configs()

for i, model in enumerate(model_list):
runtime_repack_flg = (
Expand All @@ -252,8 +272,16 @@ def _append_repack_model_step(self):
name_base = model.name or i
repack_model_step = _RepackModelStep(
name="{}-{}-{}".format(self.name, _REPACK_MODEL_NAME_BASE, name_base),
sagemaker_session=self._model.sagemaker_session or model.sagemaker_session,
role=self._model.role or model.role,
sagemaker_session=(
self._repack_model_step_settings.pop("sagemaker_session", None)
or self._model.sagemaker_session
or model.sagemaker_session
),
role=(
self._repack_model_step_settings.pop("role", None)
or self._model.role
or model.role
),
model_data=model.model_data,
entry_point=model.entry_point,
source_dir=model.source_dir,
Expand All @@ -266,8 +294,15 @@ def _append_repack_model_step(self):
),
depends_on=self.depends_on,
retry_policies=self._repack_model_retry_policies,
output_path=self._runtime_repack_output_prefix,
output_kms_key=model.model_kms_key,
output_path=(
self._repack_model_step_settings.pop("output_path", None)
or self._runtime_repack_output_prefix
),
output_kms_key=(
self._repack_model_step_settings.pop("output_kms_key", None)
or model.model_kms_key
),
**self._repack_model_step_settings
)
self.steps.append(repack_model_step)

Expand All @@ -282,3 +317,32 @@ def _append_repack_model_step(self):
"InferenceSpecification"
]["Containers"][i]
container["ModelDataUrl"] = repacked_model_data

def _pop_out_non_configurable_repack_model_step_args(self):
"""Pop out non-configurable args from _repack_model_step_settings"""
if not self._repack_model_step_settings:
return
for ignored_param in _IGNORED_REPACK_PARAM_LIST:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is directly configurable from the user, why would they include params in their step settings that they know would get ignored ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input interface to user is a dict so they can input anything they'd like. However, there are some keys/fields they should not touch, otherwise the repack model stack won't work as expected. For these field, we add this check as an early validation

if self._repack_model_step_settings.pop(ignored_param, None):
logger.warning(
"The repack model step parameter - %s is not configurable. Ignoring it.",
ignored_param,
)

def _resolve_repack_model_step_vpc_configs(self):
"""Resolve vpc configs for repack model step"""
# Note: the EstimatorBase constructor ensures that:
# "When setting up custom VPC, both subnets and security_group_ids must be set"
if self._repack_model_step_settings.get(
"security_group_ids", None
) or self._repack_model_step_settings.get("subnets", None):
security_group_ids = self._repack_model_step_settings.pop("security_group_ids", None)
subnets = self._repack_model_step_settings.pop("subnets", None)
return security_group_ids, subnets

if self._model.vpc_config:
security_group_ids = self._model.vpc_config.get("SecurityGroupIds", None)
subnets = self._model.vpc_config.get("Subnets", None)
return security_group_ids, subnets

return None, None
Loading