Skip to content

Commit beb23ec

Browse files
andjsmiAditi2424
andauthored
Added parsing string support for situations where custom code might be used (ie. mlflow) (#4960)
Co-authored-by: Aditi Sharma <[email protected]>
1 parent e57c850 commit beb23ec

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

src/sagemaker/serve/model_server/multi_model_server/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ def input_fn(input_data, content_type):
4545
try:
4646
if hasattr(schema_builder, "custom_input_translator"):
4747
deserialized_data = schema_builder.custom_input_translator.deserialize(
48-
io.BytesIO(input_data), content_type
48+
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
4949
)
5050
else:
5151
deserialized_data = schema_builder.input_deserializer.deserialize(
52-
io.BytesIO(input_data), content_type[0]
52+
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
5353
)
5454

5555
# Check if preprocess method is defined and call it

src/sagemaker/serve/model_server/torchserve/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def input_fn(input_data, content_type):
6767
try:
6868
if hasattr(schema_builder, "custom_input_translator"):
6969
deserialized_data = schema_builder.custom_input_translator.deserialize(
70-
io.BytesIO(input_data), content_type
70+
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
7171
)
7272
else:
7373
deserialized_data = schema_builder.input_deserializer.deserialize(
74-
io.BytesIO(input_data), content_type[0]
74+
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
7575
)
7676

7777
# Check if preprocess method is defined and call it

src/sagemaker/serve/model_server/torchserve/xgboost_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def input_fn(input_data, content_type):
7070
try:
7171
if hasattr(schema_builder, "custom_input_translator"):
7272
return schema_builder.custom_input_translator.deserialize(
73-
io.BytesIO(input_data), content_type
73+
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
7474
)
7575
else:
7676
return schema_builder.input_deserializer.deserialize(
77-
io.BytesIO(input_data), content_type[0]
77+
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
7878
)
7979
except Exception as e:
8080
raise Exception("Encountered error in deserialize_request.") from e

0 commit comments

Comments
 (0)