-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: jumpstart default payloads #4149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: jumpstart default payloads #4149
Conversation
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
src/sagemaker/jumpstart/types.py
Outdated
self.accept = accept | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Returns json representation of JumpStartDefaultPayloads object.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: supposed to be JumpStartSerializedPayload
?
src/sagemaker/payloads.py
Outdated
logger = logging.getLogger(__name__) | ||
|
||
|
||
def retrieve_samples( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: reasoning for switching to samples
noun?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use defaults
? I thought it would be confusing to use default
in the context of returning both a single and multiple payloads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't we used retrieve_options
in other accessors?
src/sagemaker/payloads.py
Outdated
""" | ||
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): | ||
raise ValueError( | ||
"Must specify JumpStart `model_id` and `model_version` when retrieving model URIs." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: when retrieving payloads?
self.region = region | ||
self.s3_client = s3_client | ||
|
||
def get_bytes_payload_with_s3_references( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think MH has similar logic atm in the other PR, is the plan to reuse the functions in this PR post-merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's the idea. I was working on this PR when I realized a lot of code is duplicated. So MH will eventually call these utilities.
src/sagemaker/jumpstart/model.py
Outdated
|
||
if sample_payloads is None or len(sample_payloads) == 0: | ||
raise NotImplementedError( | ||
f"No default payload supported for model ID '{self.model_id}'." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: error vs returning None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, i'm cool with both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you now? That's something you should definitely have an opinion on Evan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that both experiences can be justified from a UX perspective. I guess, to maintain consistency, we should return None
, since we don't typically use NotImplementedError
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
src/sagemaker/base_predictor.py
Outdated
@@ -201,20 +204,42 @@ def _create_request_args( | |||
custom_attributes=None, | |||
): | |||
"""Placeholder docstring""" | |||
|
|||
js_accept = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also consider spelling out for better readability: js_
-> jumpstart_
|
||
Requests are cached so that the same s3 request is never made more | ||
than once, unless a different region or client is used. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hum, not sure about this, you are effectively caching objects in memory aren't you?
Could you:
(a) determine the max memory you are willing to use for such cache?
(b) add a head object with a size limits for such objects
(c) derive the max number of items in the @lru_cache(max_items)
from round((a)/(b)) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The maximum memory is very system dependent and the payloads will come in all different sizes. How about we expose a function that clears the cache JumpStartS3Accessor.clear_cache()
? This can call JumpStartS3Accessor.get_object_cached.cache_clear()
under the hood. See: https://stackoverflow.com/questions/37653784/how-do-i-use-cache-clear-on-python-functools-lru-cache
if s3_client is None: | ||
s3_client = JumpStartS3Accessor._get_default_s3_client(region) | ||
|
||
return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line will buffer the whole object in memory? Is that acceptable and shouldn't you build in safeguards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of safeguards are you referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for posterity as you addressed above: inference file size mainly.
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains functions for obtaining JumpStart payloads.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review with @judyheflin please. Maybe obtaining example payloads for JumpStart models.
src/sagemaker/jumpstart/model.py
Outdated
def retrieve_default_payload(self) -> JumpStartSerializablePayload: | ||
"""Returns default payload associated with the model. | ||
|
||
Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is customer facing: please add a Raises
section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if payloads.retrieve_options
can throw, please port over the appropriate item of its Raises
section.
Related question: do you need a way to propagate the tolerate_vulnerable
and tolerate_deprecated
flags?
src/sagemaker/payloads.py
Outdated
logger = logging.getLogger(__name__) | ||
|
||
|
||
def retrieve_samples( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't we used retrieve_options
in other accessors?
src/sagemaker/payloads.py
Outdated
Dict[str, JumpStartSerializablePayload] | ||
] = artifacts._retrieve_default_payloads( | ||
model_id, | ||
model_version, # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: why do you need to silence the linter?
|
||
unserialized_payloads: List[JumpStartSerializablePayload] = list( | ||
unserialized_payload_dict.values() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here aren't we buffering a list of content-embedded payloads in memory?
Let's sync up offline on this point please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default we embed the content in the payload, however we give the option not to.
def test_default_payload_jumpstart_model(setup): | ||
|
||
# DO NOT COMMIT THIS LINE | ||
os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't expose this bucket name in publicly available code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't get committed, but yes I will remove
"huggingface_transformers_version": "4.17", | ||
}, | ||
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" | ||
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sanity-check: did you intend to jump a line here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was just to reduce the line length
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/bot run slow-tests
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
src/sagemaker/jumpstart/model.py
Outdated
@@ -312,6 +314,27 @@ def _is_valid_model_id_hook(): | |||
|
|||
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) | |||
|
|||
def retrieve_default_payload(self) -> JumpStartSerializablePayload: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I missed this: this instance method needs be renamed retrieve_example_payload()
Please also add a retrieve_all_example_payloads()
instance method.
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/bot run notebook-tests, pr, unit-tests
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Issue #, if available:
Description of changes:
Adds default payload functionality for JumpStart Models. The following workflow is now possible:
The
JumpStartModel
class now exposes a functionretrieve_default_payload
to give customers a starting point for interacting with the model.Standalone utilities has also been created in
sagemaker.payloads
module:retrieve_all_examples
retrieve_example
Testing done:
Integration tests run locally against staged metadata.
Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
unique_name_from_base
to create resource names in integ tests (if appropriate)By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.