Skip to content

Commit 58d912b

Browse files
grenmesterJacky Lee
authored andcommitted
feat: add support for mlflow inputs (aws#1441)
* feat: add support for mlflow inputs * fix: typo * fix: doc * fix: S3 regex * fix: refactor * fix: refactor typo * fix: pylint * fix: pylint * fix: black and pylint --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent 2b0be8d commit 58d912b

File tree

9 files changed

+413
-82
lines changed

9 files changed

+413
-82
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ nbformat>=5.9,<6
3737
accelerate>=0.24.1,<=0.27.0
3838
schema==0.7.5
3939
tensorflow>=2.1,<=2.16
40+
mlflow>=2.12.2,<2.13

src/sagemaker/serve/builder/model_builder.py

Lines changed: 127 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
"""Holds the ModelBuilder class and the ModelServer enum."""
1414
from __future__ import absolute_import
15+
16+
import importlib.util
1517
import uuid
1618
from typing import Any, Type, List, Dict, Optional, Union
1719
from dataclasses import dataclass, field
1820
import logging
1921
import os
22+
import re
2023

2124
from pathlib import Path
2225

@@ -43,12 +46,15 @@
4346
from sagemaker.predictor import Predictor
4447
from sagemaker.serve.model_format.mlflow.constants import (
4548
MLFLOW_MODEL_PATH,
49+
MLFLOW_TRACKING_ARN,
50+
MLFLOW_RUN_ID_REGEX,
51+
MLFLOW_REGISTRY_PATH_REGEX,
52+
MODEL_PACKAGE_ARN_REGEX,
4653
MLFLOW_METADATA_FILE,
4754
MLFLOW_PIP_DEPENDENCY_FILE,
4855
)
4956
from sagemaker.serve.model_format.mlflow.utils import (
5057
_get_default_model_server_for_mlflow,
51-
_mlflow_input_is_local_path,
5258
_download_s3_artifacts,
5359
_select_container_for_mlflow_model,
5460
_generate_mlflow_artifact_path,
@@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
276282
default=None,
277283
metadata={
278284
"help": "Define the model metadata to override, currently supports `HF_TASK`, "
279-
"`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
280-
"the Hub, Adding unsupported task types will throw an exception"
285+
"`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
286+
"models without task metadata in the Hub, Adding unsupported task types will "
287+
"throw an exception"
281288
},
282289
)
283290

@@ -501,6 +508,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
501508
_maintain_lineage_tracking_for_mlflow_model(
502509
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
503510
s3_upload_path=self.s3_upload_path,
511+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
504512
sagemaker_session=self.sagemaker_session,
505513
)
506514
return new_model_package
@@ -571,6 +579,7 @@ def _model_builder_deploy_wrapper(
571579
_maintain_lineage_tracking_for_mlflow_model(
572580
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
573581
s3_upload_path=self.s3_upload_path,
582+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
574583
sagemaker_session=self.sagemaker_session,
575584
)
576585
return predictor
@@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):
625634

626635
return wrapper
627636

628-
def _check_if_input_is_mlflow_model(self) -> bool:
629-
"""Checks whether an MLmodel file exists in the given directory.
637+
def _handle_mlflow_input(self):
638+
"""Check whether an MLflow model is present and handle accordingly"""
639+
self._is_mlflow_model = self._has_mlflow_arguments()
640+
if not self._is_mlflow_model:
641+
return
642+
643+
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
644+
artifact_path = self._get_artifact_path(mlflow_model_path)
645+
if not self._mlflow_metadata_exists(artifact_path):
646+
logger.info(
647+
"MLflow model metadata not detected in %s. ModelBuilder is not "
648+
"handling MLflow model input",
649+
mlflow_model_path,
650+
)
651+
return
652+
653+
self._initialize_for_mlflow(artifact_path)
654+
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
655+
656+
def _has_mlflow_arguments(self) -> bool:
657+
"""Check whether MLflow model arguments are present
630658
631659
Returns:
632-
bool: True if the MLmodel file exists, False otherwise.
660+
bool: True if MLflow arguments are present, False otherwise.
633661
"""
634662
if self.inference_spec or self.model:
635663
logger.info(
@@ -644,16 +672,80 @@ def _check_if_input_is_mlflow_model(self) -> bool:
644672
)
645673
return False
646674

647-
path = self.model_metadata.get(MLFLOW_MODEL_PATH)
648-
if not path:
675+
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
676+
if not mlflow_model_path:
649677
logger.info(
650678
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
651679
"input",
652680
MLFLOW_MODEL_PATH,
653681
)
654682
return False
655683

656-
# Check for S3 path
684+
return True
685+
686+
def _get_artifact_path(self, mlflow_model_path: str) -> str:
687+
"""Retrieves the model artifact location given the Mlflow model input.
688+
689+
Args:
690+
mlflow_model_path (str): The MLflow model path input.
691+
692+
Returns:
693+
str: The path to the model artifact.
694+
"""
695+
if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match(
696+
MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path
697+
):
698+
mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN)
699+
if not mlflow_tracking_arn:
700+
raise ValueError(
701+
"%s is not provided in ModelMetadata or through set_tracking_arn "
702+
"but MLflow model path was provided." % MLFLOW_TRACKING_ARN,
703+
)
704+
705+
if not importlib.util.find_spec("awsmlflow"):
706+
raise ImportError("Unable to import awsmlflow, check if awsmlflow is installed")
707+
708+
import mlflow
709+
710+
mlflow.set_tracking_uri(mlflow_tracking_arn)
711+
if is_run_id_type:
712+
_, run_id, model_path = mlflow_model_path.split("/", 2)
713+
artifact_uri = mlflow.get_run(run_id).info.artifact_uri
714+
if not artifact_uri.endswith("/"):
715+
artifact_uri += "/"
716+
return artifact_uri + model_path
717+
718+
mlflow_client = mlflow.MlflowClient()
719+
if not mlflow_model_path.endswith("/"):
720+
mlflow_model_path += "/"
721+
722+
if "@" in mlflow_model_path:
723+
_, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2)
724+
model_name, model_alias = model_name_and_alias.split("@")
725+
model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias)
726+
else:
727+
_, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3)
728+
model_metadata = mlflow_client.get_model_version(model_name, model_version)
729+
730+
source = model_metadata.source
731+
if not source.endswith("/"):
732+
source += "/"
733+
return source + artifact_uri
734+
735+
if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path):
736+
model_package = self.sagemaker_session.sagemaker_client.describe_model_package(
737+
ModelPackageName=mlflow_model_path
738+
)
739+
return model_package["SourceUri"]
740+
741+
return mlflow_model_path
742+
743+
def _mlflow_metadata_exists(self, path: str) -> bool:
744+
"""Checks whether an MLmodel file exists in the given directory.
745+
746+
Returns:
747+
bool: True if the MLmodel file exists, False otherwise.
748+
"""
657749
if path.startswith("s3://"):
658750
s3_downloader = S3Downloader()
659751
if not path.endswith("/"):
@@ -665,17 +757,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
665757
file_path = os.path.join(path, MLFLOW_METADATA_FILE)
666758
return os.path.isfile(file_path)
667759

668-
def _initialize_for_mlflow(self) -> None:
669-
"""Initialize mlflow model artifacts, image uri and model server."""
670-
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
671-
if not _mlflow_input_is_local_path(mlflow_path):
672-
# TODO: extend to package arn, run id and etc.
673-
logger.info(
674-
"Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
675-
)
676-
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
760+
def _initialize_for_mlflow(self, artifact_path: str) -> None:
761+
"""Initialize mlflow model artifacts, image uri and model server.
762+
763+
Args:
764+
artifact_path (str): The path to the artifact store.
765+
"""
766+
if artifact_path.startswith("s3://"):
767+
_download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session)
768+
elif os.path.exists(artifact_path):
769+
_copy_directory_contents(artifact_path, self.model_path)
677770
else:
678-
_copy_directory_contents(mlflow_path, self.model_path)
771+
raise ValueError("Invalid path: %s" % artifact_path)
679772
mlflow_model_metadata_path = _generate_mlflow_artifact_path(
680773
self.model_path, MLFLOW_METADATA_FILE
681774
)
@@ -728,6 +821,8 @@ def build( # pylint: disable=R0911
728821
self.role_arn = role_arn
729822
self.sagemaker_session = sagemaker_session or Session()
730823

824+
self.sagemaker_session.settings._local_download_dir = self.model_path
825+
731826
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
732827
# decorate to_string() due to
733828
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -739,14 +834,8 @@ def build( # pylint: disable=R0911
739834
self.serve_settings = self._get_serve_setting()
740835

741836
self._is_custom_image_uri = self.image_uri is not None
742-
self._is_mlflow_model = self._check_if_input_is_mlflow_model()
743-
if self._is_mlflow_model:
744-
logger.warning(
745-
"Support of MLflow format models is experimental and is not intended"
746-
" for production at this moment."
747-
)
748-
self._initialize_for_mlflow()
749-
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
837+
838+
self._handle_mlflow_input()
750839

751840
if isinstance(self.model, str):
752841
model_task = None
@@ -836,6 +925,17 @@ def validate(self, model_dir: str) -> Type[bool]:
836925

837926
return get_metadata(model_dir)
838927

928+
def set_tracking_arn(self, arn: str):
929+
"""Set tracking server ARN"""
930+
# TODO: support native MLflow URIs
931+
if importlib.util.find_spec("awsmlflow"):
932+
import mlflow
933+
934+
mlflow.set_tracking_uri(arn)
935+
self.model_metadata[MLFLOW_TRACKING_ARN] = arn
936+
else:
937+
raise ImportError("Unable to import awsmlflow, check if awsmlflow is installed")
938+
839939
def _hf_schema_builder_init(self, model_task: str):
840940
"""Initialize the schema builder for the given HF_TASK
841941

src/sagemaker/serve/model_format/mlflow/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
MODEL_PACKAGE_ARN_REGEX = (
2323
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
2424
)
25-
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
26-
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
25+
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+/[/a-zA-Z0-9\-_\.]+$"
26+
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$"
2727
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
28+
MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN"
2829
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
2930
MLFLOW_METADATA_FILE = "MLmodel"
3031
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"

src/sagemaker/serve/model_format/mlflow/utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file(
227227
raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.")
228228

229229

230-
def _mlflow_input_is_local_path(model_path: str) -> bool:
231-
"""Checks if the given model_path is a local filesystem path.
232-
233-
Args:
234-
- model_path (str): The model path to check.
235-
236-
Returns:
237-
- bool: True if model_path is a local path, False otherwise.
238-
"""
239-
if model_path.startswith("s3://"):
240-
return False
241-
242-
if "/runs/" in model_path or model_path.startswith("runs:"):
243-
return False
244-
245-
# Check if it's not a local file path
246-
if not os.path.exists(model_path):
247-
return False
248-
249-
return True
250-
251-
252230
def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None:
253231
"""Downloads all artifacts from a specified S3 path to a local destination path.
254232

src/sagemaker/serve/utils/lineage_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
LINEAGE_POLLER_INTERVAL_SECS = 15
1818
LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120
19+
TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$"
20+
TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
1921
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData"
2022
MLFLOW_S3_PATH = "S3"
2123
MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage"

0 commit comments

Comments
 (0)