1
- import os
2
- import uuid
3
- import ujson
4
1
import functools
5
- from pathlib import Path
2
+ import os
3
+ import uuid
6
4
from datetime import datetime
5
+ from pathlib import Path
7
6
8
- try :
9
- import warnings
10
- with warnings .catch_warnings ():
11
- warnings .filterwarnings ("ignore" , category = UserWarning )
12
- if "LITELLM_LOCAL_MODEL_COST_MAP" not in os .environ :
13
- os .environ ["LITELLM_LOCAL_MODEL_COST_MAP" ] = "True"
14
- import litellm
15
- litellm .telemetry = False
7
+ import litellm
8
+ import ujson
9
+ from litellm .caching import Cache
16
10
17
- from litellm . caching import Cache
18
- disk_cache_dir = os . environ . get ( 'DSPY_CACHEDIR' ) or os . path . join ( Path . home (), '.dspy_cache' )
19
- litellm .cache = Cache ( disk_cache_dir = disk_cache_dir , type = "disk" )
11
+ disk_cache_dir = os . environ . get ( "DSPY_CACHEDIR" ) or os . path . join ( Path . home (), ".dspy_cache" )
12
+ litellm . cache = Cache ( disk_cache_dir = disk_cache_dir , type = "disk" )
13
+ litellm .telemetry = False
20
14
21
- except ImportError :
22
- class LitellmPlaceholder :
23
- def __getattr__ (self , _ ): raise ImportError ("The LiteLLM package is not installed. Run `pip install litellm`." )
15
+ if "LITELLM_LOCAL_MODEL_COST_MAP" not in os .environ :
16
+ os .environ ["LITELLM_LOCAL_MODEL_COST_MAP" ] = "True"
24
17
25
- litellm = LitellmPlaceholder ()
26
18
27
19
class LM :
28
- def __init__ (self , model , model_type = ' chat' , temperature = 0.0 , max_tokens = 1000 , cache = True , ** kwargs ):
20
+ def __init__ (self , model , model_type = " chat" , temperature = 0.0 , max_tokens = 1000 , cache = True , ** kwargs ):
29
21
self .model = model
30
22
self .model_type = model_type
31
23
self .cache = cache
32
24
self .kwargs = dict (temperature = temperature , max_tokens = max_tokens , ** kwargs )
33
25
self .history = []
34
26
35
27
if "o1-" in model :
36
- assert max_tokens >= 5000 and temperature == 1.0 , \
37
- "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
38
-
39
-
28
+ assert (
29
+ max_tokens >= 5000 and temperature == 1.0
30
+ ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
31
+
40
32
def __call__ (self , prompt = None , messages = None , ** kwargs ):
41
33
# Build the request.
42
34
cache = kwargs .pop ("cache" , self .cache )
43
35
messages = messages or [{"role" : "user" , "content" : prompt }]
44
36
kwargs = {** self .kwargs , ** kwargs }
45
37
46
38
# Make the request and handle LRU & disk caching.
47
- if self .model_type == "chat" : completion = cached_litellm_completion if cache else litellm_completion
48
- else : completion = cached_litellm_text_completion if cache else litellm_text_completion
39
+ if self .model_type == "chat" :
40
+ completion = cached_litellm_completion if cache else litellm_completion
41
+ else :
42
+ completion = cached_litellm_text_completion if cache else litellm_text_completion
49
43
50
44
response = completion (ujson .dumps (dict (model = self .model , messages = messages , ** kwargs )))
51
45
outputs = [c .message .content if hasattr (c , "message" ) else c ["text" ] for c in response ["choices" ]]
@@ -63,8 +57,9 @@ def __call__(self, prompt=None, messages=None, **kwargs):
63
57
model_type = self .model_type ,
64
58
)
65
59
self .history .append (entry )
60
+
66
61
return outputs
67
-
62
+
68
63
def inspect_history (self , n : int = 1 ):
69
64
_inspect_history (self , n )
70
65
@@ -73,14 +68,17 @@ def inspect_history(self, n: int = 1):
73
68
def cached_litellm_completion (request ):
74
69
return litellm_completion (request , cache = {"no-cache" : False , "no-store" : False })
75
70
71
+
76
72
def litellm_completion (request , cache = {"no-cache" : True , "no-store" : True }):
77
73
kwargs = ujson .loads (request )
78
74
return litellm .completion (cache = cache , ** kwargs )
79
75
76
+
80
77
@functools .lru_cache (maxsize = None )
81
78
def cached_litellm_text_completion (request ):
82
79
return litellm_text_completion (request , cache = {"no-cache" : False , "no-store" : False })
83
80
81
+
84
82
def litellm_text_completion (request , cache = {"no-cache" : True , "no-store" : True }):
85
83
kwargs = ujson .loads (request )
86
84
@@ -93,32 +91,40 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True})
93
91
api_base = kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider } _API_BASE" )
94
92
95
93
# Build the prompt from the messages.
96
- prompt = ' \n \n ' .join ([x [' content' ] for x in kwargs .pop ("messages" )] + [' BEGIN RESPONSE:' ])
94
+ prompt = " \n \n " .join ([x [" content" ] for x in kwargs .pop ("messages" )] + [" BEGIN RESPONSE:" ])
97
95
98
- return litellm .text_completion (cache = cache , model = f'text-completion-openai/{ model } ' , api_key = api_key ,
99
- api_base = api_base , prompt = prompt , ** kwargs )
96
+ return litellm .text_completion (
97
+ cache = cache ,
98
+ model = f"text-completion-openai/{ model } " ,
99
+ api_key = api_key ,
100
+ api_base = api_base ,
101
+ prompt = prompt ,
102
+ ** kwargs ,
103
+ )
100
104
101
105
102
106
def _green (text : str , end : str = "\n " ):
103
107
return "\x1b [32m" + str (text ).lstrip () + "\x1b [0m" + end
104
108
109
+
105
110
def _red (text : str , end : str = "\n " ):
106
111
return "\x1b [31m" + str (text ) + "\x1b [0m" + end
107
112
113
+
108
114
def _inspect_history (lm , n : int = 1 ):
109
115
"""Prints the last n prompts and their completions."""
110
116
111
117
for item in lm .history [- n :]:
112
- messages = item ["messages" ] or [{"role" : "user" , "content" : item [' prompt' ]}]
118
+ messages = item ["messages" ] or [{"role" : "user" , "content" : item [" prompt" ]}]
113
119
outputs = item ["outputs" ]
114
120
timestamp = item .get ("timestamp" , "Unknown time" )
115
121
116
122
print ("\n \n \n " )
117
123
print ("\x1b [34m" + f"[{ timestamp } ]" + "\x1b [0m" + "\n " )
118
-
124
+
119
125
for msg in messages :
120
126
print (_red (f"{ msg ['role' ].capitalize ()} message:" ))
121
- print (msg [' content' ].strip ())
127
+ print (msg [" content" ].strip ())
122
128
print ("\n " )
123
129
124
130
print (_red ("Response:" ))
@@ -127,5 +133,5 @@ def _inspect_history(lm, n: int = 1):
127
133
if len (outputs ) > 1 :
128
134
choices_text = f" \t (and { len (outputs )- 1 } other completions)"
129
135
print (_red (choices_text , end = "" ))
130
-
131
- print ("\n \n \n " )
136
+
137
+ print ("\n \n \n " )
0 commit comments