Skip to content

Commit e816567

Browse files
committed
simplify and refactor djl model for latest container releases
1 parent d444b7b commit e816567

File tree

6 files changed

+265
-1915
lines changed

6 files changed

+265
-1915
lines changed

src/sagemaker/djl_inference/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,5 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.djl_inference.model import DJLPredictor # noqa: F401
16+
from sagemaker.djl_inference.djl_predictor import DJLPredictor # noqa: F401
1717
from sagemaker.djl_inference.model import DJLModel # noqa: F401
18-
from sagemaker.djl_inference.model import DeepSpeedModel # noqa: F401
19-
from sagemaker.djl_inference.model import HuggingFaceAccelerateModel # noqa: F401
20-
from sagemaker.djl_inference.model import FasterTransformerModel # noqa: F401

src/sagemaker/djl_inference/defaults.py

Lines changed: 0 additions & 59 deletions
This file was deleted.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from sagemaker.predictor import Predictor
2+
from sagemaker import Session
3+
from sagemaker.serializers import BaseSerializer, JSONSerializer
4+
from sagemaker.deserializers import BaseDeserializer, JSONDeserializer
5+
6+
7+
class DJLPredictor(Predictor):
8+
"""A Predictor for inference against DJL Model Endpoints.
9+
10+
This is able to serialize Python lists, dictionaries, and numpy arrays to
11+
multidimensional tensors for DJL inference.
12+
"""
13+
14+
def __init__(
15+
self,
16+
endpoint_name: str,
17+
sagemaker_session: Session = None,
18+
serializer: BaseSerializer = JSONSerializer(),
19+
deserializer: BaseDeserializer = JSONDeserializer(),
20+
component_name=None,
21+
):
22+
"""Initialize a ``DJLPredictor``
23+
24+
Args:
25+
endpoint_name (str): The name of the endpoint to perform inference
26+
on.
27+
sagemaker_session (sagemaker.session.Session): Session object that
28+
manages interactions with Amazon SageMaker APIs and any other
29+
AWS services needed. If not specified, the estimator creates one
30+
using the default AWS configuration chain.
31+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
32+
serializes input data to json format.
33+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
34+
Default parses the response from json format to dictionary.
35+
component_name (str): Optional. Name of the Amazon SageMaker inference
36+
component corresponding the predictor.
37+
"""
38+
super(DJLPredictor, self).__init__(
39+
endpoint_name,
40+
sagemaker_session,
41+
serializer=serializer,
42+
deserializer=deserializer,
43+
component_name=component_name,
44+
)

0 commit comments

Comments
 (0)