Skip to content

Prevent threads from being stuck in DynamicBatcher #1915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 5, 2021
39 changes: 18 additions & 21 deletions pkg/cortex/serve/cortex_internal/lib/api/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import threading as td
import time
import traceback
Expand All @@ -32,16 +33,14 @@ def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval
self.batch_max_size = max_batch_size
self.batch_interval = batch_interval # measured in seconds

# waiter prevents new threads from modifying the input batch while a batch prediction is in progress
self.waiter = td.Event()
self.waiter.set()

self.barrier = td.Barrier(self.batch_max_size + 1, action=self.waiter.clear)
self.barrier = td.Barrier(self.batch_max_size + 1)

self.samples = {}
self.predictions = {}
td.Thread(target=self._batch_engine).start()

self.thread_id_generator = itertools.count()

def _batch_engine(self):
while True:
if len(self.predictions) > 0:
Expand All @@ -54,42 +53,40 @@ def _batch_engine(self):
pass

self.predictions = {}
thread_ids = list(self.samples.keys())

try:
if self.samples:
batch = self._make_batch(self.samples)
batch = self._make_batch(thread_ids)

predictions = self.predictor_impl.predict(**batch)
if not isinstance(predictions, list):
raise UserRuntimeException(
f"please return a list when using server side batching, got {type(predictions)}"
)

self.predictions = dict(zip(self.samples.keys(), predictions))
self.predictions = dict(zip(thread_ids, predictions))
except Exception as e:
self.predictions = {thread_id: e for thread_id in self.samples}
self.predictions = {thread_id: e for thread_id in thread_ids}
logger.error(traceback.format_exc())
finally:
self.samples = {}
for thread_id in thread_ids:
del self.samples[thread_id]
self.barrier.reset()
self.waiter.set()

@staticmethod
def _make_batch(samples: Dict[int, Dict[str, Any]]) -> Dict[str, List[Any]]:
def _make_batch(self, thread_ids: List[int]) -> Dict[str, List[Any]]:
batched_samples = defaultdict(list)
for thread_id in samples:
for key, sample in samples[thread_id].items():
for thread_id in thread_ids:
for key, sample in self.samples[thread_id].items():
batched_samples[key].append(sample)

return dict(batched_samples)

def _enqueue_request(self, **kwargs):
def _enqueue_request(self, thread_id: int, **kwargs):
"""
Enqueue sample for batch inference. This is a blocking method.
"""
thread_id = td.get_ident()

self.waiter.wait()
self.samples[thread_id] = kwargs
try:
self.barrier.wait()
Expand All @@ -101,15 +98,15 @@ def predict(self, **kwargs):
Queues a request to be batched with other incoming request, waits for the response
and returns the prediction result. This is a blocking method.
"""
self._enqueue_request(**kwargs)
prediction = self._get_prediction()
thread_id = next(self.thread_id_generator)
self._enqueue_request(thread_id, **kwargs)
prediction = self._get_prediction(thread_id)
return prediction

def _get_prediction(self) -> Any:
def _get_prediction(self, thread_id: int) -> Any:
"""
Return the prediction. This is a blocking method.
"""
thread_id = td.get_ident()
while thread_id not in self.predictions:
time.sleep(0.001)

Expand Down