-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: client cache for jumpstart models #2756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: client cache for jumpstart models #2756
Conversation
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Codecov Report
@@ Coverage Diff @@
## master-jumpstart #2756 +/- ##
====================================================
+ Coverage 88.61% 88.76% +0.14%
====================================================
Files 173 179 +6
Lines 15410 15733 +323
====================================================
+ Hits 13656 13965 +309
- Misses 1754 1768 +14
Continue to review full report at Codecov.
|
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, but could you:
- add unit tests for
JumpStartModelsCache
- test various failure cases (semantic raises an exception, s3:getObject raises a client error/ regular exception, VPC error etc)
src/sagemaker/jumpstart/cache.py
Outdated
from sagemaker.utilities.cache import LRUCache | ||
import boto3 | ||
import json | ||
import semantic_version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you please add the dependency imports at the top of the module, before the local imports?
src/sagemaker/jumpstart/cache.py
Outdated
DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20 | ||
DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6) | ||
|
||
DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please move all numeric constants to a parameters.py
and the key to a config.py
or constants.py
module?
src/sagemaker/jumpstart/cache.py
Outdated
|
||
Args: | ||
region (Optional[str]): AWS region to associate with cache. Default: region associated | ||
with botocore session. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: with boto3 session.
src/sagemaker/jumpstart/cache.py
Outdated
Args: | ||
region (Optional[str]): AWS region to associate with cache. Default: region associated | ||
with botocore session. | ||
max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "file" is confusing since this is an in-memory cache. Refer to items?
src/sagemaker/jumpstart/cache.py
Outdated
semantic version cache. Default: 20. | ||
semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold | ||
items in semantic version cache before invalidation. Default: 6 hours. | ||
bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean S3 bucket name, correct? please specify, it could be an arn for example.
src/sagemaker/utilities/cache.py
Outdated
element_age = curr_time - element.creation_time | ||
if element_age > self._expiration_time: | ||
if fail_on_old_value: | ||
raise KeyError(f"{key} is old! Created at {element.creation_time}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: f"{key} is beyond allowed time {self._expiration_time}. Created at {element.creation_time}."
src/sagemaker/utilities/cache.py
Outdated
""" | ||
self._max_cache_items = max_cache_items | ||
self._lru_cache: collections.OrderedDict = collections.OrderedDict() | ||
self._expiration_time = expiration_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_expiration_time
is a bit misleading because it could be interpreted as an absolute time, whereas it is a time delta. I am not sure what to replace it with though, maybe horizon
?
retrieval_function=retrieval_function, | ||
) | ||
|
||
curr_time = datetime.datetime.now() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for date test, try to avoid .now()
or any relative date. Instead pin the date and timezone. Otherwise, you can get your test randomly failing one day, for example in a pipeline. This usually happens over gap years/daylight saving day or other edge cases.
|
||
my_cache._retrieval_function.reset_mock() | ||
my_cache.get(0) | ||
my_cache._retrieval_function.assert_called() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice tests!
optional: would you like to check what it was called with?
my_cache.clear() | ||
assert len(my_cache) == 0 | ||
with pytest.raises(KeyError): | ||
my_cache.get(1, False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a few more test cases:
retrieve_function
isNone
retrieve_function
does not have the right signatureretrieve_function
raises anException
retrieve_function
raises abotocore.exceptions.ClientError
retrieve_function
raises a VPC-related error (get ETIMEDOUT)
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
…r jumpstart cache
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just two remaining comments:
- revisit etag logic
- unit tests for JumpStart cache
The other comments are nit.
if value is not None and etag == value.md5_hash: | ||
return value | ||
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) | ||
formatted_body = json.loads(response["Body"].read().decode("utf-8")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several issues here:
- if
value
isNone
, you are making an unnecessary http call - if
value
is notNone
but theetag
are different, you risk caching the wrongetag
Does this work?
if value is not None:
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
if etag == value.md5_hash:
return value
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
etag = response["ETag"]
return JumpStartCachedS3ContentValue(
formatted_file_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
While we are at it, can you unit test this behavior please?
src/sagemaker/jumpstart/cache.py
Outdated
return self._get_header_impl(model_id, 0, semantic_version_str) | ||
|
||
def _get_header_impl( | ||
self, model_id: str, attempt: int, semantic_version_str: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: how about setting a default for attempt
at 0?
src/sagemaker/jumpstart/utils.py
Outdated
rejected. | ||
|
||
Raises: | ||
RuntimeError: If the SageMaker version is not readable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: you can also mention the exception raised by semantic_version.Version
if you pass it an invalid version, for example semantic_version.Version("..")
src/sagemaker/utilities/cache.py
Outdated
def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: | ||
"""Returns value from cache corresponding to key. | ||
|
||
If ``fail_on_old_value``, a KeyError is thrown instead of a new value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit^2: python linguo calls for a KeyError is raised
retrieval_function=retrieval_function, | ||
) | ||
|
||
curr_time = datetime.datetime.fromtimestamp(1636730651.079551) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: calling this curr_time
made me to a double-take. How about mock_curr_time
?
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
"head_object", service_error_code="EndpointConnectionError" | ||
) | ||
with pytest.raises(botocore.exceptions.ClientError): | ||
cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: we may want to create a class that extends unittest.caseCase
and break down these tests into several class methods.
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
src/sagemaker/jumpstart/constants.py
Outdated
@@ -27,3 +27,5 @@ | |||
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name | |||
|
|||
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" | |||
|
|||
PARSED_SAGEMAKER_VERSION = "" # this gets set by sagemaker.jumpstart.utils.get_sagemaker_version() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is a bit misleading: this isn't a constant but a state variable.
how about moving this to the jumpstart/utils.py
and prefix it with an underline so that mypy
picks it up as an invalid import if it's used elsewhere?
_sagemaker_version = ""
def get_sagemaker_version() -> str:
if _sagemaker_version == "":
_sagemaker_version = parse_sagemaker_version()
return _sagemaker_version
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
2f8c692
to
0a8395a
Compare
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
@@ -44,6 +44,7 @@ def read_version(): | |||
"packaging>=20.0", | |||
"pandas", | |||
"pathos", | |||
"semantic-version", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to add extra package? Whats the use of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Merging this with a note:
|
Issue #, if available:
Description of changes:
This PR introduces a client-side cache for JumpStart models. The cache stores manifests and test specs associated with JumpStart models. In addition, this PR implements a low-level LRU cache class (with expiring elements). Currently, JumpStart SDK content is not launched in any region, so this code can only be used by plugging the cache to a pre-defined bucket (no default bucket).
In order to query models, the API for the cache requires a model id and semantic version. The semantic version relies on the semantic version library. This library allows us to perform version comparison and handle wildcard versions (e.g.
1.2.*
).The low-level cache class has been created in
sagemaker.utilities.cache
module. JumpStart-specific code is insagemaker.jumpstart
module.The cache can be used as follows:
Testing done:
Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
unique_name_from_base
to create resource names in integ tests (if appropriate)By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.