1
1
package cc.unitmesh.devti.llms.custom
2
2
3
- import cc.unitmesh.devti.custom.action.CustomPromptConfig
4
3
import cc.unitmesh.devti.gui.chat.ChatRole
5
4
import cc.unitmesh.devti.llms.LLMProvider
6
5
import cc.unitmesh.devti.settings.AutoDevSettingsState
6
+ import cc.unitmesh.devti.settings.ResponseType
7
7
import com.fasterxml.jackson.databind.ObjectMapper
8
8
import com.intellij.openapi.components.Service
9
9
import com.intellij.openapi.diagnostic.logger
@@ -17,12 +17,15 @@ import io.reactivex.Flowable
17
17
import io.reactivex.FlowableEmitter
18
18
import kotlinx.coroutines.Dispatchers
19
19
import kotlinx.coroutines.ExperimentalCoroutinesApi
20
+ import kotlinx.coroutines.channels.awaitClose
20
21
import kotlinx.coroutines.flow.Flow
22
+ import kotlinx.coroutines.flow.MutableSharedFlow
21
23
import kotlinx.coroutines.flow.callbackFlow
22
24
import kotlinx.coroutines.withContext
23
25
import kotlinx.serialization.Serializable
24
26
import kotlinx.serialization.encodeToString
25
27
import kotlinx.serialization.json.*
28
+ import okhttp3.Call
26
29
import okhttp3.MediaType.Companion.toMediaTypeOrNull
27
30
import okhttp3.OkHttpClient
28
31
import okhttp3.Request
@@ -39,15 +42,15 @@ data class CustomRequest(val messages: List<Message>)
39
42
@Service(Service .Level .PROJECT )
40
43
class CustomLLMProvider (val project : Project ) : LLMProvider {
41
44
private val autoDevSettingsState = AutoDevSettingsState .getInstance()
42
- private val url get() = autoDevSettingsState.customEngineServer
43
- private val key get() = autoDevSettingsState.customEngineToken
44
- private val requestFormat : String get() = autoDevSettingsState.customEngineRequestFormat
45
- private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat
46
- private val customPromptConfig : CustomPromptConfig
47
- get() {
48
- val prompts = autoDevSettingsState.customPrompts
49
- return CustomPromptConfig .tryParse(prompts)
50
- }
45
+ private val url
46
+ get() = autoDevSettingsState.customEngineServer
47
+ private val key
48
+ get() = autoDevSettingsState.customEngineToken
49
+ private val requestFormat : String
50
+ get() = autoDevSettingsState.customEngineRequestFormat
51
+ private val responseFormat
52
+ get() = autoDevSettingsState.customEngineResponseType
53
+
51
54
private var client = OkHttpClient ()
52
55
private val timeout = Duration .ofSeconds(600 )
53
56
private val messages: MutableList <Message > = ArrayList ()
@@ -66,7 +69,6 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
66
69
return this .prompt(promptText, " " )
67
70
}
68
71
69
- @OptIn(ExperimentalCoroutinesApi ::class )
70
72
override fun stream (promptText : String , systemPrompt : String , keepHistory : Boolean ): Flow <String > {
71
73
if (! keepHistory) {
72
74
clearMessage()
@@ -86,15 +88,32 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
86
88
}
87
89
builder.appendCustomHeaders(requestFormat)
88
90
89
- client = client.newBuilder()
90
- .readTimeout(timeout)
91
- .build()
92
- val request = builder
93
- .url(url)
94
- .post(body)
95
- .build()
91
+ client = client.newBuilder().readTimeout(timeout).build()
92
+ val call = client.newCall(builder.url(url).post(body).build())
93
+
94
+ if (autoDevSettingsState.customEngineResponseType == ResponseType .SSE .name) {
95
+ return streamSSE(call)
96
+ } else {
97
+ return streamJson(call)
98
+ }
99
+ }
100
+
101
+
102
+ private val _responseFlow = MutableSharedFlow <String >()
96
103
97
- val call = client.newCall(request)
104
+ @OptIn(ExperimentalCoroutinesApi ::class )
105
+ private fun streamJson (call : Call ): Flow <String > = callbackFlow {
106
+ call.enqueue(JSONBodyResponseCallback (responseFormat) {
107
+ withContext(Dispatchers .IO ) {
108
+ send(it)
109
+ }
110
+ close()
111
+ })
112
+ awaitClose()
113
+ }
114
+
115
+ @OptIn(ExperimentalCoroutinesApi ::class )
116
+ private fun streamSSE (call : Call ): Flow <String > {
98
117
val emitDone = false
99
118
100
119
val sseFlowable = Flowable
@@ -106,29 +125,29 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
106
125
return callbackFlow {
107
126
withContext(Dispatchers .IO ) {
108
127
sseFlowable
109
- .doOnError{
110
- it.printStackTrace()
111
- close()
112
- }
113
- .blockingForEach { sse ->
114
- if (responseFormat.isNotEmpty()) {
115
- val chunk: String = JsonPath .parse(sse!! .data)?.read(responseFormat)
116
- ? : throw Exception (" Failed to parse chunk" )
117
- logger.warn(" got msg: $chunk " )
118
- trySend(chunk)
119
- } else {
120
- val result: ChatCompletionResult =
121
- ObjectMapper ().readValue(sse!! .data, ChatCompletionResult ::class .java)
128
+ .doOnError {
129
+ it.printStackTrace()
130
+ close()
131
+ }
132
+ .blockingForEach { sse ->
133
+ if (responseFormat.isNotEmpty()) {
134
+ val chunk: String = JsonPath .parse(sse!! .data)?.read(responseFormat)
135
+ ? : throw Exception (" Failed to parse chunk" )
136
+ logger.warn(" got msg: $chunk " )
137
+ trySend(chunk)
138
+ } else {
139
+ val result: ChatCompletionResult =
140
+ ObjectMapper ().readValue(sse!! .data, ChatCompletionResult ::class .java)
122
141
123
142
val completion = result.choices[0 ].message
124
143
if (completion != null && completion.content != null ) {
125
144
trySend(completion.content)
126
145
}
127
146
}
128
147
}
129
-
130
148
close()
131
149
}
150
+ awaitClose()
132
151
}
133
152
} catch (e: Exception ) {
134
153
logger.error(" Failed to stream" , e)
@@ -174,7 +193,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
174
193
fun Request.Builder.appendCustomHeaders (customRequestHeader : String ): Request .Builder = apply {
175
194
runCatching {
176
195
Json .parseToJsonElement(customRequestHeader)
177
- .jsonObject[" customHeaders" ].let { customFields ->
196
+ .jsonObject[" customHeaders" ].let { customFields ->
178
197
customFields?.jsonObject?.forEach { (key, value) ->
179
198
header(key, value.jsonPrimitive.content)
180
199
}
@@ -232,5 +251,5 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
232
251
fun CustomRequest.updateCustomFormat (format : String ): String {
233
252
val requestContentOri = Json .encodeToString<CustomRequest >(this )
234
253
return Json .parseToJsonElement(requestContentOri)
235
- .jsonObject.updateCustomBody(format).toString()
254
+ .jsonObject.updateCustomBody(format).toString()
236
255
}
0 commit comments