Skip to content

Commit 92d25a7

Browse files
authored
Merge pull request #1604 from chenmoneygithub/fix-litellm-client
Remove legacy code from client/lm.py and reformat
2 parents cc368f8 + 069808f commit 92d25a7

File tree

1 file changed

+41
-35
lines changed

1 file changed

+41
-35
lines changed

dspy/clients/lm.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,45 @@
1-
import os
2-
import uuid
3-
import ujson
41
import functools
5-
from pathlib import Path
2+
import os
3+
import uuid
64
from datetime import datetime
5+
from pathlib import Path
76

8-
try:
9-
import warnings
10-
with warnings.catch_warnings():
11-
warnings.filterwarnings("ignore", category=UserWarning)
12-
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
13-
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
14-
import litellm
15-
litellm.telemetry = False
7+
import litellm
8+
import ujson
9+
from litellm.caching import Cache
1610

17-
from litellm.caching import Cache
18-
disk_cache_dir = os.environ.get('DSPY_CACHEDIR') or os.path.join(Path.home(), '.dspy_cache')
19-
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
11+
disk_cache_dir = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
12+
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
13+
litellm.telemetry = False
2014

21-
except ImportError:
22-
class LitellmPlaceholder:
23-
def __getattr__(self, _): raise ImportError("The LiteLLM package is not installed. Run `pip install litellm`.")
15+
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
16+
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
2417

25-
litellm = LitellmPlaceholder()
2618

2719
class LM:
28-
def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs):
20+
def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs):
2921
self.model = model
3022
self.model_type = model_type
3123
self.cache = cache
3224
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
3325
self.history = []
3426

3527
if "o1-" in model:
36-
assert max_tokens >= 5000 and temperature == 1.0, \
37-
"OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
38-
39-
28+
assert (
29+
max_tokens >= 5000 and temperature == 1.0
30+
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
31+
4032
def __call__(self, prompt=None, messages=None, **kwargs):
4133
# Build the request.
4234
cache = kwargs.pop("cache", self.cache)
4335
messages = messages or [{"role": "user", "content": prompt}]
4436
kwargs = {**self.kwargs, **kwargs}
4537

4638
# Make the request and handle LRU & disk caching.
47-
if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion
48-
else: completion = cached_litellm_text_completion if cache else litellm_text_completion
39+
if self.model_type == "chat":
40+
completion = cached_litellm_completion if cache else litellm_completion
41+
else:
42+
completion = cached_litellm_text_completion if cache else litellm_text_completion
4943

5044
response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs)))
5145
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
@@ -63,8 +57,9 @@ def __call__(self, prompt=None, messages=None, **kwargs):
6357
model_type=self.model_type,
6458
)
6559
self.history.append(entry)
60+
6661
return outputs
67-
62+
6863
def inspect_history(self, n: int = 1):
6964
_inspect_history(self, n)
7065

@@ -73,14 +68,17 @@ def inspect_history(self, n: int = 1):
7368
def cached_litellm_completion(request):
7469
return litellm_completion(request, cache={"no-cache": False, "no-store": False})
7570

71+
7672
def litellm_completion(request, cache={"no-cache": True, "no-store": True}):
7773
kwargs = ujson.loads(request)
7874
return litellm.completion(cache=cache, **kwargs)
7975

76+
8077
@functools.lru_cache(maxsize=None)
8178
def cached_litellm_text_completion(request):
8279
return litellm_text_completion(request, cache={"no-cache": False, "no-store": False})
8380

81+
8482
def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}):
8583
kwargs = ujson.loads(request)
8684

@@ -93,32 +91,40 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True})
9391
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
9492

9593
# Build the prompt from the messages.
96-
prompt = '\n\n'.join([x['content'] for x in kwargs.pop("messages")] + ['BEGIN RESPONSE:'])
94+
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])
9795

98-
return litellm.text_completion(cache=cache, model=f'text-completion-openai/{model}', api_key=api_key,
99-
api_base=api_base, prompt=prompt, **kwargs)
96+
return litellm.text_completion(
97+
cache=cache,
98+
model=f"text-completion-openai/{model}",
99+
api_key=api_key,
100+
api_base=api_base,
101+
prompt=prompt,
102+
**kwargs,
103+
)
100104

101105

102106
def _green(text: str, end: str = "\n"):
103107
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end
104108

109+
105110
def _red(text: str, end: str = "\n"):
106111
return "\x1b[31m" + str(text) + "\x1b[0m" + end
107112

113+
108114
def _inspect_history(lm, n: int = 1):
109115
"""Prints the last n prompts and their completions."""
110116

111117
for item in lm.history[-n:]:
112-
messages = item["messages"] or [{"role": "user", "content": item['prompt']}]
118+
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
113119
outputs = item["outputs"]
114120
timestamp = item.get("timestamp", "Unknown time")
115121

116122
print("\n\n\n")
117123
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")
118-
124+
119125
for msg in messages:
120126
print(_red(f"{msg['role'].capitalize()} message:"))
121-
print(msg['content'].strip())
127+
print(msg["content"].strip())
122128
print("\n")
123129

124130
print(_red("Response:"))
@@ -127,5 +133,5 @@ def _inspect_history(lm, n: int = 1):
127133
if len(outputs) > 1:
128134
choices_text = f" \t (and {len(outputs)-1} other completions)"
129135
print(_red(choices_text, end=""))
130-
131-
print("\n\n\n")
136+
137+
print("\n\n\n")

0 commit comments

Comments
 (0)