Skip to content

Add telemetry support for mlflow models #4674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,9 @@ def _initialize_for_mlflow(self) -> None:
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
if not _mlflow_input_is_local_path(mlflow_path):
# TODO: extend to package arn, run id and etc.
logger.info(
"Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
)
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
else:
_copy_directory_contents(mlflow_path, self.model_path)
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/serve/model_format/mlflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
"py39": "1.13.1",
"py310": "2.2.0",
}
MODEL_PACAKGE_ARN_REGEX = (
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/[" r"a-zA-Z0-9\-_\/\.]+$"
MODEL_PACKAGE_ARN_REGEX = (
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
)
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+\/[a-zA-Z0-9\-_\/\.]*$"
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
MLFLOW_METADATA_FILE = "MLmodel"
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/serve/model_format/mlflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non
os.makedirs(local_file_dir, exist_ok=True)

# Download the file
print(f"Downloading {key} to {local_file_path}")
logger.info(f"Downloading {key} to {local_file_path}")
s3.download_file(s3_bucket, key, local_file_path)


Expand Down Expand Up @@ -356,6 +356,15 @@ def _select_container_for_mlflow_model(
logger.info("Auto-detected framework to use is %s", framework_to_use)
logger.info("Auto-detected framework version is %s", framework_version)

if framework_version is None:
raise ValueError(
(
"Unable to auto detect framework version. Please provide framework %s as part of the "
"requirements.txt file for deployment flavor %s"
)
% (framework_to_use, deployment_flavor)
)

casted_versions = (
_cast_to_compatible_version(framework_to_use, framework_version)
if framework_version
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/serve/utils/lineage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sagemaker.lineage.query import LineageSourceEnum
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_RUN_ID_REGEX,
MODEL_PACAKGE_ARN_REGEX,
MODEL_PACKAGE_ARN_REGEX,
S3_PATH_REGEX,
MLFLOW_REGISTRY_PATH_REGEX,
)
Expand Down Expand Up @@ -107,7 +107,7 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
"""
mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX
mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX
sagemaker_arn_pattern = MODEL_PACAKGE_ARN_REGEX
sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX
s3_pattern = S3_PATH_REGEX

if re.match(mlflow_rub_id_pattern, mlflow_model_path):
Expand Down
22 changes: 22 additions & 0 deletions src/sagemaker/serve/utils/telemetry_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@

from sagemaker import Session, exceptions
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
from sagemaker.serve.utils.exceptions import ModelBuilderException
from sagemaker.serve.utils.lineage_constants import (
MLFLOW_LOCAL_PATH,
MLFLOW_S3_PATH,
MLFLOW_MODEL_PACKAGE_PATH,
MLFLOW_RUN_ID,
MLFLOW_REGISTRY_PATH,
)
from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type
from sagemaker.serve.utils.types import ModelServer, ImageUriOption
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
from sagemaker.user_agent import SDK_VERSION
Expand Down Expand Up @@ -51,6 +60,14 @@
str(ModelServer.TGI): 6,
}

MLFLOW_MODEL_PATH_CODE = {
MLFLOW_LOCAL_PATH: 1,
MLFLOW_S3_PATH: 2,
MLFLOW_MODEL_PACKAGE_PATH: 3,
MLFLOW_RUN_ID: 4,
MLFLOW_REGISTRY_PATH: 5,
}


def _capture_telemetry(func_name: str):
"""Placeholder docstring"""
Expand Down Expand Up @@ -78,6 +95,11 @@ def wrapper(self, *args, **kwargs):
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"

if getattr(self, "_is_mlflow_model", False):
mlflow_model_path = self.model_metadata[MLFLOW_MODEL_PATH]
mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path)
extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}"

start_timer = perf_counter()
try:
response = func(self, *args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,61 @@ def test_select_container_for_mlflow_model_no_dlc_detected(
)


@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version")
@patch("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements")
@patch(
"sagemaker.serve.model_format.mlflow.utils._get_python_version_from_parsed_mlflow_model_file"
)
@patch("sagemaker.serve.model_format.mlflow.utils._get_all_flavor_metadata")
@patch("sagemaker.serve.model_format.mlflow.utils._generate_mlflow_artifact_path")
def test_select_container_for_mlflow_model_no_framework_version_detected(
mock_generate_mlflow_artifact_path,
mock_get_all_flavor_metadata,
mock_get_python_version_from_parsed_mlflow_model_file,
mock_get_framework_version_from_requirements,
mock_cast_to_compatible_version,
mock_image_uris_retrieve,
):
mlflow_model_src_path = "/path/to/mlflow_model"
deployment_flavor = "pytorch"
region = "us-west-2"
instance_type = "ml.m5.xlarge"

mock_requirements_path = "/path/to/requirements.txt"
mock_metadata_path = "/path/to/mlmodel"
mock_flavor_metadata = {"pytorch": {"some_key": "some_value"}}
mock_python_version = "3.8.6"

mock_generate_mlflow_artifact_path.side_effect = lambda path, artifact: (
mock_requirements_path if artifact == "requirements.txt" else mock_metadata_path
)
mock_get_all_flavor_metadata.return_value = mock_flavor_metadata
mock_get_python_version_from_parsed_mlflow_model_file.return_value = mock_python_version
mock_get_framework_version_from_requirements.return_value = None

with pytest.raises(
ValueError,
match="Unable to auto detect framework version. Please provide framework "
"pytorch as part of the requirements.txt file for deployment flavor "
"pytorch",
):
_select_container_for_mlflow_model(
mlflow_model_src_path, deployment_flavor, region, instance_type
)

mock_generate_mlflow_artifact_path.assert_any_call(
mlflow_model_src_path, "requirements.txt"
)
mock_generate_mlflow_artifact_path.assert_any_call(mlflow_model_src_path, "MLmodel")
mock_get_all_flavor_metadata.assert_called_once_with(mock_metadata_path)
mock_get_framework_version_from_requirements.assert_called_once_with(
deployment_flavor, mock_requirements_path
)
mock_cast_to_compatible_version.assert_not_called()
mock_image_uris_retrieve.assert_not_called()


def test_validate_input_for_mlflow():
_validate_input_for_mlflow(ModelServer.TORCHSERVE, "pytorch")

Expand Down
36 changes: 36 additions & 0 deletions tests/unit/sagemaker/serve/utils/test_telemetry_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import unittest
from unittest.mock import Mock, patch
from sagemaker.serve import Mode, ModelServer
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
from sagemaker.serve.utils.telemetry_logger import (
_send_telemetry,
_capture_telemetry,
Expand All @@ -32,9 +33,13 @@
"763104351884.dkr.ecr.us-east-1.amazonaws.com/"
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
)
MOCK_PYTORCH_CONTAINER = (
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310"
)
MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf"
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test"
MOCK_MODEL_METADATA_FOR_MLFLOW = {MLFLOW_MODEL_PATH: "s3://some_path"}


class ModelBuilderMock:
Expand Down Expand Up @@ -239,3 +244,34 @@ def test_construct_url_with_failure_reason_and_extra_info(self):
f"&x-extra={mock_extra_info}"
)
self.assertEquals(ret_url, expected_base_url)

@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
mock_model_builder = ModelBuilderMock()
mock_model_builder.serve_settings.telemetry_opt_out = False
mock_model_builder.image_uri = MOCK_PYTORCH_CONTAINER
mock_model_builder._is_mlflow_model = True
mock_model_builder.model_metadata = MOCK_MODEL_METADATA_FOR_MLFLOW
mock_model_builder._is_custom_image_uri = False
mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT
mock_model_builder.model_server = ModelServer.TORCHSERVE
mock_model_builder.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN

mock_model_builder.mock_deploy()

args = mock_send_telemetry.call_args.args
latency = str(args[5]).split("latency=")[1]
expected_extra_str = (
f"{MOCK_FUNC_NAME}"
"&x-modelServer=1"
"&x-imageTag=pytorch-inference:2.0.1-cpu-py310"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}"
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
f"&x-mlflowModelPathType=2"
f"&x-latency={latency}"
)

mock_send_telemetry.assert_called_once_with(
"1", 3, MOCK_SESSION, None, None, expected_extra_str
)
Loading