Skip to content

Commit ad79829

Browse files
authored
Merge pull request #66 from hotip/master
支持 SSE 式 stream 返回和一次性 JSON 返回
2 parents 439b136 + bfde660 commit ad79829

File tree

7 files changed

+23
-28
lines changed

7 files changed

+23
-28
lines changed

src/main/kotlin/cc/unitmesh/devti/gui/chat/ChatCodingPanel.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ class ChatCodingPanel(private val chatCodingService: ChatCodingService, val disp
206206
}.catch {
207207
it.printStackTrace()
208208
}.collect {
209-
println("got message $it")
210209
text += it
211210

212211
// 以下两个 API 设计不合理,如果必须要同时调用,那就只提供一个就好了

src/main/kotlin/cc/unitmesh/devti/gui/chat/MessageView.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ class MessageView(private val message: String, val role: ChatRole, private val d
104104
MessageWorker(content).execute()
105105
}
106106

107-
fun updateSourceContent(source: String?) {
108-
component.text = source
107+
108+
private var answer: String = ""
109+
fun updateSourceContent(source: String) {
110+
answer = source
109111
}
110112

111113
fun scrollToBottom() {
@@ -117,15 +119,14 @@ class MessageView(private val message: String, val role: ChatRole, private val d
117119

118120
fun reRenderAssistantOutput() {
119121
ApplicationManager.getApplication().invokeLater {
120-
val displayText = component.text
121122

122123
centerPanel.remove(component)
123124
centerPanel.updateUI()
124125

125126
centerPanel.add(myNameLabel)
126127
centerPanel.add(createTitlePanel())
127128

128-
val message = SimpleMessage(displayText, displayText, ChatRole.Assistant)
129+
val message = SimpleMessage(answer, answer, ChatRole.Assistant)
129130
renderInPartView(message)
130131

131132
centerPanel.revalidate()

src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper
88
import com.intellij.openapi.components.Service
99
import com.intellij.openapi.diagnostic.logger
1010
import com.intellij.openapi.project.Project
11-
import com.nfeld.jsonpathkt.JsonPath
12-
import com.nfeld.jsonpathkt.extension.read
11+
import com.jayway.jsonpath.JsonPath
1312
import com.theokanning.openai.completion.chat.ChatCompletionResult
1413
import com.theokanning.openai.service.SSE
1514
import io.reactivex.BackpressureStrategy
@@ -19,7 +18,6 @@ import kotlinx.coroutines.Dispatchers
1918
import kotlinx.coroutines.ExperimentalCoroutinesApi
2019
import kotlinx.coroutines.channels.awaitClose
2120
import kotlinx.coroutines.flow.Flow
22-
import kotlinx.coroutines.flow.MutableSharedFlow
2321
import kotlinx.coroutines.flow.callbackFlow
2422
import kotlinx.coroutines.withContext
2523
import kotlinx.serialization.Serializable
@@ -49,7 +47,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
4947
private val requestFormat: String
5048
get() = autoDevSettingsState.customEngineRequestFormat
5149
private val responseFormat
52-
get() = autoDevSettingsState.customEngineResponseType
50+
get() = autoDevSettingsState.customEngineResponseFormat
5351

5452
private var client = OkHttpClient()
5553
private val timeout = Duration.ofSeconds(600)
@@ -99,8 +97,6 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
9997
}
10098

10199

102-
private val _responseFlow = MutableSharedFlow<String>()
103-
104100
@OptIn(ExperimentalCoroutinesApi::class)
105101
private fun streamJson(call: Call): Flow<String> = callbackFlow {
106102
call.enqueue(JSONBodyResponseCallback(responseFormat) {
@@ -114,11 +110,9 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
114110

115111
@OptIn(ExperimentalCoroutinesApi::class)
116112
private fun streamSSE(call: Call): Flow<String> {
117-
val emitDone = false
118-
119113
val sseFlowable = Flowable
120114
.create({ emitter: FlowableEmitter<SSE> ->
121-
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone))
115+
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, true))
122116
}, BackpressureStrategy.BUFFER)
123117

124118
try {
@@ -132,8 +126,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
132126
.blockingForEach { sse ->
133127
if (responseFormat.isNotEmpty()) {
134128
val chunk: String = JsonPath.parse(sse!!.data)?.read(responseFormat)
135-
?: throw Exception("Failed to parse chunk")
136-
logger.warn("got msg: $chunk")
129+
?: throw Exception("Failed to parse chunk: ${sse.data}")
137130
trySend(chunk)
138131
} else {
139132
val result: ChatCompletionResult =
@@ -215,7 +208,7 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
215208

216209
customRequestJson["customFields"]?.let { customFields ->
217210
customFields.jsonObject.forEach { (key, value) ->
218-
put(key, value.jsonPrimitive.content)
211+
put(key, value.jsonPrimitive)
219212
}
220213
}
221214

src/main/kotlin/cc/unitmesh/devti/llms/custom/JSONBodyResponseCallback.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class JSONBodyResponseCallback(private val responseFormat: String,private val ca
1616
}
1717

1818
override fun onResponse(call: Call, response: Response) {
19+
println("got response ${response.body?.string()}")
1920
val responseContent: String = JsonPath.parse(response.body?.string())?.read(responseFormat) ?: ""
2021

2122
runBlocking() {

src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@ enum class XingHuoApiVersion(val value: Int) {
2424

2525
enum class ResponseType {
2626
SSE, JSON;
27-
28-
companion object {
29-
fun of(str: String): ResponseType = when (str) {
30-
"SSE" -> SSE
31-
"JSON" -> JSON
32-
else -> JSON
33-
}
34-
}
3527
}
3628

3729

src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
4040
private val xingHuoApiKeyParam by LLMParam.creating { Password(settings.xingHuoApiKey) }
4141
private val xingHuoApiSecretParam by LLMParam.creating { Password(settings.xingHuoApiSecrect) }
4242

43-
private val customEngineResponseTypeParam by LLMParam.creating { ComboBox(ResponseType.of(settings.customEngineResponseType).name, ResponseType.values().map { it.name }.toList()) }
43+
private val customEngineResponseTypeParam by LLMParam.creating { ComboBox(settings.customEngineResponseType, ResponseType.values().map { it.name }.toList()) }
4444
private val customEngineResponseFormatParam by LLMParam.creating { Editable(settings.customEngineResponseFormat) }
4545
private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestFormat) }
4646

@@ -224,7 +224,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
224224
customEngineToken = customEngineTokenParam.value
225225
customPrompts = customEnginePrompt.text
226226
openAiModel = openAIModelsParam.value
227-
customEngineResponseType = customEngineResponseFormatParam.value
227+
customEngineResponseFormat = customEngineResponseFormatParam.value
228228
customEngineRequestFormat = customEngineRequestBodyFormatParam.value
229229
delaySeconds = delaySecondsParam.value
230230
}

src/test/kotlin/cc/unitmesh/devti/settings/LLMSettingComponentKtTest.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cc.unitmesh.devti.settings
22

3-
import org.junit.Assert.*
3+
import com.jayway.jsonpath.JsonPath
4+
import org.junit.Assert.assertEquals
45
import org.junit.Test
56

67
class LLMSettingComponentKtTest {
@@ -88,6 +89,14 @@ class LLMSettingComponentKtTest {
8889
s = 2
8990
assertEquals("callback should be called", 1, count)
9091
assertEquals("s should be changed to 2", 2, s)
92+
}
9193

94+
@Test
95+
fun testJsonPath() {
96+
val content = """
97+
{"id":"chatcmpl-8Vf3lDVkktbu4v1SYXGt0LTzCVgdC","object":"chat.completion.chunk","created":1702556693,"model":"gpt-35-turbo-16k","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"abc"},"finish_reason":null}]}
98+
"""
99+
val result = JsonPath.parse(content)?.read<String>("\$.choices[0].delta.content")
100+
println("result is $result")
92101
}
93102
}

0 commit comments

Comments
 (0)