-
Notifications
You must be signed in to change notification settings - Fork 1.2k
ModelBuilder to fetch local schema when no SchemaBuilder present. #4434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
b2d0929
Fetch Schema locally
c4fb9f7
Fetch Schema locally
3b5e699
Local schema
edb5716
Test local schemas
1b2a4fb
Testing
b2acab3
Testing Schema
ad2303f
Schema for DJL
6ad7982
Merge branch 'aws:master' into schema-builder-detector
makungaj1 17ae5d9
Add Integ tests
e3361c2
Merge pull request #1 from makungaj1/schema-builder-detector
makungaj1 98dd0a9
address PR comments
854db04
Address PR Review Comments
967c9cc
Merge branch 'master' into master
makungaj1 99d11cb
Address PR Review Comments
38eed2f
Add Unit tests
4b3f617
Merge branch 'master' into master
makungaj1 f78637c
Address PR Comments
8f6e6b2
Address PR Comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
{ | ||
"fill-mask": { | ||
"sample_inputs": { | ||
"properties": { | ||
"inputs": "Paris is the <mask> of France.", | ||
"parameters": {} | ||
} | ||
}, | ||
"sample_outputs": { | ||
"properties": [ | ||
{ | ||
"sequence": "Paris is the capital of France.", | ||
"score": 0.7 | ||
} | ||
] | ||
} | ||
}, | ||
"question-answering": { | ||
"sample_inputs": { | ||
"properties": { | ||
"context": "I have a German Shepherd dog, named Coco.", | ||
"question": "What is my dog's breed?" | ||
} | ||
}, | ||
"sample_outputs": { | ||
"properties": [ | ||
{ | ||
"answer": "German Shepherd", | ||
"score": 0.972, | ||
"start": 9, | ||
"end": 24 | ||
} | ||
] | ||
} | ||
}, | ||
"text-classification": { | ||
"sample_inputs": { | ||
"properties": { | ||
"inputs": "Where is the capital of France?, Paris is the capital of France.", | ||
"parameters": {} | ||
} | ||
}, | ||
"sample_outputs": { | ||
"properties": [ | ||
{ | ||
"label": "entailment", | ||
"score": 0.997 | ||
} | ||
] | ||
} | ||
}, | ||
"text-generation": { | ||
"sample_inputs": { | ||
"properties": { | ||
"inputs": "Hello, I'm a language model", | ||
"parameters": {} | ||
} | ||
}, | ||
"sample_outputs": { | ||
"properties": [ | ||
{ | ||
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" | ||
} | ||
] | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Accessors to retrieve task fallback input/output schema""" | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
from typing import Any, Tuple | ||
|
||
|
||
def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: | ||
"""Retrieves task sample inputs and outputs locally. | ||
|
||
Args: | ||
task (str): Required, the task name | ||
|
||
Returns: | ||
Tuple[Any, Any]: A tuple that contains the sample input, | ||
at index 0, and output schema, at index 1. | ||
|
||
Raises: | ||
ValueError: If no tasks config found or the task does not exist in the local config. | ||
""" | ||
config_dir = os.path.dirname(os.path.dirname(__file__)) | ||
task_io_config_path = os.path.join(config_dir, "schema", "task.json") | ||
try: | ||
with open(task_io_config_path) as f: | ||
task_io_config = json.load(f) | ||
task_io_schemas = task_io_config.get(task, None) | ||
|
||
if task_io_schemas is None: | ||
raise ValueError(f"Could not find {task} I/O schema.") | ||
|
||
sample_schema = ( | ||
task_io_schemas["sample_inputs"]["properties"], | ||
task_io_schemas["sample_outputs"]["properties"], | ||
) | ||
return sample_schema | ||
except FileNotFoundError: | ||
raise ValueError("Could not find tasks config file.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
from sagemaker.serve.builder.model_builder import ModelBuilder | ||
from sagemaker.serve.utils import task | ||
|
||
import pytest | ||
|
||
from sagemaker.serve.utils.exceptions import TaskNotFoundException | ||
from tests.integ.sagemaker.serve.constants import ( | ||
PYTHON_VERSION_IS_NOT_310, | ||
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, | ||
) | ||
|
||
from tests.integ.timeout import timeout | ||
from tests.integ.utils import cleanup_model_resources | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session): | ||
model_builder = ModelBuilder(model="bert-base-uncased") | ||
|
||
model = model_builder.build(sagemaker_session=sagemaker_session) | ||
|
||
assert model is not None | ||
assert model_builder.schema_builder is not None | ||
|
||
inputs, outputs = task.retrieve_local_schemas("fill-mask") | ||
assert model_builder.schema_builder.sample_input == inputs | ||
assert model_builder.schema_builder.sample_output == outputs | ||
|
||
|
||
@pytest.mark.skipif( | ||
PYTHON_VERSION_IS_NOT_310, | ||
reason="Testing Schema Builder Simplification feature", | ||
) | ||
def test_model_builder_happy_path_with_only_model_id_question_answering( | ||
sagemaker_session, gpu_instance_type | ||
): | ||
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad") | ||
|
||
model = model_builder.build(sagemaker_session=sagemaker_session) | ||
|
||
assert model is not None | ||
assert model_builder.schema_builder is not None | ||
|
||
inputs, outputs = task.retrieve_local_schemas("question-answering") | ||
assert model_builder.schema_builder.sample_input == inputs | ||
assert model_builder.schema_builder.sample_output == outputs | ||
|
||
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): | ||
caught_ex = None | ||
try: | ||
iam_client = sagemaker_session.boto_session.client("iam") | ||
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] | ||
|
||
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") | ||
predictor = model.deploy( | ||
role=role_arn, instance_count=1, instance_type=gpu_instance_type | ||
) | ||
|
||
predicted_outputs = predictor.predict(inputs) | ||
assert predicted_outputs is not None | ||
|
||
except Exception as e: | ||
caught_ex = e | ||
finally: | ||
cleanup_model_resources( | ||
sagemaker_session=model_builder.sagemaker_session, | ||
model_name=model.name, | ||
endpoint_name=model.endpoint_name, | ||
) | ||
if caught_ex: | ||
logger.exception(caught_ex) | ||
assert ( | ||
False | ||
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" | ||
|
||
|
||
def test_model_builder_negative_path(sagemaker_session): | ||
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") | ||
|
||
with pytest.raises( | ||
TaskNotFoundException, | ||
match="Error Message: Schema builder for text-to-image could not be found.", | ||
): | ||
model_builder.build(sagemaker_session=sagemaker_session) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.