Skip to content

Commit e3361c2

Browse files
authored
Merge pull request #1 from makungaj1/schema-builder-detector
ModelBuilder to fetch local schema when no SchemaBuilder present
2 parents 5559ba3 + 17ae5d9 commit e3361c2

File tree

6 files changed

+245
-2
lines changed

6 files changed

+245
-2
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"description": "Sample Task Inputs and Outputs",
3+
"fill-mask": {
4+
"ref": "https://huggingface.co/tasks/fill-mask",
5+
"inputs": {
6+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/input.json",
7+
"properties": {
8+
"inputs": "Paris is the <mask> of France.",
9+
"parameters": {}
10+
}
11+
},
12+
"outputs": {
13+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/output.json",
14+
"properties": [
15+
{
16+
"sequence": "Paris is the capital of France.",
17+
"score": 0.7
18+
}
19+
]
20+
}
21+
},
22+
"question-answering": {
23+
"ref": "https://huggingface.co/tasks/question-answering",
24+
"inputs": {
25+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/input.json",
26+
"properties": {
27+
"context": "I have a German Shepherd dog, named Coco.",
28+
"question": "What is my dog's breed?"
29+
}
30+
},
31+
"outputs": {
32+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/output.json",
33+
"properties": [
34+
{
35+
"answer": "German Shepherd",
36+
"score": 0.972,
37+
"start": 9,
38+
"end": 24
39+
}
40+
]
41+
}
42+
},
43+
"text-classification": {
44+
"ref": "https://huggingface.co/tasks/text-classification",
45+
"inputs": {
46+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/input.json",
47+
"properties": {
48+
"inputs": "Where is the capital of France?, Paris is the capital of France.",
49+
"parameters": {}
50+
}
51+
},
52+
"outputs": {
53+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/output.json",
54+
"properties": [
55+
{
56+
"label": "entailment",
57+
"score": 0.997
58+
}
59+
]
60+
}
61+
},
62+
"text-generation": {
63+
"ref": "https://huggingface.co/tasks/text-generation",
64+
"inputs": {
65+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/input.json",
66+
"properties": {
67+
"inputs": "Hello, I'm a language model",
68+
"parameters": {}
69+
}
70+
},
71+
"outputs": {
72+
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/output.json",
73+
"properties": [
74+
{
75+
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
76+
}
77+
]
78+
}
79+
}
80+
}

src/sagemaker/serve/builder/model_builder.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pathlib import Path
2222

23-
from sagemaker import Session
23+
from sagemaker import Session, task
2424
from sagemaker.model import Model
2525
from sagemaker.base_predictor import PredictorBase
2626
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
@@ -38,6 +38,7 @@
3838
from sagemaker.predictor import Predictor
3939
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4040
from sagemaker.serve.spec.inference_spec import InferenceSpec
41+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4142
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
4243
from sagemaker.serve.detector.image_detector import (
4344
auto_detect_container,
@@ -614,7 +615,12 @@ def build(
614615
hf_model_md = get_huggingface_model_metadata(
615616
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
616617
)
617-
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
618+
619+
model_task = hf_model_md.get("pipeline_tag")
620+
if self.schema_builder is None:
621+
self._schema_builder_init(model_task)
622+
623+
if model_task == "text-generation": # pylint: disable=R1705
618624
return self._build_for_tgi()
619625
else:
620626
return self._build_for_transformers()
@@ -672,3 +678,18 @@ def validate(self, model_dir: str) -> Type[bool]:
672678
"""
673679

674680
return get_metadata(model_dir)
681+
682+
def _schema_builder_init(self, model_task: str):
683+
"""Initialize the schema builder
684+
685+
Args:
686+
model_task (str): Required, the task name
687+
688+
Raises:
689+
TaskNotFoundException: If the I/O schema for the given task is not found.
690+
"""
691+
try:
692+
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
693+
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
694+
except ValueError:
695+
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")

src/sagemaker/serve/utils/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException):
6060

6161
def __init__(self, message):
6262
super().__init__(message=message)
63+
64+
65+
class TaskNotFoundException(ModelBuilderException):
66+
"""Raise when task could not be found"""
67+
68+
fmt = "Error Message: {message}"
69+
70+
def __init__(self, message):
71+
super().__init__(message=message)

src/sagemaker/task.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve task fallback input/output schema"""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
from typing import Any, Tuple
19+
20+
21+
def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
22+
"""Retrieves task sample inputs and outputs locally.
23+
24+
Args:
25+
task (str): Required, the task name
26+
27+
Returns:
28+
Tuple[Any, Any]: A tuple that contains the sample input,
29+
at index 0, and output schema, at index 1.
30+
31+
Raises:
32+
ValueError: If no tasks config found or the task does not exist in the local config.
33+
"""
34+
task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json")
35+
try:
36+
with open(task_io_config_path) as f:
37+
task_io_config = json.load(f)
38+
task_io_schemas = task_io_config.get(task, None)
39+
40+
if task_io_schemas is None:
41+
raise ValueError(f"Could not find {task} I/O schema.")
42+
43+
sample_schema = (
44+
task_io_schemas["inputs"]["properties"],
45+
task_io_schemas["outputs"]["properties"],
46+
)
47+
return sample_schema
48+
except FileNotFoundError:
49+
raise ValueError("Could not find tasks config file.")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import task
16+
from sagemaker.serve.builder.model_builder import ModelBuilder
17+
18+
import logging
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session):
24+
model_builder = ModelBuilder(model="bert-base-uncased")
25+
26+
model = model_builder.build(sagemaker_session=sagemaker_session)
27+
28+
assert model is not None
29+
assert model_builder.schema_builder is not None
30+
31+
inputs, outputs = task.retrieve_local_schemas("fill-mask")
32+
assert model_builder.schema_builder.sample_input == inputs
33+
assert model_builder.schema_builder.sample_output == outputs
34+
35+
36+
def test_model_builder_happy_path_with_only_model_id_question_answering(sagemaker_session):
37+
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")
38+
39+
model = model_builder.build(sagemaker_session=sagemaker_session)
40+
41+
assert model is not None
42+
assert model_builder.schema_builder is not None
43+
44+
inputs, outputs = task.retrieve_local_schemas("question-answering")
45+
assert model_builder.schema_builder.sample_input == inputs
46+
assert model_builder.schema_builder.sample_output == outputs

tests/unit/sagemaker/test_task.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker import task
17+
18+
EXPECTED_INPUTS = {"inputs": "Paris is the <mask> of France.", "parameters": {}}
19+
EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}]
20+
21+
22+
def test_retrieve_local_schemas_success():
23+
inputs, outputs = task.retrieve_local_schemas("fill-mask")
24+
25+
assert inputs == EXPECTED_INPUTS
26+
assert outputs == EXPECTED_OUTPUTS
27+
28+
29+
def test_retrieve_local_schemas_text_generation_success():
30+
inputs, outputs = task.retrieve_local_schemas("text-generation")
31+
32+
assert inputs is not None
33+
assert outputs is not None
34+
35+
36+
def test_retrieve_local_schemas_throws():
37+
with pytest.raises(ValueError):
38+
task.retrieve_local_schemas("not-present-task")

0 commit comments

Comments
 (0)