Skip to content

Commit 1b2a4fb

Browse files
author
Jonathan Makunga
committed
Testing
1 parent edb5716 commit 1b2a4fb

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -605,19 +605,24 @@ def build(
605605

606606
self.serve_settings = self._get_serve_setting()
607607

608-
sample_input, sample_output = task.retrieve_local_schemas("text-generation")
609-
self.schema_builder = SchemaBuilder(sample_input, sample_output)
610-
611608
if isinstance(self.model, str):
612609
if self._is_jumpstart_model_id():
613610
return self._build_for_jumpstart()
614611
if self._is_djl(): # pylint: disable=R1705
615612
return self._build_for_djl()
616613
else:
614+
logger.info("******************************************************")
615+
logger.info(f"schema_builder is None: {self.schema_builder is None}")
616+
617617
hf_model_md = get_huggingface_model_metadata(
618618
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
619619
)
620-
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
620+
621+
hf_task = hf_model_md.get("pipeline_tag")
622+
logger.info(f"hf_task: {hf_task}")
623+
self._schema_builder_init(hf_task)
624+
625+
if hf_task == "text-generation": # pylint: disable=R1705
621626
return self._build_for_tgi()
622627
else:
623628
return self._build_for_transformers()
@@ -678,16 +683,19 @@ def validate(self, model_dir: str) -> Type[bool]:
678683

679684
def _schema_builder_init(self, model_task: str):
680685
"""Initialize the"""
681-
sample_input, sample_output = None, None
686+
sample_inputs, sample_outputs = None, None
682687

683688
try:
684-
sample_input, sample_output = task.retrieve_local_schemas(model_task)
689+
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
690+
logger.info(f"Sample input: {sample_inputs}")
691+
logger.info(f"Sample output: {sample_outputs}")
685692
except ValueError:
686693
# TODO: try to retrieve schemas remotely
687694
pass
688695

689-
if sample_input and sample_output:
690-
self.schema_builder = SchemaBuilder(sample_input, sample_output)
696+
if sample_inputs and sample_outputs:
697+
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
698+
logger.info(f"schema_builder is not None: {self.schema_builder is None}")
691699
else:
692700
# TODO: Raise ClientError
693701
pass

0 commit comments

Comments
 (0)