Skip to content

Prevent threads from being stuck in DynamicBatcher #1915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 5, 2021
Merged
4 changes: 4 additions & 0 deletions images/test/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ RUN pip install --upgrade pip && \
COPY pkg /src
COPY images/test/run.sh /src/run.sh

COPY pkg/cortex/serve/log_config.yaml /src/cortex/serve/log_config.yaml
ENV CORTEX_LOG_LEVEL DEBUG
ENV CORTEX_LOG_CONFIG_FILE /src/cortex/serve/log_config.yaml

RUN pip install --no-deps /src/cortex/serve/ && \
rm -rf /root/.cache/pip*

Expand Down
6 changes: 6 additions & 0 deletions images/test/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
err=0
trap 'err=1' ERR

function substitute_env_vars() {
file_to_run_substitution=$1
python -c "from cortex_internal.lib import util; import os; util.expand_environment_vars_on_file('$file_to_run_substitution')"
}

substitute_env_vars $CORTEX_LOG_CONFIG_FILE
pytest lib/test

test $err = 0
1 change: 1 addition & 0 deletions pkg/cortex/serve/cortex_internal.requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
grpcio==1.32.0
boto3==1.14.53
google-cloud-storage==1.32.0
datadog==0.39.0
Expand Down
72 changes: 42 additions & 30 deletions pkg/cortex/serve/cortex_internal/lib/api/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import threading as td
import time
import traceback
Expand All @@ -21,26 +22,32 @@

from starlette.responses import Response

from ..exceptions import UserRuntimeException
from ..log import logger
from cortex_internal.lib.exceptions import UserRuntimeException
from cortex_internal.lib.log import logger


class DynamicBatcher:
def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval: int):
def __init__(
self,
predictor_impl: Callable,
max_batch_size: int,
batch_interval: int,
test_mode: bool = False,
):
self.predictor_impl = predictor_impl

self.batch_max_size = max_batch_size
self.batch_interval = batch_interval # measured in seconds
self.test_mode = test_mode # only for unit testing
self._test_batch_lengths = [] # only when unit testing

# waiter prevents new threads from modifying the input batch while a batch prediction is in progress
self.waiter = td.Event()
self.waiter.set()

self.barrier = td.Barrier(self.batch_max_size + 1, action=self.waiter.clear)
self.barrier = td.Barrier(self.batch_max_size + 1)

self.samples = {}
self.predictions = {}
td.Thread(target=self._batch_engine).start()
td.Thread(target=self._batch_engine, daemon=True).start()

self.sample_id_generator = itertools.count()

def _batch_engine(self):
while True:
Expand All @@ -54,43 +61,48 @@ def _batch_engine(self):
pass

self.predictions = {}

sample_ids = self._get_sample_ids(self.batch_max_size)
try:
if self.samples:
batch = self._make_batch(self.samples)
batch = self._make_batch(sample_ids)

predictions = self.predictor_impl.predict(**batch)
if not isinstance(predictions, list):
raise UserRuntimeException(
f"please return a list when using server side batching, got {type(predictions)}"
)

self.predictions = dict(zip(self.samples.keys(), predictions))
if self.test_mode:
self._test_batch_lengths.append(len(predictions))

self.predictions = dict(zip(sample_ids, predictions))
except Exception as e:
self.predictions = {thread_id: e for thread_id in self.samples}
self.predictions = {sample_id: e for sample_id in sample_ids}
logger.error(traceback.format_exc())
finally:
self.samples = {}
for sample_id in sample_ids:
del self.samples[sample_id]
self.barrier.reset()
self.waiter.set()

@staticmethod
def _make_batch(samples: Dict[int, Dict[str, Any]]) -> Dict[str, List[Any]]:
def _get_sample_ids(self, max_number: int) -> List[int]:
if len(self.samples) <= max_number:
return list(self.samples.keys())
return sorted(self.samples)[:max_number]

def _make_batch(self, sample_ids: List[int]) -> Dict[str, List[Any]]:
batched_samples = defaultdict(list)
for thread_id in samples:
for key, sample in samples[thread_id].items():
for sample_id in sample_ids:
for key, sample in self.samples[sample_id].items():
batched_samples[key].append(sample)

return dict(batched_samples)

def _enqueue_request(self, **kwargs):
def _enqueue_request(self, sample_id: int, **kwargs):
"""
Enqueue sample for batch inference. This is a blocking method.
"""
thread_id = td.get_ident()

self.waiter.wait()
self.samples[thread_id] = kwargs
self.samples[sample_id] = kwargs
try:
self.barrier.wait()
except td.BrokenBarrierError:
Expand All @@ -101,20 +113,20 @@ def predict(self, **kwargs):
Queues a request to be batched with other incoming request, waits for the response
and returns the prediction result. This is a blocking method.
"""
self._enqueue_request(**kwargs)
prediction = self._get_prediction()
sample_id = next(self.sample_id_generator)
self._enqueue_request(sample_id, **kwargs)
prediction = self._get_prediction(sample_id)
return prediction

def _get_prediction(self) -> Any:
def _get_prediction(self, sample_id: int) -> Any:
"""
Return the prediction. This is a blocking method.
"""
thread_id = td.get_ident()
while thread_id not in self.predictions:
while sample_id not in self.predictions:
time.sleep(0.001)

prediction = self.predictions[thread_id]
del self.predictions[thread_id]
prediction = self.predictions[sample_id]
del self.predictions[sample_id]

if isinstance(prediction, Exception):
return Response(
Expand Down
112 changes: 112 additions & 0 deletions pkg/cortex/serve/cortex_internal/lib/test/dynamic_batching_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2021 Cortex Labs, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import threading as td
import itertools
import time

import cortex_internal.lib.api.batching as batching


class Predictor:
def predict(self, payload):
time.sleep(0.2)
return payload


def test_dynamic_batching_while_hitting_max_batch_size():
max_batch_size = 32
dynamic_batcher = batching.DynamicBatcher(
Predictor(), max_batch_size=max_batch_size, batch_interval=0.1, test_mode=True
)
counter = itertools.count(1)
event = td.Event()
global_list = []

def submitter():
while not event.is_set():
global_list.append(dynamic_batcher.predict(payload=next(counter)))
time.sleep(0.1)

running_threads = []
for _ in range(128):
thread = td.Thread(target=submitter, daemon=True)
thread.start()
running_threads.append(thread)

time.sleep(60)
event.set()

# if this fails, then the submitter threads are getting stuck
for thread in running_threads:
thread.join(3.0)
if thread.is_alive():
raise TimeoutError("thread", thread.getName(), "got stuck")

sum1 = int(len(global_list) * (len(global_list) + 1) / 2)
sum2 = sum(global_list)
assert sum1 == sum2

# get the last 80% of batch lengths
# we ignore the first 20% because it may take some time for all threads to start making requests
batch_lengths = dynamic_batcher._test_batch_lengths
batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :]

# verify that the batch size is always equal to the max batch size
assert len(set(batch_lengths)) == 1
assert max_batch_size in batch_lengths


def test_dynamic_batching_while_hitting_max_interval():
max_batch_size = 32
dynamic_batcher = batching.DynamicBatcher(
Predictor(), max_batch_size=max_batch_size, batch_interval=1.0, test_mode=True
)
counter = itertools.count(1)
event = td.Event()
global_list = []

def submitter():
while not event.is_set():
global_list.append(dynamic_batcher.predict(payload=next(counter)))
time.sleep(0.1)

running_threads = []
for _ in range(2):
thread = td.Thread(target=submitter, daemon=True)
thread.start()
running_threads.append(thread)

time.sleep(30)
event.set()

# if this fails, then the submitter threads are getting stuck
for thread in running_threads:
thread.join(3.0)
if thread.is_alive():
raise TimeoutError("thread", thread.getName(), "got stuck")

sum1 = int(len(global_list) * (len(global_list) + 1) / 2)
sum2 = sum(global_list)
assert sum1 == sum2

# get the last 80% of batch lengths
# we ignore the first 20% because it may take some time for all threads to start making requests
batch_lengths = dynamic_batcher._test_batch_lengths
batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :]

# verify that the batch size is always equal to the number of running threads
assert len(set(batch_lengths)) == 1
assert len(running_threads) in batch_lengths