Skip to content

Commit a4043be

Browse files
more robust usage tracker in async (#8329)
1 parent ef244c0 commit a4043be

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

dspy/primitives/program.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import magicattr
44

5-
from dspy.dsp.utils.settings import settings
5+
from dspy.dsp.utils.settings import settings, thread_local_overrides
66
from dspy.predict.parallel import Parallel
77
from dspy.primitives.module import BaseModule
88
from dspy.utils.callback import with_callbacks
@@ -51,7 +51,7 @@ def __call__(self, *args, **kwargs):
5151
caller_modules.append(self)
5252

5353
with settings.context(caller_modules=caller_modules):
54-
if settings.track_usage and settings.usage_tracker is None:
54+
if settings.track_usage and thread_local_overrides.get().get("usage_tracker") is None:
5555
with track_usage() as usage_tracker:
5656
output = self.forward(*args, **kwargs)
5757
output.set_lm_usage(usage_tracker.get_total_tokens())
@@ -66,7 +66,7 @@ async def acall(self, *args, **kwargs):
6666
caller_modules.append(self)
6767

6868
with settings.context(caller_modules=caller_modules):
69-
if settings.track_usage and settings.usage_tracker is None:
69+
if settings.track_usage and thread_local_overrides.get().get("usage_tracker") is None:
7070
with track_usage() as usage_tracker:
7171
output = await self.aforward(*args, **kwargs)
7272
output.set_lm_usage(usage_tracker.get_total_tokens())

tests/primitives/test_module.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77
from litellm import Choices, Message, ModelResponse
8+
from litellm.types.utils import Usage
9+
import asyncio
810

911
import dspy
1012
from dspy.utils.dummies import DummyLM
@@ -307,6 +309,54 @@ def __call__(self, question: str) -> str:
307309
assert results[1].get_lm_usage().keys() == set(["openai/gpt-3.5-turbo"])
308310

309311

312+
@pytest.mark.asyncio
313+
async def test_usage_tracker_async_parallel():
314+
program = dspy.Predict("question -> answer")
315+
316+
with patch("litellm.acompletion") as mock_completion:
317+
mock_completion.return_value = ModelResponse(
318+
choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
319+
usage=Usage(
320+
**{
321+
"prompt_tokens": 1117,
322+
"completion_tokens": 46,
323+
"total_tokens": 1163,
324+
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
325+
"completion_tokens_details": {
326+
"reasoning_tokens": 0,
327+
"audio_tokens": 0,
328+
"accepted_prediction_tokens": 0,
329+
"rejected_prediction_tokens": 0,
330+
},
331+
},
332+
),
333+
model="openai/gpt-4o-mini",
334+
)
335+
336+
coroutines = [
337+
program.acall(question="What is the capital of France?"),
338+
program.acall(question="What is the capital of France?"),
339+
program.acall(question="What is the capital of France?"),
340+
program.acall(question="What is the capital of France?"),
341+
]
342+
with dspy.settings.context(
343+
lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True, adapter=dspy.JSONAdapter()
344+
):
345+
results = await asyncio.gather(*coroutines)
346+
347+
assert results[0].get_lm_usage() is not None
348+
assert results[1].get_lm_usage() is not None
349+
350+
lm_usage0 = results[0].get_lm_usage()["openai/gpt-4o-mini"]
351+
lm_usage1 = results[1].get_lm_usage()["openai/gpt-4o-mini"]
352+
assert lm_usage0["prompt_tokens"] == 1117
353+
assert lm_usage1["prompt_tokens"] == 1117
354+
assert lm_usage0["completion_tokens"] == 46
355+
assert lm_usage1["completion_tokens"] == 46
356+
assert lm_usage0["total_tokens"] == 1163
357+
assert lm_usage1["total_tokens"] == 1163
358+
359+
310360
def test_module_history():
311361
class MyProgram(dspy.Module):
312362
def __init__(self, **kwargs):

0 commit comments

Comments
 (0)