-
Notifications
You must be signed in to change notification settings - Fork 1.2k
tests: Implement integration tests covering JumpStart PrivateHub workflows #4883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8910f50
bfeb2c0
705ceb9
386d836
fa7e47c
7c50ee8
cb5f1c7
52991a0
4bd94ec
3fed9f4
6456883
bac00dd
556d120
8ff04d3
d79b8a3
9e29524
e0b8467
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,24 +16,43 @@ | |
import boto3 | ||
import pytest | ||
from botocore.config import Config | ||
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
from sagemaker.session import Session | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, | ||
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, | ||
HUB_NAME_PREFIX, | ||
JUMPSTART_TAG, | ||
) | ||
|
||
from sagemaker.jumpstart.types import ( | ||
HubContentType, | ||
) | ||
|
||
|
||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_test_artifact_bucket, | ||
get_test_suite_id, | ||
get_sm_session, | ||
with_exponential_backoff, | ||
) | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME | ||
|
||
|
||
def _setup(): | ||
print("Setting up...") | ||
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()}) | ||
test_suite_id = get_test_suite_id() | ||
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}" | ||
test_hub_description = "PySDK Integ Test Private Hub" | ||
|
||
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id}) | ||
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name}) | ||
|
||
# Create a private hub to use for the test session | ||
hub = Hub( | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() | ||
) | ||
hub.create(description=test_hub_description) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we necessarily create a Hub every time a JS integ test is run? Does this bring any problems? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think of any problems tbh with this strategy, we're cleaning it up in the end. do you think we should be approaching it differently? |
||
|
||
|
||
def _teardown(): | ||
|
@@ -43,6 +62,8 @@ def _teardown(): | |
|
||
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] | ||
|
||
test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] | ||
|
||
boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) | ||
|
||
sagemaker_client = boto3_session.client( | ||
|
@@ -113,6 +134,29 @@ def _teardown(): | |
bucket = s3_resource.Bucket(test_cache_bucket) | ||
bucket.objects.filter(Prefix=test_suite_id + "/").delete() | ||
|
||
# delete private hubs | ||
_delete_hubs(sagemaker_session, test_hub_name) | ||
|
||
|
||
def _delete_hubs(sagemaker_session, hub_name): | ||
# list and delete all hub contents first | ||
list_hub_content_response = sagemaker_session.list_hub_contents( | ||
hub_name=hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value | ||
) | ||
for model in list_hub_content_response["HubContentSummaries"]: | ||
_delete_hub_contents(sagemaker_session, hub_name, model) | ||
|
||
sagemaker_session.delete_hub(hub_name) | ||
|
||
|
||
@with_exponential_backoff() | ||
def _delete_hub_contents(sagemaker_session, hub_name, model): | ||
sagemaker_session.delete_hub_content_reference( | ||
hub_name=hub_name, | ||
hub_content_type=HubContentType.MODEL_REFERENCE.value, | ||
hub_content_name=model["HubContentName"], | ||
) | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def setup(request): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
import time | ||
|
||
import pytest | ||
from sagemaker.enums import EndpointType | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs | ||
from sagemaker.predictor import retrieve_default | ||
|
||
import tests.integ | ||
|
||
from sagemaker.jumpstart.model import JumpStartModel | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, | ||
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, | ||
JUMPSTART_TAG, | ||
) | ||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_public_hub_model_arn, | ||
get_sm_session, | ||
with_exponential_backoff, | ||
) | ||
|
||
MAX_INIT_TIME_SECONDS = 5 | ||
|
||
TEST_MODEL_IDS = { | ||
"catboost-classification-model", | ||
"huggingface-txt2img-conflictx-complex-lineart", | ||
"meta-textgeneration-llama-2-7b", | ||
"meta-textgeneration-llama-3-2-1b", | ||
"catboost-regression-model", | ||
} | ||
Comment on lines
+40
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the integ test runs in PDX. Double check these are available in PDX region. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these should be, I have chosen these models from the existing Jumpstart hub integ tests |
||
|
||
|
||
@with_exponential_backoff() | ||
def create_model_reference(hub_instance, model_arn): | ||
hub_instance.create_model_reference(model_arn=model_arn) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def add_model_references(): | ||
# Create Model References to test in Hub | ||
hub_instance = Hub( | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() | ||
) | ||
for model in TEST_MODEL_IDS: | ||
model_arn = get_public_hub_model_arn(hub_instance, model) | ||
create_model_reference(hub_instance, model_arn) | ||
|
||
|
||
def test_jumpstart_hub_model(setup, add_model_references): | ||
|
||
model_id = "catboost-classification-model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make |
||
|
||
sagemaker_session = get_sm_session() | ||
|
||
model = JumpStartModel( | ||
model_id=model_id, | ||
role=sagemaker_session.get_caller_identity_arn(), | ||
sagemaker_session=sagemaker_session, | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we test via the HubArn path as well, since we noticed an issue around that once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry didn't get it, this is the HubArn path right? we ask customers to provide hub_name in the jumpstart model parameter, but we convert it into HubArn right after it. Sure customers can provide arn directly to model class but in that case we just leave it as it is and that gets passed to the rest of the code. |
||
) | ||
|
||
predictor = model.deploy( | ||
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], | ||
) | ||
|
||
assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) | ||
|
||
|
||
def test_jumpstart_hub_gated_model(setup, add_model_references): | ||
|
||
model_id = "meta-textgeneration-llama-3-2-1b" | ||
|
||
model = JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=get_sm_session(), | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
) | ||
|
||
predictor = model.deploy( | ||
accept_eula=True, | ||
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], | ||
) | ||
|
||
payload = model.retrieve_example_payload() | ||
|
||
response = predictor.predict(payload) | ||
|
||
assert response is not None | ||
|
||
|
||
def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): | ||
|
||
model_id = "meta-textgeneration-llama-2-7b" | ||
|
||
hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] | ||
|
||
region = tests.integ.test_region() | ||
|
||
sagemaker_session = get_sm_session() | ||
|
||
hub_arn = generate_hub_arn_for_init_kwargs( | ||
hub_name=hub_name, region=region, session=sagemaker_session | ||
) | ||
|
||
model = JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=sagemaker_session, | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
) | ||
|
||
model.deploy( | ||
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], | ||
accept_eula=True, | ||
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, | ||
) | ||
|
||
predictor = retrieve_default( | ||
endpoint_name=model.endpoint_name, | ||
sagemaker_session=sagemaker_session, | ||
tolerate_vulnerable_model=True, | ||
hub_arn=hub_arn, | ||
) | ||
|
||
payload = model.retrieve_example_payload() | ||
|
||
response = predictor.predict(payload) | ||
|
||
assert response is not None | ||
|
||
model = JumpStartModel.attach( | ||
predictor.endpoint_name, sagemaker_session=sagemaker_session, hub_name=hub_name | ||
) | ||
assert model.model_id == model_id | ||
assert model.endpoint_name == predictor.endpoint_name | ||
assert model.inference_component_name == predictor.component_name | ||
|
||
|
||
def test_instantiating_model(setup, add_model_references): | ||
|
||
model_id = "catboost-regression-model" | ||
|
||
start_time = time.perf_counter() | ||
|
||
JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=get_sm_session(), | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
) | ||
|
||
elapsed_time = time.perf_counter() - start_time | ||
|
||
assert elapsed_time <= MAX_INIT_TIME_SECONDS |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
import pytest | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
|
||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_sm_session, | ||
) | ||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_test_suite_id, | ||
) | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
HUB_NAME_PREFIX, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def hub_instance(): | ||
HUB_NAME = f"{HUB_NAME_PREFIX}-{get_test_suite_id()}" | ||
hub = Hub(HUB_NAME, sagemaker_session=get_sm_session()) | ||
yield hub | ||
|
||
|
||
def test_private_hub(setup, hub_instance): | ||
# Createhub | ||
create_hub_response = hub_instance.create( | ||
description="This is a Test Private Hub.", | ||
display_name="PySDK integration tests Hub", | ||
search_keywords=["jumpstart-sdk-integ-test"], | ||
) | ||
|
||
# Create Hub Verifications | ||
assert create_hub_response is not None | ||
|
||
# Describe Hub | ||
hub_description = hub_instance.describe() | ||
assert hub_description is not None | ||
|
||
# Delete Hub | ||
delete_hub_response = hub_instance.delete() | ||
assert delete_hub_response is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a unit test for this? seems like the current coverage didn't cover this bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
synced with @evakravi offline, this needs to be covered through unit tests and I'll add it as a fast follow.