|
| 1 | +"""Module for Local TEI Serving""" |
| 2 | + |
| 3 | +from __future__ import absolute_import |
| 4 | + |
| 5 | +import requests |
| 6 | +import logging |
| 7 | +from pathlib import Path |
| 8 | +from docker.types import DeviceRequest |
| 9 | +from sagemaker import Session, fw_utils |
| 10 | +from sagemaker.serve.utils.exceptions import LocalModelInvocationException |
| 11 | +from sagemaker.base_predictor import PredictorBase |
| 12 | +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join |
| 13 | +from sagemaker.s3 import S3Uploader |
| 14 | +from sagemaker.local.utils import get_docker_host |
| 15 | + |
| 16 | + |
| 17 | +MODE_DIR_BINDING = "/opt/ml/model/" |
| 18 | +_SHM_SIZE = "2G" |
| 19 | +_DEFAULT_ENV_VARS = { |
| 20 | + "TRANSFORMERS_CACHE": "/opt/ml/model/", |
| 21 | + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", |
| 22 | +} |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class LocalTeiServing: |
| 28 | + """LocalTeiServing class""" |
| 29 | + |
| 30 | + def _start_tei_serving( |
| 31 | + self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict |
| 32 | + ): |
| 33 | + """Starts a local tei serving container. |
| 34 | +
|
| 35 | + Args: |
| 36 | + client: Docker client |
| 37 | + image: Image to use |
| 38 | + model_path: Path to the model |
| 39 | + secret_key: Secret key to use for authentication |
| 40 | + env_vars: Environment variables to set |
| 41 | + """ |
| 42 | + if env_vars and secret_key: |
| 43 | + env_vars['SAGEMAKER_SERVE_SECRET_KEY'] = secret_key |
| 44 | + |
| 45 | + self.container = client.containers.run( |
| 46 | + image, |
| 47 | + shm_size=_SHM_SIZE, |
| 48 | + device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])], |
| 49 | + network_mode="host", |
| 50 | + detach=True, |
| 51 | + auto_remove=True, |
| 52 | + volumes={ |
| 53 | + Path(model_path).joinpath("code"): { |
| 54 | + "bind": MODE_DIR_BINDING, |
| 55 | + "mode": "rw", |
| 56 | + }, |
| 57 | + }, |
| 58 | + environment=_update_env_vars(env_vars), |
| 59 | + ) |
| 60 | + |
| 61 | + def _invoke_tei_serving(self, request: object, content_type: str, accept: str): |
| 62 | + """Invokes a local tei serving container. |
| 63 | +
|
| 64 | + Args: |
| 65 | + request: Request to send |
| 66 | + content_type: Content type to use |
| 67 | + accept: Accept to use |
| 68 | + """ |
| 69 | + try: |
| 70 | + response = requests.post( |
| 71 | + f"http://{get_docker_host()}:8080/invocations", |
| 72 | + data=request, |
| 73 | + headers={"Content-Type": content_type, "Accept": accept}, |
| 74 | + timeout=600, |
| 75 | + ) |
| 76 | + response.raise_for_status() |
| 77 | + return response.content |
| 78 | + except Exception as e: |
| 79 | + raise Exception("Unable to send request to the local container server") from e |
| 80 | + |
| 81 | + def _tei_deep_ping(self, predictor: PredictorBase): |
| 82 | + """Checks if the local tei serving container is up and running. |
| 83 | +
|
| 84 | + If the container is not up and running, it will raise an exception. |
| 85 | + """ |
| 86 | + response = None |
| 87 | + try: |
| 88 | + response = predictor.predict(self.schema_builder.sample_input) |
| 89 | + return (True, response) |
| 90 | + # pylint: disable=broad-except |
| 91 | + except Exception as e: |
| 92 | + if "422 Client Error: Unprocessable Entity for url" in str(e): |
| 93 | + raise LocalModelInvocationException(str(e)) |
| 94 | + return (False, response) |
| 95 | + |
| 96 | + return (True, response) |
| 97 | + |
| 98 | + |
| 99 | +class SageMakerTeiServing: |
| 100 | + """SageMakerTeiServing class""" |
| 101 | + |
| 102 | + def _upload_tei_artifacts( |
| 103 | + self, |
| 104 | + model_path: str, |
| 105 | + sagemaker_session: Session, |
| 106 | + s3_model_data_url: str = None, |
| 107 | + image: str = None, |
| 108 | + env_vars: dict = None, |
| 109 | + ): |
| 110 | + """Uploads the model artifacts to S3. |
| 111 | +
|
| 112 | + Args: |
| 113 | + model_path: Path to the model |
| 114 | + sagemaker_session: SageMaker session |
| 115 | + s3_model_data_url: S3 model data URL |
| 116 | + image: Image to use |
| 117 | + env_vars: Environment variables to set |
| 118 | + """ |
| 119 | + if s3_model_data_url: |
| 120 | + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) |
| 121 | + else: |
| 122 | + bucket, key_prefix = None, None |
| 123 | + |
| 124 | + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) |
| 125 | + |
| 126 | + bucket, code_key_prefix = determine_bucket_and_prefix( |
| 127 | + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session |
| 128 | + ) |
| 129 | + |
| 130 | + code_dir = Path(model_path).joinpath("code") |
| 131 | + |
| 132 | + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") |
| 133 | + |
| 134 | + logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) |
| 135 | + |
| 136 | + model_data_url = S3Uploader.upload( |
| 137 | + str(code_dir), |
| 138 | + s3_location, |
| 139 | + None, |
| 140 | + sagemaker_session, |
| 141 | + ) |
| 142 | + |
| 143 | + model_data = { |
| 144 | + "S3DataSource": { |
| 145 | + "CompressionType": "None", |
| 146 | + "S3DataType": "S3Prefix", |
| 147 | + "S3Uri": model_data_url + "/", |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + return (model_data, _update_env_vars(env_vars)) |
| 152 | + |
| 153 | + |
| 154 | +def _update_env_vars(env_vars: dict) -> dict: |
| 155 | + """Placeholder docstring""" |
| 156 | + updated_env_vars = {} |
| 157 | + updated_env_vars.update(_DEFAULT_ENV_VARS) |
| 158 | + if env_vars: |
| 159 | + updated_env_vars.update(env_vars) |
| 160 | + return updated_env_vars |
0 commit comments