-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: custom base job name for jumpstart models/estimators #2970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f166b60
086258d
a39b750
28fd737
20df3d7
b9f90dc
c639b19
3542bb9
747e234
5677dcb
6db3774
d610bfb
169dffd
b47b1d5
4325fcd
95e2ca9
ea9490b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
version: 2 | ||
|
||
python: | ||
version: 3.6 | ||
version: 3.9 | ||
install: | ||
- method: pip | ||
path: . | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -232,6 +232,22 @@ def add_single_jumpstart_tag( | |
return curr_tags | ||
|
||
|
||
def get_jumpstart_base_name_if_jumpstart_model( | ||
*uris: Optional[str], | ||
) -> Optional[str]: | ||
"""Return default JumpStart base name if a URI belongs to JumpStart. | ||
|
||
If no URIs belong to JumpStart, return None. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please address Mufis' comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did ( |
||
|
||
Args: | ||
*uris (Optional[str]): URI to test for association with JumpStart. | ||
""" | ||
for uri in uris: | ||
if is_jumpstart_model_uri(uri): | ||
return constants.JUMPSTART_RESOURCE_BASE_NAME | ||
return None | ||
|
||
|
||
def add_jumpstart_tags( | ||
tags: Optional[List[Dict[str, str]]] = None, | ||
inference_model_uri: Optional[str] = None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
from sagemaker.predictor import PredictorBase | ||
from sagemaker.serverless import ServerlessInferenceConfig | ||
from sagemaker.transformer import Transformer | ||
from sagemaker.jumpstart.utils import add_jumpstart_tags | ||
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model | ||
from sagemaker.utils import unique_name_from_base | ||
from sagemaker.async_inference import AsyncInferenceConfig | ||
from sagemaker.predictor_async import AsyncPredictor | ||
|
@@ -466,7 +466,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: | |
) | ||
|
||
def _script_mode_env_vars(self): | ||
"""Placeholder docstring""" | ||
"""Returns a mapping of environment variables for script mode execution""" | ||
script_name = None | ||
dir_name = None | ||
if self.uploaded_code: | ||
|
@@ -478,8 +478,11 @@ def _script_mode_env_vars(self): | |
elif self.entry_point is not None: | ||
script_name = self.entry_point | ||
if self.source_dir is not None: | ||
dir_name = "file://" + self.source_dir | ||
|
||
dir_name = ( | ||
self.source_dir | ||
if self.source_dir.startswith("s3://") | ||
else "file://" + self.source_dir | ||
) | ||
return { | ||
SCRIPT_PARAM_NAME.upper(): script_name or str(), | ||
DIR_PARAM_NAME.upper(): dir_name or str(), | ||
|
@@ -514,7 +517,9 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag | |
""" | ||
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type) | ||
|
||
self._ensure_base_name_if_needed(container_def["Image"]) | ||
self._ensure_base_name_if_needed( | ||
image_uri=container_def["Image"], script_uri=self.source_dir, model_uri=self.model_data | ||
) | ||
self._set_model_name_if_needed() | ||
|
||
enable_network_isolation = self.enable_network_isolation() | ||
|
@@ -529,10 +534,17 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag | |
tags=tags, | ||
) | ||
|
||
def _ensure_base_name_if_needed(self, image_uri): | ||
"""Create a base name from the image URI if there is no model name provided.""" | ||
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri): | ||
"""Create a base name from the image URI if there is no model name provided. | ||
|
||
If a JumpStart script or model uri is used, select the JumpStart base name. | ||
""" | ||
if self.name is None: | ||
self._base_name = self._base_name or utils.base_name_from_image(image_uri) | ||
self._base_name = ( | ||
self._base_name | ||
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri) | ||
or utils.base_name_from_image(image_uri) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: I would prefer to define a |
||
|
||
def _set_model_name_if_needed(self): | ||
"""Generate a new model name if ``self._base_name`` is present.""" | ||
|
@@ -963,7 +975,9 @@ def deploy( | |
|
||
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1]) | ||
if self._is_compiled_model and not is_serverless: | ||
self._ensure_base_name_if_needed(self.image_uri) | ||
self._ensure_base_name_if_needed( | ||
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data | ||
) | ||
if self._base_name is not None: | ||
self._base_name = "-".join((self._base_name, compiled_model_suffix)) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.