Skip to content

Commit f7aa6f8

Browse files
ulrichkrAshwin Krishna
authored andcommitted
feat: jumpstart llm load test benchmarking (#277)
* feat: initial jumpstart llm benchmarking commit * fix: throughput robustness and pricing * feat: jumpstart benchmarking, deploy endpoint and model specs * feat: jumpstart benchmarking generalize concurrency probe * fix: jumpstart latency benchmarking logging changes * fix: adjust throughput computations * fix: variety of cleanup to jumpstart llm benchmarking * chore: clean up notebooks * chore: black * chore: cleanup jumpstart inference benchmarking * fix: adjust error logging in concurrency probe * chore: concurrency probe finalization * chore: clean up notebooks * chore: add tranformers requirement install * chore: black * chore: grammar * chore: pip install change * chore: adjust load metrics for missing token statistics
1 parent a7cc547 commit f7aa6f8

File tree

11 files changed

+1414
-0
lines changed

11 files changed

+1414
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.json

introduction_to_amazon_algorithms/jumpstart-foundation-models/text-generation-benchmarking/benchmarking/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
import boto3
5+
6+
7+
SERVICE_CODE = "AmazonSageMaker"
8+
PRODUCT_FAMILY = "ML Instance"
9+
PRODUCT_FAMILY_KEY = "productFamily"
10+
PRICING_SERVICE_API_REGION = "us-east-1" # All pricing APIs are hosted in IAD
11+
REGION_KEY = "regionCode"
12+
INSTANCE_NAME_KEY = "instanceName"
13+
PLATO_INSTANCE_TYPE_KEY = "platoinstancetype"
14+
PLATO_INSTANCE_TYPE = "Hosting"
15+
16+
17+
def _create_pricing_filter(type: str, field: str, value: str) -> Dict[str, str]:
18+
return {"Type": type, "Field": field, "Value": value}
19+
20+
21+
class PricingClient:
22+
"""Boto3 client to access AWS Pricing."""
23+
24+
def __init__(self) -> None:
25+
"""Creates the boto3 client for AWS pricing."""
26+
self._client = boto3.client(service_name="pricing", region_name=PRICING_SERVICE_API_REGION)
27+
28+
def get_price_per_unit(self, instance_type: str, region: str) -> float:
29+
"""Returns the price per unit in USD of a SageMaker machine learning instance in a region."""
30+
filters = [
31+
_create_pricing_filter(type="TERM_MATCH", field=PRODUCT_FAMILY_KEY, value=PRODUCT_FAMILY),
32+
_create_pricing_filter(type="TERM_MATCH", field=REGION_KEY, value=region),
33+
_create_pricing_filter(type="TERM_MATCH", field=INSTANCE_NAME_KEY, value=instance_type),
34+
_create_pricing_filter(
35+
type="TERM_MATCH",
36+
field=PLATO_INSTANCE_TYPE_KEY,
37+
value=PLATO_INSTANCE_TYPE,
38+
),
39+
]
40+
response = self._client.get_products(ServiceCode=SERVICE_CODE, Filters=filters)
41+
price_list = json.loads(response["PriceList"][0])["terms"]["OnDemand"]
42+
price_dimensions = list(price_list.values())[0]["priceDimensions"]
43+
price_per_unit = list(price_dimensions.values())[0]["pricePerUnit"]["USD"]
44+
return float(price_per_unit)
45+
46+
47+
class SageMakerClient:
48+
"""Boto3 SageMaker client to access endpoint and model information."""
49+
50+
def __init__(self) -> None:
51+
self._client = boto3.client("sagemaker")
52+
53+
def describe_endpoint_config(self, endpoint_config_name: str) -> Dict[str, Any]:
54+
return self._client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
55+
56+
def describe_endpoint(self, endpoint_name: str) -> Dict[str, Any]:
57+
return self._client.describe_endpoint(EndpointName=endpoint_name)
58+
59+
def describe_model(self, endpoint_name: str) -> Dict[str, Any]:
60+
endpoint_config = self.describe_endpoint_config(endpoint_name)
61+
model_name = endpoint_config["ProductionVariants"][0]["ModelName"]
62+
return self._client.describe_model(ModelName=model_name)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from abc import abstractmethod
2+
from typing import Any, Dict, Optional
3+
4+
from sagemaker.predictor import Predictor
5+
6+
7+
class ConcurrentProbeIteratorBase:
8+
def __init__(self, model_id: str, payload_name: str):
9+
self.model_id = model_id
10+
self.payload_name = payload_name
11+
self.exception: Optional[Exception] = None
12+
self.stop_reason: str = "No stop reason set."
13+
self.result: Dict[str, Any] = None
14+
15+
def __iter__(self) -> "ConcurrentProbeIteratorBase":
16+
return self
17+
18+
@abstractmethod
19+
def __next__(self) -> int:
20+
raise NotImplementedError
21+
22+
def send(self, result: Dict[str, Any], predictor: Predictor) -> bool:
23+
"""Send load test results to the iterator and return whether to use results.
24+
25+
Some iterators may make internal adjustments (e.g., scale endpoint instances and repeat load test for the same
26+
conccurent request setting) before using the results.
27+
"""
28+
self.result = result
29+
return True
30+
31+
32+
class ConcurrentProbeExponentialScalingIterator(ConcurrentProbeIteratorBase):
33+
"""An iterator used during a concurrency probe to exponentially scale concurrent requests."""
34+
35+
def __init__(self, model_id: str, payload_name: str, start: int = 1, scale_factor: float = 2.0) -> None:
36+
self.concurrent_requests = start
37+
self.scale_factor = scale_factor
38+
super().__init__(model_id, payload_name)
39+
40+
def __next__(self) -> int:
41+
if self.exception is not None:
42+
e = self.exception
43+
self.stop_reason = "".join([type(e).__name__, f": {e}" if str(e) else ""])
44+
raise StopIteration
45+
46+
if self.result is None:
47+
return self.concurrent_requests
48+
49+
self.concurrent_requests = int(self.concurrent_requests * self.scale_factor)
50+
51+
return self.concurrent_requests
52+
53+
54+
def num_invocation_scaler(concurrent_requests: int, num_invocation_factor: int = 3) -> int:
55+
return concurrent_requests * num_invocation_factor
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import boto3
2+
from botocore.config import Config
3+
from sagemaker.session import Session
4+
from pathlib import Path
5+
6+
7+
SAVE_METRICS_FILE_PATH = Path.cwd() / "latency_benchmarking.json"
8+
CLOUDWATCH_PERIOD_SECONDS = 60.0
9+
MAX_CONCURRENT_INVOCATIONS_PER_MODEL = 30
10+
MAX_CONCURRENT_BENCHMARKS = 20
11+
RETRY_WAIT_TIME_SECONDS = 30.0
12+
MAX_TOTAL_RETRY_TIME_SECONDS = 120.0
13+
NUM_INVOCATIONS = 10
14+
SM_INVOCATION_TIMEOUT_SECONDS = 60.0
15+
SM_SESSION = Session(
16+
sagemaker_runtime_client=boto3.client(
17+
"sagemaker-runtime",
18+
config=Config(connect_timeout=5, retries={"mode": "standard", "total_max_attempts": 10}),
19+
),
20+
sagemaker_client=boto3.client(
21+
"sagemaker",
22+
config=Config(connect_timeout=5, read_timeout=60, retries={"total_max_attempts": 20}),
23+
),
24+
)

0 commit comments

Comments
 (0)