Skip to content

Commit 29682c4

Browse files
committed
Updates for new openai API
1 parent deab0bc commit 29682c4

File tree

3 files changed

+98
-70
lines changed

3 files changed

+98
-70
lines changed

llmstack/play/actors/agent.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import importlib
21
import logging
32
import time
43
import uuid
5-
import orjson as json
64
from typing import Any
7-
from jinja2 import Template
85

6+
import orjson as json
97
from asgiref.sync import async_to_sync
8+
from jinja2 import Template
9+
from openai import OpenAI
1010
from pydantic import BaseModel
11+
1112
from llmstack.play.actor import Actor, BookKeepingData
1213
from llmstack.play.actors.output import OutputResponse
1314
from llmstack.play.output_stream import Message, MessageType
@@ -54,6 +55,10 @@ def __init__(self, output_stream, processor_configs, dependencies=[], all_depend
5455
self._input = kwargs.get('input')
5556
self._config = kwargs.get('config', {})
5657

58+
self._openai_client = OpenAI(
59+
api_key=self._env['openai_api_key']
60+
)
61+
5762
self._agent_messages = [{
5863
'role': 'system',
5964
'content': self._config.get('system_message', 'You are a helpful assistant that uses provided tools to perform actions.')
@@ -94,8 +99,6 @@ def _on_error(self, message) -> None:
9499
)
95100

96101
def on_receive(self, message: Message) -> Any:
97-
import openai
98-
importlib.reload(openai)
99102
max_steps = self._config.get('max_steps', 10) + 2
100103

101104
if len(self._agent_messages) > max_steps:
@@ -123,14 +126,12 @@ def on_receive(self, message: Message) -> Any:
123126

124127
model = self._config.get('model', 'gpt-3.5-turbo')
125128

126-
openai.api_key = self._env['openai_api_key']
127-
128129
# Make one call to the model
129130
full_content = ''
130131
function_name = ''
131132
function_args = ''
132133
finish_reason = None
133-
result = openai.ChatCompletion.create(
134+
result = self._openai_client.chat.completions.create(
134135
model=model,
135136
messages=self._agent_messages,
136137
stream=True,
@@ -139,30 +140,31 @@ def on_receive(self, message: Message) -> Any:
139140
agent_message_id = str(uuid.uuid4())
140141

141142
for data in result:
142-
if data.get('object') and data.get('object') == 'chat.completion.chunk' and data.get('choices') and len(data.get('choices')) > 0:
143-
finish_reason = data['choices'][0]['finish_reason']
144-
delta = data['choices'][0]['delta']
145-
function_call = delta.get('function_call')
146-
content = delta.get('content')
147-
148-
if function_call and function_call.get('name'):
149-
function_name += function_call['name']
143+
logger.info(data)
144+
if data.object == 'chat.completion.chunk' and len(data.choices) > 0 and data.choices[0].delta:
145+
finish_reason = data.choices[0].finish_reason
146+
delta = data.choices[0].delta
147+
function_call = delta.function_call
148+
content = delta.content
149+
150+
if function_call and function_call.name:
151+
function_name += function_call.name
150152
async_to_sync(self._output_stream.write)(
151153
AgentOutput(
152154
content=FunctionCall(
153-
name=function_call['name'],
155+
name=function_call.name,
154156
),
155157
id=agent_message_id,
156158
from_id='agent',
157159
type='step',
158160
)
159161
)
160-
elif function_call and function_call.get('arguments'):
161-
function_args += function_call['arguments']
162+
elif function_call and function_call.arguments:
163+
function_args += function_call.arguments
162164
async_to_sync(self._output_stream.write)(
163165
AgentOutput(
164166
content=FunctionCall(
165-
arguments=function_call['arguments'],
167+
arguments=function_call.arguments,
166168
),
167169
id=agent_message_id,
168170
from_id='agent',

llmstack/processors/providers/openai/chat_completions_vision.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import importlib
2-
import openai
3-
from enum import Enum
42
import logging
3+
from enum import Enum
54
from typing import Annotated, List, Literal, Optional, Union
65

6+
import openai
77
from asgiref.sync import async_to_sync
8+
from openai import OpenAI
89
from pydantic import BaseModel, Field, confloat, conint
910

10-
from llmstack.common.blocks.llm.openai import OpenAIChatCompletionsAPIProcessorConfiguration
11-
from llmstack.processors.providers.api_processor_interface import ApiProcessorInterface, ApiProcessorSchema
12-
from llmstack.processors.providers.api_processor_interface import ApiProcessorInterface, ApiProcessorSchema
11+
from llmstack.common.blocks.llm.openai import (
12+
OpenAIChatCompletionsAPIProcessorConfiguration,
13+
)
14+
from llmstack.processors.providers.api_processor_interface import (
15+
ApiProcessorInterface,
16+
ApiProcessorSchema,
17+
)
1318

1419
logger = logging.getLogger(__name__)
1520

21+
1622
class Role(str, Enum):
1723
SYSTEM = 'system'
1824
USER = 'user'
@@ -21,42 +27,53 @@ class Role(str, Enum):
2127
def __str__(self):
2228
return self.value
2329

30+
2431
class ChatCompletionsVisionModel(str, Enum):
2532
GPT_4_Vision = 'gpt-4-vision-preview'
2633

2734
def __str__(self):
2835
return self.value
29-
36+
37+
3038
class TextMessage(BaseModel):
3139
type: Literal["text"]
32-
40+
3341
text: str = Field(
3442
default='', description='The message text.')
35-
43+
44+
3645
class UrlImageMessage(BaseModel):
3746
type: Literal["image_url"]
38-
47+
3948
image_url: str = Field(
4049
default='', description='The image data URI.')
4150

42-
Message = Annotated[Union[TextMessage, UrlImageMessage], Field(discriminator='type')]
51+
52+
Message = Annotated[Union[TextMessage, UrlImageMessage],
53+
Field(discriminator='type')]
54+
55+
4356
class ChatMessage(ApiProcessorSchema):
4457
role: Optional[Role] = Field(
4558
default=Role.USER, description="The role of the message sender. Can be 'user' or 'assistant' or 'system'.",
4659
)
47-
content: List[Union[TextMessage, UrlImageMessage]] = Field(default=[], description='The message text.')
48-
60+
content: List[Union[TextMessage, UrlImageMessage]] = Field(
61+
default=[], description='The message text.')
62+
63+
4964
class ChatCompletionsVisionInput(ApiProcessorSchema):
5065
system_message: Optional[str] = Field(
5166
default='', description='A message from the system, which will be prepended to the chat history.', widget='textarea',
5267
)
5368
messages: List[Message] = Field(
5469
default=[], description='A list of messages, each with a role and message text.'
5570
)
56-
71+
72+
5773
class ChatCompletionsVisionOutput(ApiProcessorSchema):
5874
result: str = Field(default='', description='The model-generated message.')
59-
75+
76+
6077
class ChatCompletionsVisionConfiguration(OpenAIChatCompletionsAPIProcessorConfiguration, ApiProcessorSchema):
6178
model: ChatCompletionsVisionModel = Field(
6279
default=ChatCompletionsVisionModel.GPT_4_Vision,
@@ -83,6 +100,7 @@ class ChatCompletionsVisionConfiguration(OpenAIChatCompletionsAPIProcessorConfig
83100
default=False, description="Automatically prune chat history. This is only applicable if 'retain_history' is set to 'true'.",
84101
)
85102

103+
86104
class ChatCompletionsVision(ApiProcessorInterface[ChatCompletionsVisionInput, ChatCompletionsVisionOutput, ChatCompletionsVisionConfiguration]):
87105
"""
88106
OpenAI Chat Completions with vision API
@@ -114,32 +132,32 @@ def session_data_to_persist(self) -> dict:
114132
def process(self) -> dict:
115133
importlib.reload(openai)
116134
output_stream = self._output_stream
117-
135+
118136
chat_history = self._chat_history if self._config.retain_history else []
119137
messages = []
120-
messages.append({'role': 'system', 'content': self._input.system_message})
121-
138+
messages.append(
139+
{'role': 'system', 'content': self._input.system_message})
140+
122141
for msg in chat_history:
123142
messages.append(msg)
124-
125-
messages.append({'role': 'user', 'content': [msg.dict() for msg in self._input.messages]})
126-
127-
openai.api_key = self._env['openai_api_key']
128-
result = openai.chat.completions.create(
129-
model=self._config.model,
130-
messages=messages,
131-
temperature=self._config.temperature,
132-
stream=True,
133-
)
134-
135-
143+
144+
messages.append({'role': 'user', 'content': [
145+
msg.dict() for msg in self._input.messages]})
146+
147+
openai_client = OpenAI(api_key=self._env['openai_api_key'])
148+
result = openai_client.chat.completions.create(
149+
model=self._config.model,
150+
messages=messages,
151+
temperature=self._config.temperature,
152+
stream=True,
153+
)
154+
136155
for data in result:
137-
if data.get('object') and data.get('object') == 'chat.completion.chunk' and data.get('choices') and len(data.get('choices')) > 0 and data['choices'][0].get('delta') and data['choices'][0]['delta'].get('content'):
156+
if data.object == 'chat.completion.chunk' and len(data.choices) > 0 and data.choices[0].delta and data.choices[0].delta.content:
138157
async_to_sync(output_stream.write)(
139158
ChatCompletionsVisionOutput(
140-
result=data['choices'][0]['delta']['content']
159+
result=data.choices[0].delta.content
141160
))
142-
143161

144162
output = self._output_stream.finalize()
145163

llmstack/processors/providers/promptly/text_chat.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import logging
44
import uuid
55
from enum import Enum
6-
from typing import Optional
7-
from typing import List
6+
from typing import List, Optional
87

9-
import openai
108
from asgiref.sync import async_to_sync
119
from django import db
10+
from openai import AzureOpenAI, OpenAI
1211
from pydantic import Field
1312

1413
from llmstack.datasources.models import DataSource
1514
from llmstack.datasources.types import DataSourceTypeFactory
16-
from llmstack.processors.providers.api_processor_interface import ApiProcessorInterface, ApiProcessorSchema
15+
from llmstack.processors.providers.api_processor_interface import (
16+
ApiProcessorInterface,
17+
ApiProcessorSchema,
18+
)
1719

1820
logger = logging.getLogger(__name__)
1921

@@ -188,7 +190,6 @@ def fetch_datasource_docs(datasource_uuid):
188190

189191
def process(self) -> dict:
190192
input = self._input.dict()
191-
importlib.reload(openai)
192193
output_stream = self._output_stream
193194
docs = self._search_datasources(input)
194195

@@ -227,10 +228,12 @@ def process(self) -> dict:
227228
)
228229

229230
if self._env['azure_openai_api_key'] and self._config.use_azure_if_available:
230-
openai.api_type = 'azure'
231-
openai.api_key = self._env['azure_openai_api_key']
232-
openai.api_base = f"https://{self._env['azure_openai_endpoint']}.openai.azure.com"
233-
openai.api_version = '2023-03-15-preview'
231+
openai_client = AzureOpenAI(
232+
api_key=self._env['azure_openai_api_key'],
233+
api_version='2023-03-15-preview',
234+
azure_endpoint=self._env['azure_openai_endpoint'],
235+
)
236+
234237
model = self._config.dict().get('model', 'gpt-3.5-turbo')
235238
engine = 'gpt-4'
236239
if model == 'gpt-3.5-turbo':
@@ -240,30 +243,35 @@ def process(self) -> dict:
240243
elif model == 'gpt-4-32k':
241244
engine = 'gpt-4-32k'
242245

243-
result = openai.ChatCompletion.create(
246+
result = openai_client.chat.completions.create(
244247
engine=engine,
245248
messages=[system_message] +
246249
[context_message] + self._chat_history,
247250
temperature=self._config.temperature,
248251
stream=True,
249252
)
250253
elif self._env['localai_base_url'] and self._config.use_localai_if_available:
251-
if self._env['localai_api_key']:
252-
openai.api_key = self._env['localai_api_key']
253-
openai.api_base = self._env['localai_base_url']
254+
openai_client = OpenAI(
255+
api_key=self._env['localai_api_key'],
256+
base_url=self._env['localai_base_url'],
257+
) if self._env['localai_api_key'] else OpenAI(
258+
base_url=self._env['localai_base_url'],
259+
)
254260
model = self._config.dict().get('model', 'gpt-3.5-turbo')
255261

256-
result = openai.ChatCompletion.create(
262+
result = openai_client.chat.completions.create(
257263
model=model,
258264
messages=[system_message] +
259265
[context_message] + self._chat_history,
260266
temperature=self._config.temperature,
261267
stream=True,
262268
)
263269
elif self._env['openai_api_key'] is not None:
264-
openai.api_key = self._env['openai_api_key']
270+
openai_client = OpenAI(
271+
api_key=self._env['openai_api_key'],
272+
)
265273
model = self._config.dict().get('model', 'gpt-3.5-turbo')
266-
result = openai.ChatCompletion.create(
274+
result = openai_client.chat.completions.create(
267275
model=model,
268276
messages=[system_message] +
269277
[context_message] + self._chat_history,
@@ -274,10 +282,10 @@ def process(self) -> dict:
274282
raise Exception('No OpenAI API key provided')
275283

276284
for data in result:
277-
if data.get('object') and data.get('object') == 'chat.completion.chunk' and data.get('choices') and len(data.get('choices')) > 0 and data['choices'][0].get('delta') and data['choices'][0]['delta'].get('content'):
285+
if data.object == 'chat.completion.chunk' and len(data.choices) > 0 and data.choices[0].delta and data.choices[0].delta.content:
278286
async_to_sync(output_stream.write)(
279287
TextChatOutput(
280-
answer=data['choices'][0]['delta']['content']
288+
answer=data.choices[0].delta.content
281289
))
282290

283291
if len(docs) > 0:

0 commit comments

Comments
 (0)