Skip to content

Commit 338f608

Browse files
authored
Merge branch 'master' into master
2 parents e1b9420 + fa178be commit 338f608

File tree

1 file changed

+4
-4
lines changed
  • src/sagemaker/serve/model_server/multi_model_server

1 file changed

+4
-4
lines changed

src/sagemaker/serve/model_server/multi_model_server/inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
2222

2323

24-
def model_fn(model_dir):
24+
def model_fn(model_dir, context=None):
2525
"""Overrides default method for loading a model"""
2626
shared_libs_path = Path(model_dir + "/shared_libs")
2727

@@ -40,7 +40,7 @@ def model_fn(model_dir):
4040
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
4141

4242

43-
def input_fn(input_data, content_type):
43+
def input_fn(input_data, content_type, context=None):
4444
"""Deserializes the bytes that were received from the model server"""
4545
try:
4646
if hasattr(schema_builder, "custom_input_translator"):
@@ -72,12 +72,12 @@ def input_fn(input_data, content_type):
7272
raise Exception("Encountered error in deserialize_request.") from e
7373

7474

75-
def predict_fn(input_data, predict_callable):
75+
def predict_fn(input_data, predict_callable, context=None):
7676
"""Invokes the model that is taken in by model server"""
7777
return predict_callable(input_data)
7878

7979

80-
def output_fn(predictions, accept_type):
80+
def output_fn(predictions, accept_type, context=None):
8181
"""Prediction is serialized to bytes and sent back to the customer"""
8282
try:
8383
if hasattr(inference_spec, "postprocess"):

0 commit comments

Comments
 (0)