Skip to content

feat: jumpstart default payloads #4149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 10, 2023

Conversation

evakravi
Copy link
Member

@evakravi evakravi commented Oct 2, 2023

Issue #, if available:

Description of changes:
Adds default payload functionality for JumpStart Models. The following workflow is now possible:

from sagemaker.jumpstart.model import JumpStartModel

model = JumpStartModel(model_id=<model_id>)
default_payload = model.retrieve_default_payload()
predictor = model.deploy()
response = predictor.predict(default_payload)

The JumpStartModel class now exposes a function retrieve_default_payload to give customers a starting point for interacting with the model.

Standalone utilities has also been created in sagemaker.payloads module:

  • retrieve_all_examples
  • retrieve_example

Testing done:
Integration tests run locally against staged metadata.

Merge Checklist

Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.

General

  • I have read the CONTRIBUTING doc
  • I certify that the changes I am introducing will be backward compatible, and I have discussed concerns about this, if any, with the Python SDK team
  • I used the commit message format described in CONTRIBUTING
  • I have passed the region in to all S3 and STS clients that I've initialized as part of this change.
  • I have updated any necessary documentation, including READMEs and API docs (if appropriate)

Tests

  • I have added tests that prove my fix is effective or that my feature works (if appropriate)
  • I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes
  • I have checked that my tests are not configured for a specific region or account (if appropriate)
  • I have used unique_name_from_base to create resource names in integ tests (if appropriate)

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

@evakravi evakravi requested a review from a team as a code owner October 2, 2023 18:40
@evakravi evakravi requested review from jgoyani1 and removed request for a team October 2, 2023 18:40
@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: cba9034
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@evakravi evakravi marked this pull request as draft October 2, 2023 18:43
@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: 6082e59
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: bf13065
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: 5dc4aac
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

self.accept = accept

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartDefaultPayloads object."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: supposed to be JumpStartSerializedPayload?

logger = logging.getLogger(__name__)


def retrieve_samples(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: reasoning for switching to samples noun?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use defaults? I thought it would be confusing to use default in the context of returning both a single and multiple payloads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't we used retrieve_options in other accessors?

"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError(
"Must specify JumpStart `model_id` and `model_version` when retrieving model URIs."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: when retrieving payloads?

self.region = region
self.s3_client = s3_client

def get_bytes_payload_with_s3_references(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think MH has similar logic atm in the other PR, is the plan to reuse the functions in this PR post-merge?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's the idea. I was working on this PR when I realized a lot of code is duplicated. So MH will eventually call these utilities.


if sample_payloads is None or len(sample_payloads) == 0:
raise NotImplementedError(
f"No default payload supported for model ID '{self.model_id}'."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: error vs returning None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, i'm cool with both

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you now? That's something you should definitely have an opinion on Evan.

Copy link
Member Author

@evakravi evakravi Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that both experiences can be justified from a UX perspective. I guess, to maintain consistency, we should return None, since we don't typically use NotImplementedError

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-local-mode-tests
  • Commit ID: ca20d87
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-notebook-tests
  • Commit ID: ca20d87
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-pr
  • Commit ID: ca20d87
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: ca20d87
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-slow-tests
  • Commit ID: ca20d87
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-local-mode-tests
  • Commit ID: b2a6374
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-notebook-tests
  • Commit ID: b2a6374
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-slow-tests
  • Commit ID: b2a6374
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-pr
  • Commit ID: b2a6374
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: b2a6374
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@@ -201,20 +204,42 @@ def _create_request_args(
custom_attributes=None,
):
"""Placeholder docstring"""

js_accept = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing please

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also consider spelling out for better readability: js_ -> jumpstart_


Requests are cached so that the same s3 request is never made more
than once, unless a different region or client is used.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum, not sure about this, you are effectively caching objects in memory aren't you?

Could you:
(a) determine the max memory you are willing to use for such cache?
(b) add a head object with a size limits for such objects
(c) derive the max number of items in the @lru_cache(max_items) from round((a)/(b)) ?

Copy link
Member Author

@evakravi evakravi Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maximum memory is very system dependent and the payloads will come in all different sizes. How about we expose a function that clears the cache JumpStartS3Accessor.clear_cache()? This can call JumpStartS3Accessor.get_object_cached.cache_clear() under the hood. See: https://stackoverflow.com/questions/37653784/how-do-i-use-cache-clear-on-python-functools-lru-cache

if s3_client is None:
s3_client = JumpStartS3Accessor._get_default_s3_client(region)

return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line will buffer the whole object in memory? Is that acceptable and shouldn't you build in safeguards?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of safeguards are you referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for posterity as you addressed above: inference file size mainly.

# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functions for obtaining JumpStart payloads."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review with @judyheflin please. Maybe obtaining example payloads for JumpStart models.

def retrieve_default_payload(self) -> JumpStartSerializablePayload:
"""Returns default payload associated with the model.

Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is customer facing: please add a Raises section.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if payloads.retrieve_options can throw, please port over the appropriate item of its Raises section.

Related question: do you need a way to propagate the tolerate_vulnerable and tolerate_deprecated flags?

logger = logging.getLogger(__name__)


def retrieve_samples(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't we used retrieve_options in other accessors?

Dict[str, JumpStartSerializablePayload]
] = artifacts._retrieve_default_payloads(
model_id,
model_version, # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: why do you need to silence the linter?


unserialized_payloads: List[JumpStartSerializablePayload] = list(
unserialized_payload_dict.values()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here aren't we buffering a list of content-embedded payloads in memory?

Let's sync up offline on this point please.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default we embed the content in the payload, however we give the option not to.

def test_default_payload_jumpstart_model(setup):

# DO NOT COMMIT THIS LINE
os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't expose this bucket name in publicly available code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't get committed, but yes I will remove

"huggingface_transformers_version": "4.17",
},
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sanity-check: did you intend to jump a line here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just to reduce the line length

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-local-mode-tests
  • Commit ID: 0a61992
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

Copy link
Member Author

@evakravi evakravi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/bot run slow-tests

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-pr
  • Commit ID: c869f30
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: c869f30
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-slow-tests
  • Commit ID: c869f30
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@@ -312,6 +314,27 @@ def _is_valid_model_id_hook():

super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())

def retrieve_default_payload(self) -> JumpStartSerializablePayload:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I missed this: this instance method needs be renamed retrieve_example_payload()

Please also add a retrieve_all_example_payloads() instance method.

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-local-mode-tests
  • Commit ID: 2aecf36
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-notebook-tests
  • Commit ID: 2aecf36
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-slow-tests
  • Commit ID: 2aecf36
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: 2aecf36
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-pr
  • Commit ID: 2aecf36
  • Result: FAILED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

Copy link
Member Author

@evakravi evakravi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/bot run notebook-tests, pr, unit-tests

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-notebook-tests
  • Commit ID: 2aecf36
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-pr
  • Commit ID: 2aecf36
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: 2aecf36
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@evakravi evakravi dismissed JGuinegagne’s stale review October 10, 2023 15:02

problems addressed

@evakravi evakravi added the do-not-merge Use when PR needs to be marked as do not merge label Oct 10, 2023
@evakravi evakravi marked this pull request as draft October 10, 2023 16:28
@evakravi evakravi marked this pull request as ready for review October 10, 2023 16:56
@evakravi evakravi removed the do-not-merge Use when PR needs to be marked as do not merge label Oct 10, 2023
@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-local-mode-tests
  • Commit ID: e5a7058
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-notebook-tests
  • Commit ID: e5a7058
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-pr
  • Commit ID: e5a7058
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-slow-tests
  • Commit ID: e5a7058
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@sagemaker-bot
Copy link
Collaborator

AWS CodeBuild CI Report

  • CodeBuild project: sagemaker-python-sdk-unit-tests
  • Commit ID: e5a7058
  • Result: SUCCEEDED
  • Build Logs (available for 30 days)

Powered by github-codebuild-logs, available on the AWS Serverless Application Repository

@jgoyani1 jgoyani1 merged commit 7f6f3f9 into aws:master Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants