12
12
# language governing permissions and limitations under the License.
13
13
"""Holds the ModelBuilder class and the ModelServer enum."""
14
14
from __future__ import absolute_import
15
+
16
+ import importlib .util
15
17
import uuid
16
18
from typing import Any , Type , List , Dict , Optional , Union
17
19
from dataclasses import dataclass , field
18
20
import logging
19
21
import os
22
+ import re
20
23
21
24
from pathlib import Path
22
25
43
46
from sagemaker .predictor import Predictor
44
47
from sagemaker .serve .model_format .mlflow .constants import (
45
48
MLFLOW_MODEL_PATH ,
49
+ MLFLOW_TRACKING_ARN ,
50
+ MLFLOW_RUN_ID_REGEX ,
51
+ MLFLOW_REGISTRY_PATH_REGEX ,
52
+ MODEL_PACKAGE_ARN_REGEX ,
46
53
MLFLOW_METADATA_FILE ,
47
54
MLFLOW_PIP_DEPENDENCY_FILE ,
48
55
)
49
56
from sagemaker .serve .model_format .mlflow .utils import (
50
57
_get_default_model_server_for_mlflow ,
51
- _mlflow_input_is_local_path ,
52
58
_download_s3_artifacts ,
53
59
_select_container_for_mlflow_model ,
54
60
_generate_mlflow_artifact_path ,
@@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
276
282
default = None ,
277
283
metadata = {
278
284
"help" : "Define the model metadata to override, currently supports `HF_TASK`, "
279
- "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
280
- "the Hub, Adding unsupported task types will throw an exception"
285
+ "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
286
+ "models without task metadata in the Hub, Adding unsupported task types will "
287
+ "throw an exception"
281
288
},
282
289
)
283
290
@@ -501,6 +508,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
501
508
_maintain_lineage_tracking_for_mlflow_model (
502
509
mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
503
510
s3_upload_path = self .s3_upload_path ,
511
+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
504
512
sagemaker_session = self .sagemaker_session ,
505
513
)
506
514
return new_model_package
@@ -571,6 +579,7 @@ def _model_builder_deploy_wrapper(
571
579
_maintain_lineage_tracking_for_mlflow_model (
572
580
mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
573
581
s3_upload_path = self .s3_upload_path ,
582
+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
574
583
sagemaker_session = self .sagemaker_session ,
575
584
)
576
585
return predictor
@@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):
625
634
626
635
return wrapper
627
636
628
- def _check_if_input_is_mlflow_model (self ) -> bool :
629
- """Checks whether an MLmodel file exists in the given directory.
637
+ def _handle_mlflow_input (self ):
638
+ """Check whether an MLflow model is present and handle accordingly"""
639
+ self ._is_mlflow_model = self ._has_mlflow_arguments ()
640
+ if not self ._is_mlflow_model :
641
+ return
642
+
643
+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
644
+ artifact_path = self ._get_artifact_path (mlflow_model_path )
645
+ if not self ._mlflow_metadata_exists (artifact_path ):
646
+ logger .info (
647
+ "MLflow model metadata not detected in %s. ModelBuilder is not "
648
+ "handling MLflow model input" ,
649
+ mlflow_model_path ,
650
+ )
651
+ return
652
+
653
+ self ._initialize_for_mlflow (artifact_path )
654
+ _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
655
+
656
+ def _has_mlflow_arguments (self ) -> bool :
657
+ """Check whether MLflow model arguments are present
630
658
631
659
Returns:
632
- bool: True if the MLmodel file exists , False otherwise.
660
+ bool: True if MLflow arguments are present , False otherwise.
633
661
"""
634
662
if self .inference_spec or self .model :
635
663
logger .info (
@@ -644,16 +672,80 @@ def _check_if_input_is_mlflow_model(self) -> bool:
644
672
)
645
673
return False
646
674
647
- path = self .model_metadata .get (MLFLOW_MODEL_PATH )
648
- if not path :
675
+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
676
+ if not mlflow_model_path :
649
677
logger .info (
650
678
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
651
679
"input" ,
652
680
MLFLOW_MODEL_PATH ,
653
681
)
654
682
return False
655
683
656
- # Check for S3 path
684
+ return True
685
+
686
+ def _get_artifact_path (self , mlflow_model_path : str ) -> str :
687
+ """Retrieves the model artifact location given the Mlflow model input.
688
+
689
+ Args:
690
+ mlflow_model_path (str): The MLflow model path input.
691
+
692
+ Returns:
693
+ str: The path to the model artifact.
694
+ """
695
+ if (is_run_id_type := re .match (MLFLOW_RUN_ID_REGEX , mlflow_model_path )) or re .match (
696
+ MLFLOW_REGISTRY_PATH_REGEX , mlflow_model_path
697
+ ):
698
+ mlflow_tracking_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN )
699
+ if not mlflow_tracking_arn :
700
+ raise ValueError (
701
+ "%s is not provided in ModelMetadata or through set_tracking_arn "
702
+ "but MLflow model path was provided." % MLFLOW_TRACKING_ARN ,
703
+ )
704
+
705
+ if not importlib .util .find_spec ("awsmlflow" ):
706
+ raise ImportError ("Unable to import awsmlflow, check if awsmlflow is installed" )
707
+
708
+ import mlflow
709
+
710
+ mlflow .set_tracking_uri (mlflow_tracking_arn )
711
+ if is_run_id_type :
712
+ _ , run_id , model_path = mlflow_model_path .split ("/" , 2 )
713
+ artifact_uri = mlflow .get_run (run_id ).info .artifact_uri
714
+ if not artifact_uri .endswith ("/" ):
715
+ artifact_uri += "/"
716
+ return artifact_uri + model_path
717
+
718
+ mlflow_client = mlflow .MlflowClient ()
719
+ if not mlflow_model_path .endswith ("/" ):
720
+ mlflow_model_path += "/"
721
+
722
+ if "@" in mlflow_model_path :
723
+ _ , model_name_and_alias , artifact_uri = mlflow_model_path .split ("/" , 2 )
724
+ model_name , model_alias = model_name_and_alias .split ("@" )
725
+ model_metadata = mlflow_client .get_model_version_by_alias (model_name , model_alias )
726
+ else :
727
+ _ , model_name , model_version , artifact_uri = mlflow_model_path .split ("/" , 3 )
728
+ model_metadata = mlflow_client .get_model_version (model_name , model_version )
729
+
730
+ source = model_metadata .source
731
+ if not source .endswith ("/" ):
732
+ source += "/"
733
+ return source + artifact_uri
734
+
735
+ if re .match (MODEL_PACKAGE_ARN_REGEX , mlflow_model_path ):
736
+ model_package = self .sagemaker_session .sagemaker_client .describe_model_package (
737
+ ModelPackageName = mlflow_model_path
738
+ )
739
+ return model_package ["SourceUri" ]
740
+
741
+ return mlflow_model_path
742
+
743
+ def _mlflow_metadata_exists (self , path : str ) -> bool :
744
+ """Checks whether an MLmodel file exists in the given directory.
745
+
746
+ Returns:
747
+ bool: True if the MLmodel file exists, False otherwise.
748
+ """
657
749
if path .startswith ("s3://" ):
658
750
s3_downloader = S3Downloader ()
659
751
if not path .endswith ("/" ):
@@ -665,17 +757,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
665
757
file_path = os .path .join (path , MLFLOW_METADATA_FILE )
666
758
return os .path .isfile (file_path )
667
759
668
- def _initialize_for_mlflow (self ) -> None :
669
- """Initialize mlflow model artifacts, image uri and model server."""
670
- mlflow_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
671
- if not _mlflow_input_is_local_path (mlflow_path ):
672
- # TODO: extend to package arn, run id and etc.
673
- logger .info (
674
- "Start downloading model artifacts from %s to %s" , mlflow_path , self .model_path
675
- )
676
- _download_s3_artifacts (mlflow_path , self .model_path , self .sagemaker_session )
760
+ def _initialize_for_mlflow (self , artifact_path : str ) -> None :
761
+ """Initialize mlflow model artifacts, image uri and model server.
762
+
763
+ Args:
764
+ artifact_path (str): The path to the artifact store.
765
+ """
766
+ if artifact_path .startswith ("s3://" ):
767
+ _download_s3_artifacts (artifact_path , self .model_path , self .sagemaker_session )
768
+ elif os .path .exists (artifact_path ):
769
+ _copy_directory_contents (artifact_path , self .model_path )
677
770
else :
678
- _copy_directory_contents ( mlflow_path , self . model_path )
771
+ raise ValueError ( "Invalid path: %s" % artifact_path )
679
772
mlflow_model_metadata_path = _generate_mlflow_artifact_path (
680
773
self .model_path , MLFLOW_METADATA_FILE
681
774
)
@@ -728,6 +821,8 @@ def build( # pylint: disable=R0911
728
821
self .role_arn = role_arn
729
822
self .sagemaker_session = sagemaker_session or Session ()
730
823
824
+ self .sagemaker_session .settings ._local_download_dir = self .model_path
825
+
731
826
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
732
827
# decorate to_string() due to
733
828
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -739,14 +834,8 @@ def build( # pylint: disable=R0911
739
834
self .serve_settings = self ._get_serve_setting ()
740
835
741
836
self ._is_custom_image_uri = self .image_uri is not None
742
- self ._is_mlflow_model = self ._check_if_input_is_mlflow_model ()
743
- if self ._is_mlflow_model :
744
- logger .warning (
745
- "Support of MLflow format models is experimental and is not intended"
746
- " for production at this moment."
747
- )
748
- self ._initialize_for_mlflow ()
749
- _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
837
+
838
+ self ._handle_mlflow_input ()
750
839
751
840
if isinstance (self .model , str ):
752
841
model_task = None
@@ -836,6 +925,17 @@ def validate(self, model_dir: str) -> Type[bool]:
836
925
837
926
return get_metadata (model_dir )
838
927
928
+ def set_tracking_arn (self , arn : str ):
929
+ """Set tracking server ARN"""
930
+ # TODO: support native MLflow URIs
931
+ if importlib .util .find_spec ("awsmlflow" ):
932
+ import mlflow
933
+
934
+ mlflow .set_tracking_uri (arn )
935
+ self .model_metadata [MLFLOW_TRACKING_ARN ] = arn
936
+ else :
937
+ raise ImportError ("Unable to import awsmlflow, check if awsmlflow is installed" )
938
+
839
939
def _hf_schema_builder_init (self , model_task : str ):
840
940
"""Initialize the schema builder for the given HF_TASK
841
941
0 commit comments