Skip to content

Commit 4184689

Browse files
committed
address reviewer feedback
1 parent 2cbecd6 commit 4184689

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

doc/frameworks/djl/using_djl.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ Alternatively, you can provide full specifications to the DJLModel to have full
4949
},
5050
image_uri=<djl lmi image uri>,
5151
)
52+
# Deploy the model to an Amazon SageMaker Endpoint and get a Predictor
53+
predictor = djl_model.deploy("ml.g5.12xlarge",
54+
initial_instance_count=1)
5255
56+
Regardless of how you create your model, a ``Predictor`` object is returned.
5357
Each ``Predictor`` provides a ``predict`` method, which can do inference with json data, numpy arrays, or Python lists.
5458
Inference data are serialized and sent to the DJL Serving model server by an ``InvokeEndpoint`` SageMaker operation. The
5559
``predict`` method returns the result of inference against your model.

src/sagemaker/djl_inference/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,8 @@ def _infer_engine(self) -> Optional[str]:
147147
logger.info("Using provided engine %s", self.engine)
148148
return self.engine
149149

150-
if self.task is not None:
151-
if self.task == "text-embedding":
152-
return "OnnxRuntime"
150+
if self.task == "text-embedding":
151+
return "OnnxRuntime"
153152
return "Python"
154153

155154
def _infer_image_uri(self):

src/sagemaker/serve/model_server/djl_serving/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _get_default_max_tokens(sample_input, sample_output) -> tuple:
9898

9999
def _get_default_djl_configurations(
100100
model_id: str, hf_model_config: dict, schema_builder: SchemaBuilder
101-
) -> tuple:
101+
) -> tuple[dict, int]:
102102
"""Placeholder docstring"""
103103
default_tensor_parallel_degree = _get_default_tensor_parallel_degree(hf_model_config)
104104
if default_tensor_parallel_degree is None:

0 commit comments

Comments
 (0)