Skip to content

Commit 74bc3e1

Browse files
committed
Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (aws#4662)
* Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Initial commit for lineage impl * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Add integ tests and uts * fix local mode for tf_serving * Allow lineage tracking only in sagemaker endpoint mode * fix regex pattern * fix style issues * fix regex pattern and hard coded py version in ut * fix missing session * Resolve pr comments and fix regex for mlflow registry and ids
1 parent 9b85ab6 commit 74bc3e1

33 files changed

+2281
-32
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ onnx>=1.15.0
3636
nbformat>=5.9,<6
3737
accelerate>=0.24.1,<=0.27.0
3838
schema==0.7.5
39+
tensorflow>=2.1,<=2.16

src/sagemaker/serve/builder/model_builder.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
3030
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
3131
from sagemaker.serve.builder.schema_builder import SchemaBuilder
32+
from sagemaker.serve.builder.tf_serving_builder import TensorflowServing
3233
from sagemaker.serve.mode.function_pointers import Mode
3334
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
3435
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
@@ -60,6 +61,7 @@
6061
from sagemaker.serve.spec.inference_spec import InferenceSpec
6162
from sagemaker.serve.utils import task
6263
from sagemaker.serve.utils.exceptions import TaskNotFoundException
64+
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
6365
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
6466
from sagemaker.serve.utils.hardware_detector import (
6567
_get_gpu_info,
@@ -91,12 +93,13 @@
9193
ModelServer.TRITON,
9294
ModelServer.DJL_SERVING,
9395
ModelServer.FASTAPI
96+
ModelServer.TENSORFLOW_SERVING,
9497
}
9598

9699

97-
# pylint: disable=attribute-defined-outside-init, disable=E1101
100+
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901
98101
@dataclass
99-
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, FastAPIServe):
102+
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, FastAPIServe):
100103
"""Class that builds a deployable model.
101104
102105
Args:
@@ -495,6 +498,12 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
495498
self.pysdk_model.model_package_arn = new_model_package.model_package_arn
496499
new_model_package.deploy = self._model_builder_deploy_model_package_wrapper
497500
self.model_package = new_model_package
501+
if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
502+
_maintain_lineage_tracking_for_mlflow_model(
503+
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
504+
s3_upload_path=self.s3_upload_path,
505+
sagemaker_session=self.sagemaker_session,
506+
)
498507
return new_model_package
499508

500509
def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs):
@@ -553,12 +562,19 @@ def _model_builder_deploy_wrapper(
553562

554563
if "endpoint_logging" not in kwargs:
555564
kwargs["endpoint_logging"] = True
556-
return self._original_deploy(
565+
predictor = self._original_deploy(
557566
*args,
558567
instance_type=instance_type,
559568
initial_instance_count=initial_instance_count,
560569
**kwargs,
561570
)
571+
if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
572+
_maintain_lineage_tracking_for_mlflow_model(
573+
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
574+
s3_upload_path=self.s3_upload_path,
575+
sagemaker_session=self.sagemaker_session,
576+
)
577+
return predictor
562578

563579
def _overwrite_mode_in_deploy(self, overwrite_mode: str):
564580
"""Mode overwritten by customer during model.deploy()"""
@@ -730,7 +746,7 @@ def build( # pylint: disable=R0911
730746
" for production at this moment."
731747
)
732748
self._initialize_for_mlflow()
733-
_validate_input_for_mlflow(self.model_server)
749+
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
734750

735751
if isinstance(self.model, str):
736752
model_task = None
@@ -775,6 +791,9 @@ def build( # pylint: disable=R0911
775791
if self.model_server == ModelServer.TRITON:
776792
return self._build_for_triton()
777793

794+
if self.model_server == ModelServer.TENSORFLOW_SERVING:
795+
return self._build_for_tensorflow_serving()
796+
778797
raise ValueError("%s model server is not supported" % self.model_server)
779798

780799
def save(
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds mixin logic to support deployment of Model ID"""
14+
from __future__ import absolute_import
15+
import logging
16+
import os
17+
from pathlib import Path
18+
from abc import ABC, abstractmethod
19+
20+
from sagemaker import Session
21+
from sagemaker.serve.detector.pickler import save_pkl
22+
from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving
23+
from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor
24+
25+
logger = logging.getLogger(__name__)
26+
27+
_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py"
28+
_CODE_FOLDER = "code"
29+
30+
31+
# pylint: disable=attribute-defined-outside-init, disable=E1101
32+
class TensorflowServing(ABC):
33+
"""TensorflowServing build logic for ModelBuilder()"""
34+
35+
def __init__(self):
36+
self.model = None
37+
self.serve_settings = None
38+
self.sagemaker_session = None
39+
self.model_path = None
40+
self.dependencies = None
41+
self.modes = None
42+
self.mode = None
43+
self.model_server = None
44+
self.image_uri = None
45+
self._is_custom_image_uri = False
46+
self.image_config = None
47+
self.vpc_config = None
48+
self._original_deploy = None
49+
self.secret_key = None
50+
self.engine = None
51+
self.pysdk_model = None
52+
self.schema_builder = None
53+
self.env_vars = None
54+
55+
@abstractmethod
56+
def _prepare_for_mode(self):
57+
"""Prepare model artifacts based on mode."""
58+
59+
@abstractmethod
60+
def _get_client_translators(self):
61+
"""Set up client marshaller based on schema builder."""
62+
63+
def _save_schema_builder(self):
64+
"""Save schema builder for tensorflow serving."""
65+
if not os.path.exists(self.model_path):
66+
os.makedirs(self.model_path)
67+
68+
code_path = Path(self.model_path).joinpath("code")
69+
save_pkl(code_path, self.schema_builder)
70+
71+
def _get_tensorflow_predictor(
72+
self, endpoint_name: str, sagemaker_session: Session
73+
) -> TensorFlowPredictor:
74+
"""Creates a TensorFlowPredictor object"""
75+
serializer, deserializer = self._get_client_translators()
76+
77+
return TensorFlowPredictor(
78+
endpoint_name=endpoint_name,
79+
sagemaker_session=sagemaker_session,
80+
serializer=serializer,
81+
deserializer=deserializer,
82+
)
83+
84+
def _validate_for_tensorflow_serving(self):
85+
"""Validate for tensorflow serving"""
86+
if not getattr(self, "_is_mlflow_model", False):
87+
raise ValueError("Tensorflow Serving is currently only supported for mlflow models.")
88+
89+
def _create_tensorflow_model(self):
90+
"""Creates a TensorFlow model object"""
91+
self.pysdk_model = TensorFlowModel(
92+
image_uri=self.image_uri,
93+
image_config=self.image_config,
94+
vpc_config=self.vpc_config,
95+
model_data=self.s3_upload_path,
96+
role=self.serve_settings.role_arn,
97+
env=self.env_vars,
98+
sagemaker_session=self.sagemaker_session,
99+
predictor_cls=self._get_tensorflow_predictor,
100+
)
101+
102+
self.pysdk_model.mode = self.mode
103+
self.pysdk_model.modes = self.modes
104+
self.pysdk_model.serve_settings = self.serve_settings
105+
106+
self._original_deploy = self.pysdk_model.deploy
107+
self.pysdk_model.deploy = self._model_builder_deploy_wrapper
108+
self._original_register = self.pysdk_model.register
109+
self.pysdk_model.register = self._model_builder_register_wrapper
110+
self.model_package = None
111+
return self.pysdk_model
112+
113+
def _build_for_tensorflow_serving(self):
114+
"""Build the model for Tensorflow Serving"""
115+
self._validate_for_tensorflow_serving()
116+
self._save_schema_builder()
117+
118+
if not self.image_uri:
119+
raise ValueError("image_uri is not set for tensorflow serving")
120+
121+
self.secret_key = prepare_for_tf_serving(
122+
model_path=self.model_path,
123+
shared_libs=self.shared_libs,
124+
dependencies=self.dependencies,
125+
)
126+
127+
self._prepare_for_mode()
128+
129+
return self._create_tensorflow_model()

src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import docker
1212

1313
from sagemaker.base_predictor import PredictorBase
14+
from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing
1415
from sagemaker.serve.spec.inference_spec import InferenceSpec
1516
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1617
from sagemaker.serve.utils.logging_agent import pull_logs
@@ -34,7 +35,18 @@
3435
)
3536

3637

38+
<<<<<<< HEAD
3739
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalFastApi, LocalMultiModelServer):
40+
=======
41+
class LocalContainerMode(
42+
LocalTorchServe,
43+
LocalDJLServing,
44+
LocalTritonServer,
45+
LocalTgiServing,
46+
LocalMultiModelServer,
47+
LocalTensorflowServing,
48+
):
49+
>>>>>>> a5c6229b0 (Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (#4662))
3850
"""A class that holds methods to deploy model to a container in local environment"""
3951

4052
def __init__(
@@ -140,6 +152,15 @@ def create_server(
140152
env_vars=env_vars if env_vars else self.env_vars,
141153
)
142154
self._ping_container = self._multi_model_server_deep_ping
155+
elif self.model_server == ModelServer.TENSORFLOW_SERVING:
156+
self._start_tensorflow_serving(
157+
client=self.client,
158+
image=image,
159+
model_path=model_path if model_path else self.model_path,
160+
secret_key=secret_key,
161+
env_vars=env_vars if env_vars else self.env_vars,
162+
)
163+
self._ping_container = self._tensorflow_serving_deep_ping
143164
elif self.model_server == ModelServer.FASTAPI:
144165
self._start_fast_api(
145166
client=self.client,
@@ -150,6 +171,7 @@ def create_server(
150171
)
151172
self._ping_container = self._fastapi_deep_ping
152173

174+
153175
# allow some time for container to be ready
154176
time.sleep(10)
155177

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from typing import Type
88

9+
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
910
from sagemaker.session import Session
1011
from sagemaker.serve.utils.types import ModelServer
1112
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -26,6 +27,7 @@ class SageMakerEndpointMode(
2627
SageMakerTgiServing,
2728
SageMakerMultiModelServer,
2829
SageMakerFastApi,
30+
SageMakerTensorflowServing,
2931
):
3032
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
3133

@@ -118,4 +120,13 @@ def prepare(
118120
image=image,
119121
)
120122

123+
if self.model_server == ModelServer.TENSORFLOW_SERVING:
124+
return self._upload_tensorflow_serving_artifacts(
125+
model_path=model_path,
126+
sagemaker_session=sagemaker_session,
127+
secret_key=secret_key,
128+
s3_model_data_url=s3_model_data_url,
129+
image=image,
130+
)
131+
121132
raise ValueError("%s model server is not supported" % self.model_server)

src/sagemaker/serve/model_format/mlflow/constants.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
"py39": "1.13.1",
2020
"py310": "2.2.0",
2121
}
22+
MODEL_PACAKGE_ARN_REGEX = (
23+
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/[" r"a-zA-Z0-9\-_\/\.]+$"
24+
)
25+
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
26+
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
27+
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+\/[a-zA-Z0-9\-_\/\.]*$"
2228
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
2329
MLFLOW_METADATA_FILE = "MLmodel"
2430
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
@@ -34,8 +40,12 @@
3440
"spark": "pyspark",
3541
"onnx": "onnxruntime",
3642
}
37-
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf
38-
"sklearn",
39-
"pytorch",
40-
"xgboost",
41-
]
43+
TENSORFLOW_SAVED_MODEL_NAME = "saved_model.pb"
44+
FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = {
45+
"sklearn": "sklearn",
46+
"pytorch": "pytorch",
47+
"xgboost": "xgboost",
48+
"tensorflow": "tensorflow",
49+
"keras": "tensorflow",
50+
}
51+
FLAVORS_DEFAULT_WITH_TF_SERVING = ["keras", "tensorflow"]

0 commit comments

Comments
 (0)