@@ -54,7 +54,7 @@ def __init__(
54
54
max_tokens: The maximum number of tokens to generate per response.
55
55
cache: Whether to cache the model responses for reuse to improve performance
56
56
and reduce costs.
57
- cache_in_memory: To enable additional caching with LRU in memory.
57
+ cache_in_memory (deprecated) : To enable additional caching with LRU in memory.
58
58
callbacks: A list of callback functions to run before and after each request.
59
59
num_retries: The number of times to retry a request if it fails transiently due to
60
60
network error, rate limiting, etc. Requests are retried with exponential
@@ -92,44 +92,69 @@ def __init__(
92
92
else :
93
93
self .kwargs = dict (temperature = temperature , max_tokens = max_tokens , ** kwargs )
94
94
95
+ def _get_cached_completion_fn (self , completion_fn , cache , enable_memory_cache ):
96
+ ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ]
97
+ if cache and enable_memory_cache :
98
+ completion_fn = request_cache (
99
+ cache_arg_name = "request" ,
100
+ ignored_args_for_cache_key = ignored_args_for_cache_key ,
101
+ )(completion_fn )
102
+ elif cache :
103
+ completion_fn = request_cache (
104
+ cache_arg_name = "request" ,
105
+ ignored_args_for_cache_key = ignored_args_for_cache_key ,
106
+ enable_memory_cache = False ,
107
+ )(completion_fn )
108
+ else :
109
+ completion_fn = completion_fn
110
+
111
+ if not cache or litellm .cache is None :
112
+ litellm_cache_args = {"no-cache" : True , "no-store" : True }
113
+ else :
114
+ litellm_cache_args = {"no-cache" : False , "no-store" : False }
115
+
116
+ return completion_fn , litellm_cache_args
117
+
95
118
def forward (self , prompt = None , messages = None , ** kwargs ):
96
119
# Build the request.
97
120
cache = kwargs .pop ("cache" , self .cache )
98
- # disable cache will also disable in memory cache
99
- cache_in_memory = cache and kwargs . pop ( "cache_in_memory" , self . cache_in_memory )
121
+ enable_memory_cache = kwargs . pop ( "cache_in_memory" , self . cache_in_memory )
122
+
100
123
messages = messages or [{"role" : "user" , "content" : prompt }]
101
124
kwargs = {** self .kwargs , ** kwargs }
102
125
103
- # Make the request and handle LRU & disk caching.
104
- if cache_in_memory :
105
- completion = cached_litellm_completion if self .model_type == "chat" else cached_litellm_text_completion
106
-
107
- results = completion (
108
- request = dict (model = self .model , messages = messages , ** kwargs ),
109
- num_retries = self .num_retries ,
110
- )
111
- else :
112
- completion = litellm_completion if self .model_type == "chat" else litellm_text_completion
126
+ completion = litellm_completion if self .model_type == "chat" else litellm_text_completion
127
+ completion , litellm_cache_args = self ._get_cached_completion_fn (completion , cache , enable_memory_cache )
113
128
114
- results = completion (
115
- request = dict (model = self .model , messages = messages , ** kwargs ),
116
- num_retries = self .num_retries ,
117
- # only leverage LiteLLM cache in this case
118
- cache = {"no-cache" : not cache , "no-store" : not cache },
119
- )
129
+ results = completion (
130
+ request = dict (model = self .model , messages = messages , ** kwargs ),
131
+ num_retries = self .num_retries ,
132
+ cache = litellm_cache_args ,
133
+ )
120
134
121
135
if not getattr (results , "cache_hit" , False ) and dspy .settings .usage_tracker and hasattr (results , "usage" ):
122
136
settings .usage_tracker .add_usage (self .model , dict (results .usage ))
123
137
return results
124
138
125
139
async def aforward (self , prompt = None , messages = None , ** kwargs ):
126
- completion = alitellm_completion if self .model_type == "chat" else alitellm_text_completion
140
+ # Build the request.
141
+ cache = kwargs .pop ("cache" , self .cache )
142
+ enable_memory_cache = kwargs .pop ("cache_in_memory" , self .cache_in_memory )
127
143
128
144
messages = messages or [{"role" : "user" , "content" : prompt }]
145
+ kwargs = {** self .kwargs , ** kwargs }
146
+
147
+ completion = alitellm_completion if self .model_type == "chat" else alitellm_text_completion
148
+ completion , litellm_cache_args = self ._get_cached_completion_fn (completion , cache , enable_memory_cache )
149
+
129
150
results = await completion (
130
151
request = dict (model = self .model , messages = messages , ** kwargs ),
131
152
num_retries = self .num_retries ,
153
+ cache = litellm_cache_args ,
132
154
)
155
+
156
+ if not getattr (results , "cache_hit" , False ) and dspy .settings .usage_tracker and hasattr (results , "usage" ):
157
+ settings .usage_tracker .add_usage (self .model , dict (results .usage ))
133
158
return results
134
159
135
160
def launch (self , launch_kwargs : Optional [Dict [str , Any ]] = None ):
@@ -206,22 +231,6 @@ def dump_state(self):
206
231
return {key : getattr (self , key ) for key in state_keys } | self .kwargs
207
232
208
233
209
- @request_cache (cache_arg_name = "request" , ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ])
210
- def cached_litellm_completion (request : Dict [str , Any ], num_retries : int ):
211
- import litellm
212
-
213
- if litellm .cache :
214
- litellm_cache_args = {"no-cache" : False , "no-store" : False }
215
- else :
216
- litellm_cache_args = {"no-cache" : True , "no-store" : True }
217
-
218
- return litellm_completion (
219
- request ,
220
- cache = litellm_cache_args ,
221
- num_retries = num_retries ,
222
- )
223
-
224
-
225
234
def litellm_completion (request : Dict [str , Any ], num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
226
235
retry_kwargs = dict (
227
236
retry_policy = _get_litellm_retry_policy (num_retries ),
@@ -267,22 +276,6 @@ async def stream_completion():
267
276
return stream_completion ()
268
277
269
278
270
- @request_cache (cache_arg_name = "request" , ignored_args_for_cache_key = ["api_key" , "api_base" , "base_url" ])
271
- def cached_litellm_text_completion (request : Dict [str , Any ], num_retries : int ):
272
- import litellm
273
-
274
- if litellm .cache :
275
- litellm_cache_args = {"no-cache" : False , "no-store" : False }
276
- else :
277
- litellm_cache_args = {"no-cache" : True , "no-store" : True }
278
-
279
- return litellm_text_completion (
280
- request ,
281
- num_retries = num_retries ,
282
- cache = litellm_cache_args ,
283
- )
284
-
285
-
286
279
def litellm_text_completion (request : Dict [str , Any ], num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
287
280
# Extract the provider and model from the model string.
288
281
# TODO: Not all the models are in the format of "provider/model"
0 commit comments