Skip to content

Commit a72f085

Browse files
committed
feat: 增加自定义请求。可修改请求 Header 及 reqeust body
1 parent e0cca51 commit a72f085

File tree

8 files changed

+111
-38
lines changed

8 files changed

+111
-38
lines changed

src/main/kotlin/cc/unitmesh/devti/gui/chat/ChatCodingPanel.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import kotlinx.coroutines.delay
2727
import kotlinx.coroutines.flow.Flow
2828
import kotlinx.coroutines.flow.catch
2929
import kotlinx.coroutines.flow.collect
30+
import kotlinx.coroutines.flow.onCompletion
3031
import kotlinx.coroutines.withContext
3132
import java.awt.event.ActionListener
3233
import java.awt.event.MouseAdapter
@@ -152,7 +153,7 @@ class ChatCodingPanel(private val chatCodingService: ChatCodingService, val disp
152153

153154
suspend fun updateMessage(content: Flow<String>): String {
154155
if (myList.componentCount > 0) {
155-
myList.remove(myList.componentCount - 1)
156+
myList.remove(myList.componentCount - 1)
156157
}
157158

158159
progressBar.isVisible = true
@@ -200,9 +201,12 @@ class ChatCodingPanel(private val chatCodingService: ChatCodingService, val disp
200201
val startTime = System.currentTimeMillis() // 记录代码开始执行的时间
201202

202203
var text = ""
203-
content.catch {
204+
content.onCompletion {
205+
println("onCompletion ${it?.message}")
206+
}.catch {
204207
it.printStackTrace()
205208
}.collect {
209+
println("got message $it")
206210
text += it
207211

208212
// 以下两个 API 设计不合理,如果必须要同时调用,那就只提供一个就好了

src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package cc.unitmesh.devti.llms.custom
22

3-
import cc.unitmesh.devti.custom.action.CustomPromptConfig
43
import cc.unitmesh.devti.gui.chat.ChatRole
54
import cc.unitmesh.devti.llms.LLMProvider
65
import cc.unitmesh.devti.settings.AutoDevSettingsState
6+
import cc.unitmesh.devti.settings.ResponseType
77
import com.fasterxml.jackson.databind.ObjectMapper
88
import com.intellij.openapi.components.Service
99
import com.intellij.openapi.diagnostic.logger
@@ -17,12 +17,15 @@ import io.reactivex.Flowable
1717
import io.reactivex.FlowableEmitter
1818
import kotlinx.coroutines.Dispatchers
1919
import kotlinx.coroutines.ExperimentalCoroutinesApi
20+
import kotlinx.coroutines.channels.awaitClose
2021
import kotlinx.coroutines.flow.Flow
22+
import kotlinx.coroutines.flow.MutableSharedFlow
2123
import kotlinx.coroutines.flow.callbackFlow
2224
import kotlinx.coroutines.withContext
2325
import kotlinx.serialization.Serializable
2426
import kotlinx.serialization.encodeToString
2527
import kotlinx.serialization.json.*
28+
import okhttp3.Call
2629
import okhttp3.MediaType.Companion.toMediaTypeOrNull
2730
import okhttp3.OkHttpClient
2831
import okhttp3.Request
@@ -39,15 +42,15 @@ data class CustomRequest(val messages: List<Message>)
3942
@Service(Service.Level.PROJECT)
4043
class CustomLLMProvider(val project: Project) : LLMProvider {
4144
private val autoDevSettingsState = AutoDevSettingsState.getInstance()
42-
private val url get() = autoDevSettingsState.customEngineServer
43-
private val key get() = autoDevSettingsState.customEngineToken
44-
private val requestFormat: String get() = autoDevSettingsState.customEngineRequestFormat
45-
private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat
46-
private val customPromptConfig: CustomPromptConfig
47-
get() {
48-
val prompts = autoDevSettingsState.customPrompts
49-
return CustomPromptConfig.tryParse(prompts)
50-
}
45+
private val url
46+
get() = autoDevSettingsState.customEngineServer
47+
private val key
48+
get() = autoDevSettingsState.customEngineToken
49+
private val requestFormat: String
50+
get() = autoDevSettingsState.customEngineRequestFormat
51+
private val responseFormat
52+
get() = autoDevSettingsState.customEngineResponseType
53+
5154
private var client = OkHttpClient()
5255
private val timeout = Duration.ofSeconds(600)
5356
private val messages: MutableList<Message> = ArrayList()
@@ -66,7 +69,6 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
6669
return this.prompt(promptText, "")
6770
}
6871

69-
@OptIn(ExperimentalCoroutinesApi::class)
7072
override fun stream(promptText: String, systemPrompt: String, keepHistory: Boolean): Flow<String> {
7173
if (!keepHistory) {
7274
clearMessage()
@@ -86,15 +88,32 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
8688
}
8789
builder.appendCustomHeaders(requestFormat)
8890

89-
client = client.newBuilder()
90-
.readTimeout(timeout)
91-
.build()
92-
val request = builder
93-
.url(url)
94-
.post(body)
95-
.build()
91+
client = client.newBuilder().readTimeout(timeout).build()
92+
val call = client.newCall(builder.url(url).post(body).build())
93+
94+
if (autoDevSettingsState.customEngineResponseType == ResponseType.SSE.name) {
95+
return streamSSE(call)
96+
} else {
97+
return streamJson(call)
98+
}
99+
}
100+
101+
102+
private val _responseFlow = MutableSharedFlow<String>()
96103

97-
val call = client.newCall(request)
104+
@OptIn(ExperimentalCoroutinesApi::class)
105+
private fun streamJson(call: Call): Flow<String> = callbackFlow {
106+
call.enqueue(JSONBodyResponseCallback(responseFormat) {
107+
withContext(Dispatchers.IO) {
108+
send(it)
109+
}
110+
close()
111+
})
112+
awaitClose()
113+
}
114+
115+
@OptIn(ExperimentalCoroutinesApi::class)
116+
private fun streamSSE(call: Call): Flow<String> {
98117
val emitDone = false
99118

100119
val sseFlowable = Flowable
@@ -106,29 +125,29 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
106125
return callbackFlow {
107126
withContext(Dispatchers.IO) {
108127
sseFlowable
109-
.doOnError{
110-
it.printStackTrace()
111-
close()
112-
}
113-
.blockingForEach { sse ->
114-
if (responseFormat.isNotEmpty()) {
115-
val chunk: String = JsonPath.parse(sse!!.data)?.read(responseFormat)
116-
?: throw Exception("Failed to parse chunk")
117-
logger.warn("got msg: $chunk")
118-
trySend(chunk)
119-
} else {
120-
val result: ChatCompletionResult =
121-
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
128+
.doOnError {
129+
it.printStackTrace()
130+
close()
131+
}
132+
.blockingForEach { sse ->
133+
if (responseFormat.isNotEmpty()) {
134+
val chunk: String = JsonPath.parse(sse!!.data)?.read(responseFormat)
135+
?: throw Exception("Failed to parse chunk")
136+
logger.warn("got msg: $chunk")
137+
trySend(chunk)
138+
} else {
139+
val result: ChatCompletionResult =
140+
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
122141

123142
val completion = result.choices[0].message
124143
if (completion != null && completion.content != null) {
125144
trySend(completion.content)
126145
}
127146
}
128147
}
129-
130148
close()
131149
}
150+
awaitClose()
132151
}
133152
} catch (e: Exception) {
134153
logger.error("Failed to stream", e)
@@ -174,7 +193,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
174193
fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Builder = apply {
175194
runCatching {
176195
Json.parseToJsonElement(customRequestHeader)
177-
.jsonObject["customHeaders"].let { customFields ->
196+
.jsonObject["customHeaders"].let { customFields ->
178197
customFields?.jsonObject?.forEach { (key, value) ->
179198
header(key, value.jsonPrimitive.content)
180199
}
@@ -232,5 +251,5 @@ fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
232251
fun CustomRequest.updateCustomFormat(format: String): String {
233252
val requestContentOri = Json.encodeToString<CustomRequest>(this)
234253
return Json.parseToJsonElement(requestContentOri)
235-
.jsonObject.updateCustomBody(format).toString()
254+
.jsonObject.updateCustomBody(format).toString()
236255
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package cc.unitmesh.devti.llms.custom
2+
3+
import com.nfeld.jsonpathkt.JsonPath
4+
import com.nfeld.jsonpathkt.extension.read
5+
import io.kotest.common.runBlocking
6+
import okhttp3.Call
7+
import okhttp3.Callback
8+
import okhttp3.Response
9+
import java.io.IOException
10+
11+
class JSONBodyResponseCallback(private val responseFormat: String,private val callback: suspend (String)->Unit): Callback {
12+
override fun onFailure(call: Call, e: IOException) {
13+
runBlocking {
14+
callback("error. ${e.message}")
15+
}
16+
}
17+
18+
override fun onResponse(call: Call, response: Response) {
19+
val responseContent: String = JsonPath.parse(response.body?.string())?.read(responseFormat) ?: ""
20+
21+
runBlocking() {
22+
callback(responseContent)
23+
}
24+
25+
}
26+
}

src/main/kotlin/cc/unitmesh/devti/settings/AppSettingsComponent.kt

Whitespace-only changes.

src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ class AutoDevSettingsState : PersistentStateComponent<AutoDevSettingsState> {
2828
var xingHuoApiSecrect = ""
2929
var xingHuoApiKey = ""
3030

31+
32+
/**
33+
* 自定义引擎返回的数据格式是否是 [SSE](https://www.ruanyifeng.com/blog/2017/05/server-sent_events.html) 格式
34+
*/
35+
var customEngineResponseType = ResponseType.SSE.name
3136
/**
3237
* should be a json path
3338
*/

src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ enum class XingHuoApiVersion(val value: Int) {
2222
}
2323
}
2424

25+
enum class ResponseType {
26+
SSE, JSON;
27+
28+
companion object {
29+
fun of(str: String): ResponseType = when (str) {
30+
"SSE" -> SSE
31+
"JSON" -> JSON
32+
else -> JSON
33+
}
34+
}
35+
}
36+
2537

2638
val DEFAULT_AI_ENGINE = AI_ENGINES[0]
2739

src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
3939
private val xingHuoAppIDParam by LLMParam.creating { Editable(settings.xingHuoAppId) }
4040
private val xingHuoApiKeyParam by LLMParam.creating { Password(settings.xingHuoApiKey) }
4141
private val xingHuoApiSecretParam by LLMParam.creating { Password(settings.xingHuoApiSecrect) }
42+
43+
private val customEngineResponseTypeParam by LLMParam.creating { ComboBox(ResponseType.of(settings.customEngineResponseType).name, ResponseType.values().map { it.name }.toList()) }
4244
private val customEngineResponseFormatParam by LLMParam.creating { Editable(settings.customEngineResponseFormat) }
4345
private val customEngineRequestBodyFormatParam by LLMParam.creating { Editable(settings.customEngineRequestFormat) }
4446

@@ -79,6 +81,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
7981
customOpenAIHostParam,
8082
),
8183
AIEngines.Custom to listOf(
84+
customEngineResponseTypeParam,
8285
customEngineServerParam,
8386
customEngineTokenParam,
8487
customEngineResponseFormatParam,
@@ -185,6 +188,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
185188
openAIKeyParam.value = openAiKey
186189
customOpenAIHostParam.value = customOpenAiHost
187190
customEngineServerParam.value = customEngineServer
191+
customEngineResponseTypeParam.value = customEngineResponseType
188192
customEngineTokenParam.value = customEngineToken
189193
openAIModelsParam.value = openAiModel
190194
xingHuoApiVersionParam.value = xingHuoApiVersion.toString()
@@ -216,10 +220,11 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
216220
aiEngine = aiEngineParam.value
217221
language = languageParam.value
218222
customEngineServer = customEngineServerParam.value
223+
customEngineResponseType = customEngineResponseTypeParam.value
219224
customEngineToken = customEngineTokenParam.value
220225
customPrompts = customEnginePrompt.text
221226
openAiModel = openAIModelsParam.value
222-
customEngineResponseFormat = customEngineResponseFormatParam.value
227+
customEngineResponseType = customEngineResponseFormatParam.value
223228
customEngineRequestFormat = customEngineRequestBodyFormatParam.value
224229
delaySeconds = delaySecondsParam.value
225230
}
@@ -239,6 +244,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) {
239244
settings.aiEngine != aiEngineParam.value ||
240245
settings.language != languageParam.value ||
241246
settings.customEngineServer != customEngineServerParam.value ||
247+
settings.customEngineResponseType != customEngineResponseTypeParam.value ||
242248
settings.customEngineToken != customEngineTokenParam.value ||
243249
settings.customPrompts != customEnginePrompt.text ||
244250
settings.openAiModel != openAIModelsParam.value ||

src/main/resources/messages/AutoDevBundle.properties

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ settings.xingHuoApiVersionParam=XingHuo API Version
6666

6767
settings.delaySecondsParam=Quest Delay Seconds
6868
settings.customEngineResponseFormatParam=Custom Response Format (Json Path)
69+
settings.customEngineResponseTypeParam=Custom Response Type
6970
settings.customEngineRequestBodyFormatParam=Custom Request Body Format (Json Path)
7071
settings.customEngineRequestHeaderFormatParam=Custom Request Header Format (Json Path)
7172
settings.external.counit.enable.label=Enable CoUnit (Experimental)

0 commit comments

Comments
 (0)