@@ -605,19 +605,24 @@ def build(
605
605
606
606
self .serve_settings = self ._get_serve_setting ()
607
607
608
- sample_input , sample_output = task .retrieve_local_schemas ("text-generation" )
609
- self .schema_builder = SchemaBuilder (sample_input , sample_output )
610
-
611
608
if isinstance (self .model , str ):
612
609
if self ._is_jumpstart_model_id ():
613
610
return self ._build_for_jumpstart ()
614
611
if self ._is_djl (): # pylint: disable=R1705
615
612
return self ._build_for_djl ()
616
613
else :
614
+ logger .info ("******************************************************" )
615
+ logger .info (f"schema_builder is None: { self .schema_builder is None } " )
616
+
617
617
hf_model_md = get_huggingface_model_metadata (
618
618
self .model , self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
619
619
)
620
- if hf_model_md .get ("pipeline_tag" ) == "text-generation" : # pylint: disable=R1705
620
+
621
+ hf_task = hf_model_md .get ("pipeline_tag" )
622
+ logger .info (f"hf_task: { hf_task } " )
623
+ self ._schema_builder_init (hf_task )
624
+
625
+ if hf_task == "text-generation" : # pylint: disable=R1705
621
626
return self ._build_for_tgi ()
622
627
else :
623
628
return self ._build_for_transformers ()
@@ -678,16 +683,19 @@ def validate(self, model_dir: str) -> Type[bool]:
678
683
679
684
def _schema_builder_init (self , model_task : str ):
680
685
"""Initialize the"""
681
- sample_input , sample_output = None , None
686
+ sample_inputs , sample_outputs = None , None
682
687
683
688
try :
684
- sample_input , sample_output = task .retrieve_local_schemas (model_task )
689
+ sample_inputs , sample_outputs = task .retrieve_local_schemas (model_task )
690
+ logger .info (f"Sample input: { sample_inputs } " )
691
+ logger .info (f"Sample output: { sample_outputs } " )
685
692
except ValueError :
686
693
# TODO: try to retrieve schemas remotely
687
694
pass
688
695
689
- if sample_input and sample_output :
690
- self .schema_builder = SchemaBuilder (sample_input , sample_output )
696
+ if sample_inputs and sample_outputs :
697
+ self .schema_builder = SchemaBuilder (sample_inputs , sample_outputs )
698
+ logger .info (f"schema_builder is not None: { self .schema_builder is None } " )
691
699
else :
692
700
# TODO: Raise ClientError
693
701
pass
0 commit comments