21
21
METADATA_PATH = Path (__file__ ).parent .joinpath ("metadata.json" )
22
22
23
23
24
- def model_fn (model_dir ):
24
+ def model_fn (model_dir , context = None ):
25
25
"""Overrides default method for loading a model"""
26
26
shared_libs_path = Path (model_dir + "/shared_libs" )
27
27
@@ -40,7 +40,7 @@ def model_fn(model_dir):
40
40
return partial (inference_spec .invoke , model = inference_spec .load (model_dir ))
41
41
42
42
43
- def input_fn (input_data , content_type ):
43
+ def input_fn (input_data , content_type , context = None ):
44
44
"""Deserializes the bytes that were received from the model server"""
45
45
try :
46
46
if hasattr (schema_builder , "custom_input_translator" ):
@@ -72,12 +72,12 @@ def input_fn(input_data, content_type):
72
72
raise Exception ("Encountered error in deserialize_request." ) from e
73
73
74
74
75
- def predict_fn (input_data , predict_callable ):
75
+ def predict_fn (input_data , predict_callable , context = None ):
76
76
"""Invokes the model that is taken in by model server"""
77
77
return predict_callable (input_data )
78
78
79
79
80
- def output_fn (predictions , accept_type ):
80
+ def output_fn (predictions , accept_type , context = None ):
81
81
"""Prediction is serialized to bytes and sent back to the customer"""
82
82
try :
83
83
if hasattr (inference_spec , "postprocess" ):
0 commit comments