Skip to content

Commit 576f33b

Browse files
Add cache to async LM call (#8135)
* add cache for async lm calls * Cache async * fix tests
1 parent d41d8d1 commit 576f33b

File tree

4 files changed

+159
-60
lines changed

4 files changed

+159
-60
lines changed

dspy/clients/cache.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import inspect
23
import logging
34
import threading
45
from functools import wraps
@@ -119,14 +120,20 @@ def get(self, request: Dict[str, Any], ignored_args_for_cache_key: Optional[list
119120
response.usage = {}
120121
return response
121122

122-
def put(self, request: Dict[str, Any], value: Any, ignored_args_for_cache_key: Optional[list[str]] = None) -> None:
123+
def put(
124+
self,
125+
request: Dict[str, Any],
126+
value: Any,
127+
ignored_args_for_cache_key: Optional[list[str]] = None,
128+
enable_memory_cache: bool = True,
129+
) -> None:
123130
try:
124131
key = self.cache_key(request, ignored_args_for_cache_key)
125132
except Exception:
126133
logger.debug(f"Failed to generate cache key for request: {request}")
127134
return
128135

129-
if self.enable_memory_cache:
136+
if self.enable_memory_cache and enable_memory_cache:
130137
with self._lock:
131138
self.memory_cache[key] = value
132139

@@ -164,6 +171,7 @@ def load_memory_cache(self, filepath: str) -> None:
164171
def request_cache(
165172
cache_arg_name: Optional[str] = None,
166173
ignored_args_for_cache_key: Optional[list[str]] = ["api_key", "api_base", "base_url"],
174+
enable_memory_cache: bool = True,
167175
*, # everything after this is keyword-only
168176
maxsize: Optional[int] = None, # legacy / no-op
169177
):
@@ -174,6 +182,8 @@ def request_cache(
174182
cache_arg_name: The name of the argument that contains the request. If not provided, the entire kwargs is used
175183
as the request.
176184
ignored_args_for_cache_key: A list of arguments to ignore when computing the cache key from the request.
185+
enable_memory_cache: Whether to enable in-memory cache at call time. If False, the memory cache will not be
186+
written to on new data.
177187
"""
178188

179189
# Deprecation notice
@@ -186,10 +196,7 @@ def request_cache(
186196

187197
def decorator(fn):
188198
@wraps(fn)
189-
def wrapper(*args, **kwargs):
190-
import dspy
191-
192-
cache = dspy.cache
199+
def process_request(args, kwargs):
193200
# Use fully qualified function name for uniqueness
194201
fn_identifier = f"{fn.__module__}.{fn.__qualname__}"
195202

@@ -206,6 +213,15 @@ def wrapper(*args, **kwargs):
206213
modified_request[f"positional_arg_{i}"] = arg
207214
modified_request["_fn_identifier"] = fn_identifier
208215

216+
return modified_request
217+
218+
@wraps(fn)
219+
def sync_wrapper(*args, **kwargs):
220+
import dspy
221+
222+
cache = dspy.cache
223+
modified_request = process_request(args, kwargs)
224+
209225
# Retrieve from cache if available
210226
cached_result = cache.get(modified_request, ignored_args_for_cache_key)
211227

@@ -214,10 +230,32 @@ def wrapper(*args, **kwargs):
214230

215231
# Otherwise, compute and store the result
216232
result = fn(*args, **kwargs)
217-
cache.put(modified_request, result, ignored_args_for_cache_key)
233+
# `enable_memory_cache` can be provided at call time to avoid indefinite growth.
234+
cache.put(modified_request, result, ignored_args_for_cache_key, enable_memory_cache)
235+
236+
return result
237+
238+
@wraps(fn)
239+
async def async_wrapper(*args, **kwargs):
240+
import dspy
241+
242+
cache = dspy.cache
243+
modified_request = process_request(args, kwargs)
244+
245+
# Retrieve from cache if available
246+
cached_result = cache.get(modified_request, ignored_args_for_cache_key)
247+
if cached_result is not None:
248+
return cached_result
249+
250+
# Otherwise, compute and store the result
251+
result = await fn(*args, **kwargs)
252+
cache.put(modified_request, result, ignored_args_for_cache_key, enable_memory_cache)
218253

219254
return result
220255

221-
return wrapper
256+
if inspect.iscoroutinefunction(fn):
257+
return async_wrapper
258+
else:
259+
return sync_wrapper
222260

223261
return decorator

dspy/clients/lm.py

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
max_tokens: The maximum number of tokens to generate per response.
5555
cache: Whether to cache the model responses for reuse to improve performance
5656
and reduce costs.
57-
cache_in_memory: To enable additional caching with LRU in memory.
57+
cache_in_memory (deprecated): To enable additional caching with LRU in memory.
5858
callbacks: A list of callback functions to run before and after each request.
5959
num_retries: The number of times to retry a request if it fails transiently due to
6060
network error, rate limiting, etc. Requests are retried with exponential
@@ -92,44 +92,69 @@ def __init__(
9292
else:
9393
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
9494

95+
def _get_cached_completion_fn(self, completion_fn, cache, enable_memory_cache):
96+
ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
97+
if cache and enable_memory_cache:
98+
completion_fn = request_cache(
99+
cache_arg_name="request",
100+
ignored_args_for_cache_key=ignored_args_for_cache_key,
101+
)(completion_fn)
102+
elif cache:
103+
completion_fn = request_cache(
104+
cache_arg_name="request",
105+
ignored_args_for_cache_key=ignored_args_for_cache_key,
106+
enable_memory_cache=False,
107+
)(completion_fn)
108+
else:
109+
completion_fn = completion_fn
110+
111+
if not cache or litellm.cache is None:
112+
litellm_cache_args = {"no-cache": True, "no-store": True}
113+
else:
114+
litellm_cache_args = {"no-cache": False, "no-store": False}
115+
116+
return completion_fn, litellm_cache_args
117+
95118
def forward(self, prompt=None, messages=None, **kwargs):
96119
# Build the request.
97120
cache = kwargs.pop("cache", self.cache)
98-
# disable cache will also disable in memory cache
99-
cache_in_memory = cache and kwargs.pop("cache_in_memory", self.cache_in_memory)
121+
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)
122+
100123
messages = messages or [{"role": "user", "content": prompt}]
101124
kwargs = {**self.kwargs, **kwargs}
102125

103-
# Make the request and handle LRU & disk caching.
104-
if cache_in_memory:
105-
completion = cached_litellm_completion if self.model_type == "chat" else cached_litellm_text_completion
106-
107-
results = completion(
108-
request=dict(model=self.model, messages=messages, **kwargs),
109-
num_retries=self.num_retries,
110-
)
111-
else:
112-
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
126+
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
127+
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)
113128

114-
results = completion(
115-
request=dict(model=self.model, messages=messages, **kwargs),
116-
num_retries=self.num_retries,
117-
# only leverage LiteLLM cache in this case
118-
cache={"no-cache": not cache, "no-store": not cache},
119-
)
129+
results = completion(
130+
request=dict(model=self.model, messages=messages, **kwargs),
131+
num_retries=self.num_retries,
132+
cache=litellm_cache_args,
133+
)
120134

121135
if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
122136
settings.usage_tracker.add_usage(self.model, dict(results.usage))
123137
return results
124138

125139
async def aforward(self, prompt=None, messages=None, **kwargs):
126-
completion = alitellm_completion if self.model_type == "chat" else alitellm_text_completion
140+
# Build the request.
141+
cache = kwargs.pop("cache", self.cache)
142+
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)
127143

128144
messages = messages or [{"role": "user", "content": prompt}]
145+
kwargs = {**self.kwargs, **kwargs}
146+
147+
completion = alitellm_completion if self.model_type == "chat" else alitellm_text_completion
148+
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)
149+
129150
results = await completion(
130151
request=dict(model=self.model, messages=messages, **kwargs),
131152
num_retries=self.num_retries,
153+
cache=litellm_cache_args,
132154
)
155+
156+
if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
157+
settings.usage_tracker.add_usage(self.model, dict(results.usage))
133158
return results
134159

135160
def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
@@ -206,22 +231,6 @@ def dump_state(self):
206231
return {key: getattr(self, key) for key in state_keys} | self.kwargs
207232

208233

209-
@request_cache(cache_arg_name="request", ignored_args_for_cache_key=["api_key", "api_base", "base_url"])
210-
def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
211-
import litellm
212-
213-
if litellm.cache:
214-
litellm_cache_args = {"no-cache": False, "no-store": False}
215-
else:
216-
litellm_cache_args = {"no-cache": True, "no-store": True}
217-
218-
return litellm_completion(
219-
request,
220-
cache=litellm_cache_args,
221-
num_retries=num_retries,
222-
)
223-
224-
225234
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
226235
retry_kwargs = dict(
227236
retry_policy=_get_litellm_retry_policy(num_retries),
@@ -267,22 +276,6 @@ async def stream_completion():
267276
return stream_completion()
268277

269278

270-
@request_cache(cache_arg_name="request", ignored_args_for_cache_key=["api_key", "api_base", "base_url"])
271-
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
272-
import litellm
273-
274-
if litellm.cache:
275-
litellm_cache_args = {"no-cache": False, "no-store": False}
276-
else:
277-
litellm_cache_args = {"no-cache": True, "no-store": True}
278-
279-
return litellm_text_completion(
280-
request,
281-
num_retries=num_retries,
282-
cache=litellm_cache_args,
283-
)
284-
285-
286279
def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
287280
# Extract the provider and model from the model string.
288281
# TODO: Not all the models are in the format of "provider/model"

tests/clients/test_cache.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,31 @@ def test_function2(prompt, model):
254254

255255
# Because model arg is not ignored, the second call should return a different result
256256
assert result3 != result4
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_request_cache_decorator_async(cache):
261+
"""Test the request_cache decorator with async functions."""
262+
from dspy.clients.cache import request_cache
263+
264+
# Mock the dspy.cache attribute
265+
with patch("dspy.cache", cache):
266+
# Define a test function
267+
@request_cache()
268+
async def test_function(prompt, model):
269+
return f"Response for {prompt} with {model}"
270+
271+
# First call should compute the result
272+
result1 = await test_function(prompt="Hello", model="openai/gpt-4o-mini")
273+
assert result1 == "Response for Hello with openai/gpt-4o-mini"
274+
275+
# Second call with same arguments should use cache
276+
with patch.object(cache, "get") as mock_get:
277+
mock_get.return_value = "Cached response"
278+
result2 = await test_function(prompt="Hello", model="openai/gpt-4o-mini")
279+
assert result2 == "Cached response"
280+
mock_get.assert_called_once()
281+
282+
# Call with different arguments should compute again
283+
result3 = await test_function(prompt="Different", model="openai/gpt-4o-mini")
284+
assert result3 == "Response for Different with openai/gpt-4o-mini"

tests/clients/test_lm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,43 @@ async def test_async_lm_call():
377377

378378
assert result == ["answer"]
379379
mock_acompletion.assert_called_once()
380+
381+
382+
@pytest.mark.asyncio
383+
async def test_async_lm_call_with_cache(tmp_path):
384+
"""Test the async LM call with caching."""
385+
original_cache = dspy.cache
386+
dspy.clients.configure_cache(
387+
enable_disk_cache=True,
388+
enable_memory_cache=True,
389+
enable_litellm_cache=False,
390+
disk_cache_dir=tmp_path / ".disk_cache",
391+
)
392+
cache = dspy.cache
393+
394+
lm = dspy.LM(model="openai/gpt-4o-mini")
395+
396+
with mock.patch("dspy.clients.lm.alitellm_completion") as mock_alitellm_completion:
397+
mock_alitellm_completion.return_value = ModelResponse(
398+
choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini"
399+
)
400+
mock_alitellm_completion.__qualname__ = "alitellm_completion"
401+
await lm.acall("Query")
402+
403+
assert len(cache.memory_cache) == 1
404+
cache_key = next(iter(cache.memory_cache.keys()))
405+
assert cache_key in cache.disk_cache
406+
assert mock_alitellm_completion.call_count == 1
407+
408+
await lm.acall("Query")
409+
# Second call should hit the cache, so no new call to LiteLLM is made.
410+
assert mock_alitellm_completion.call_count == 1
411+
412+
# Test that explicitly disabling memory cache works
413+
await lm.acall("New query", cache_in_memory=False)
414+
415+
# There should be a new call to LiteLLM on new query, but the memory cache shouldn't be written to.
416+
assert len(cache.memory_cache) == 1
417+
assert mock_alitellm_completion.call_count == 2
418+
419+
dspy.cache = original_cache

0 commit comments

Comments
 (0)