Skip to content

Commit eda8ae4

Browse files
authored
Merge pull request #64 from hotip/master
2 parents 8642d8d + 433fb3c commit eda8ae4

File tree

12 files changed

+352
-399
lines changed

12 files changed

+352
-399
lines changed

docs/customize/custom-llm-server.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,58 @@ class ChatInput(BaseModel):
4141
async def chat(msg: ChatInput):
4242
return StreamingResponse(fetch_chat_completion(msg.messages), media_type="text/event-stream")
4343
```
44+
45+
## Custom response format
46+
47+
We used [JsonPathKt](https://github.com/codeniko/JsonPathKt) to parse response,
48+
currently we only extract the first choice and only the response message.
49+
If your response is this format:
50+
51+
```json
52+
{
53+
"id": "chatcmpl-123",
54+
"object": "chat.completion.chunk",
55+
"created": 1677652288,
56+
"model": "gpt-3.5-turbo",
57+
"choices": [{
58+
"index": 0,
59+
"delta": {
60+
"content": "Hello"
61+
},
62+
"finish_reason": "stop"
63+
}]
64+
}
65+
```
66+
You need to set the `response format` to:
67+
68+
```text
69+
$.choices[0].message.delta.content
70+
```
71+
72+
## Custom request format
73+
74+
Only support amount of request parameters like OpenAI does.
75+
Only support http request that don't need encryption keys(like websocket)
76+
77+
78+
### Custom Request (header and body)
79+
80+
You can add top level field to the request body,
81+
And custom the origin keys for `role`, `messsage`
82+
83+
```json
84+
{
85+
"customHeaders": { "my header": "my value" },
86+
"customFields": {"user": "userid", "date": "2012"},
87+
"messageKeys": {"role": "role", "content": "message"}
88+
}
89+
```
90+
91+
and the request body will be:
92+
93+
```json
94+
{
95+
"user": "userid",
96+
"messages": [{"role": "user", "message": "..."}]
97+
}
98+
```

src/main/kotlin/cc/unitmesh/devti/gui/chat/ChatCodingPanel.kt

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import com.intellij.util.ui.UIUtil
2525
import kotlinx.coroutines.Dispatchers
2626
import kotlinx.coroutines.delay
2727
import kotlinx.coroutines.flow.Flow
28+
import kotlinx.coroutines.flow.catch
2829
import kotlinx.coroutines.flow.collect
2930
import kotlinx.coroutines.withContext
3031
import java.awt.event.ActionListener
@@ -199,13 +200,16 @@ class ChatCodingPanel(private val chatCodingService: ChatCodingService, val disp
199200
val startTime = System.currentTimeMillis() // 记录代码开始执行的时间
200201

201202
var text = ""
202-
runCatching {
203-
content.collect {
204-
text += it
205-
messageView.updateSourceContent(text)
206-
messageView.updateContent(text)
207-
messageView.scrollToBottom()
208-
}
203+
content.catch {
204+
it.printStackTrace()
205+
}.collect {
206+
text += it
207+
208+
// 以下两个 API 设计不合理,如果必须要同时调用,那就只提供一个就好了
209+
messageView.updateSourceContent(text)
210+
messageView.updateContent(text)
211+
212+
messageView.scrollToBottom()
209213
}
210214

211215
if (delaySeconds.isNotEmpty()) {

src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import okhttp3.OkHttpClient
2626
import okhttp3.Request
2727
import okhttp3.RequestBody
2828

29+
2930
@Serializable
3031
data class SimpleOpenAIFormat(val role: String, val content: String) {
3132
companion object {

src/main/kotlin/cc/unitmesh/devti/llms/azure/ResponseBodyCallback.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ import java.io.InputStreamReader
3535
import java.nio.charset.StandardCharsets
3636

3737
class AutoDevHttpException(error: String, val statusCode: Int) : RuntimeException(error) {
38+
override fun toString(): String {
39+
return "AutoDevHttpException(statusCode=$statusCode, message=$message)"
40+
}
3841
}
3942

4043
/**
@@ -80,6 +83,11 @@ class ResponseBodyCallback(private val emitter: FlowableEmitter<SSE>, private va
8083
null
8184
}
8285

86+
line!!.startsWith("{") -> {
87+
logger<ResponseBodyCallback>().warn("msg starts with { $line")
88+
emitter.onNext(SSE(line!!))
89+
null
90+
}
8391
else -> {
8492
throw SSEFormatException("Invalid sse format! $line")
8593
}

src/main/kotlin/cc/unitmesh/devti/llms/custom/CustomLLMProvider.kt

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ import kotlinx.coroutines.flow.callbackFlow
2222
import kotlinx.coroutines.withContext
2323
import kotlinx.serialization.Serializable
2424
import kotlinx.serialization.encodeToString
25-
import kotlinx.serialization.json.Json
25+
import kotlinx.serialization.json.*
2626
import okhttp3.MediaType.Companion.toMediaTypeOrNull
2727
import okhttp3.OkHttpClient
2828
import okhttp3.Request
2929
import okhttp3.RequestBody
30+
import org.jetbrains.annotations.VisibleForTesting
3031
import java.time.Duration
3132

3233
@Serializable
@@ -40,8 +41,9 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
4041
private val autoDevSettingsState = AutoDevSettingsState.getInstance()
4142
private val url get() = autoDevSettingsState.customEngineServer
4243
private val key get() = autoDevSettingsState.customEngineToken
43-
private val engineFormat get() = autoDevSettingsState.customEngineResponseFormat
44-
private val customPromptConfig: CustomPromptConfig?
44+
private val requestFormat: String get() = autoDevSettingsState.customEngineRequestFormat
45+
private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat
46+
private val customPromptConfig: CustomPromptConfig
4547
get() {
4648
val prompts = autoDevSettingsState.customPrompts
4749
return CustomPromptConfig.tryParse(prompts)
@@ -73,18 +75,24 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
7375
messages += Message("user", promptText)
7476

7577
val customRequest = CustomRequest(messages)
76-
val requestContent = Json.encodeToString<CustomRequest>(customRequest)
78+
val requestContent = customRequest.updateCustomFormat(requestFormat)
7779

7880
val body = RequestBody.create("application/json; charset=utf-8".toMediaTypeOrNull(), requestContent)
79-
logger.info("Requesting from $body")
8081

8182
val builder = Request.Builder()
8283
if (key.isNotEmpty()) {
8384
builder.addHeader("Authorization", "Bearer $key")
85+
builder.addHeader("Content-Type", "application/json")
8486
}
87+
builder.appendCustomHeaders(requestFormat)
8588

86-
client = client.newBuilder().readTimeout(timeout).build()
87-
val request = builder.url(url).post(body).build()
89+
client = client.newBuilder()
90+
.readTimeout(timeout)
91+
.build()
92+
val request = builder
93+
.url(url)
94+
.post(body)
95+
.build()
8896

8997
val call = client.newCall(request)
9098
val emitDone = false
@@ -95,19 +103,22 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
95103
}, BackpressureStrategy.BUFFER)
96104

97105
try {
98-
logger.info("Starting to stream:")
99106
return callbackFlow {
100107
withContext(Dispatchers.IO) {
101108
sseFlowable
102-
.doOnError(Throwable::printStackTrace)
103-
.blockingForEach { sse ->
104-
if (engineFormat.isNotEmpty()) {
105-
val chunk: String = JsonPath.parse(sse!!.data)?.read<String>(engineFormat)
106-
?: throw Exception("Failed to parse chunk: ${sse.data}, format: $engineFormat")
107-
trySend(chunk)
108-
} else {
109-
val result: ChatCompletionResult =
110-
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
109+
.doOnError{
110+
it.printStackTrace()
111+
close()
112+
}
113+
.blockingForEach { sse ->
114+
if (responseFormat.isNotEmpty()) {
115+
val chunk: String = JsonPath.parse(sse!!.data)?.read(responseFormat)
116+
?: throw Exception("Failed to parse chunk")
117+
logger.warn("got msg: $chunk")
118+
trySend(chunk)
119+
} else {
120+
val result: ChatCompletionResult =
121+
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)
111122

112123
val completion = result.choices[0].message
113124
if (completion != null && completion.content != null) {
@@ -134,7 +145,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
134145

135146
val body = RequestBody.create("application/json; charset=utf-8".toMediaTypeOrNull(), requestContent)
136147

137-
logger.info("Requesting from $body")
148+
logger.info("Requesting form: $requestContent ${body.toString()}")
138149
val builder = Request.Builder()
139150
if (key.isNotEmpty()) {
140151
builder.addHeader("Authorization", "Bearer $key")
@@ -157,4 +168,69 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
157168
return ""
158169
}
159170
}
160-
}
171+
}
172+
173+
@VisibleForTesting
174+
fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Builder = apply {
175+
runCatching {
176+
Json.parseToJsonElement(customRequestHeader)
177+
.jsonObject["customHeaders"].let { customFields ->
178+
customFields?.jsonObject?.forEach { (key, value) ->
179+
header(key, value.jsonPrimitive.content)
180+
}
181+
}
182+
}.onFailure {
183+
// should I warn user?
184+
println("Failed to parse custom request header ${it.message}")
185+
}
186+
}
187+
188+
@VisibleForTesting
189+
fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
190+
return runCatching {
191+
buildJsonObject {
192+
// copy origin object
193+
this@updateCustomBody.forEach { u, v -> put(u, v) }
194+
195+
val customRequestJson = Json.parseToJsonElement(customRequest).jsonObject
196+
197+
customRequestJson["customFields"]?.let { customFields ->
198+
customFields.jsonObject.forEach { (key, value) ->
199+
put(key, value.jsonPrimitive.content)
200+
}
201+
}
202+
203+
204+
// TODO clean code with magic literals
205+
var roleKey = "role"
206+
var contentKey = "message"
207+
customRequestJson.jsonObject["messageKeys"]?.let {
208+
roleKey = it.jsonObject["role"]?.jsonPrimitive?.content ?: "role"
209+
contentKey = it.jsonObject["content"]?.jsonPrimitive?.content ?: "message"
210+
}
211+
212+
val messages: JsonArray = this@updateCustomBody["messages"]?.jsonArray ?: buildJsonArray { }
213+
214+
215+
this.put("messages", buildJsonArray {
216+
messages.forEach { message ->
217+
val role: String = message.jsonObject["role"]?.jsonPrimitive?.content ?: "user"
218+
val content: String = message.jsonObject["message"]?.jsonPrimitive?.content ?: ""
219+
add(buildJsonObject {
220+
put(roleKey, role)
221+
put(contentKey, content)
222+
})
223+
}
224+
})
225+
}
226+
}.getOrElse {
227+
logger<CustomLLMProvider>().error("Failed to parse custom request body", it)
228+
this
229+
}
230+
}
231+
232+
fun CustomRequest.updateCustomFormat(format: String): String {
233+
val requestContentOri = Json.encodeToString<CustomRequest>(this)
234+
return Json.parseToJsonElement(requestContentOri)
235+
.jsonObject.updateCustomBody(format).toString()
236+
}

src/main/kotlin/cc/unitmesh/devti/llms/xianghuo/XingHuoProvider.kt

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package cc.unitmesh.devti.llms.xianghuo
44

55
import cc.unitmesh.devti.llms.LLMProvider
66
import cc.unitmesh.devti.settings.AutoDevSettingsState
7+
import cc.unitmesh.devti.settings.XingHuoApiVersion
78
import com.intellij.openapi.components.Service
89
import com.intellij.openapi.project.Project
910
import kotlinx.coroutines.ExperimentalCoroutinesApi
@@ -26,6 +27,14 @@ class XingHuoProvider(val project: Project) : LLMProvider {
2627
private val secrectKey: String
2728
get() = autoDevSettingsState.xingHuoApiSecrect
2829

30+
private val apiVersion: XingHuoApiVersion
31+
get() = autoDevSettingsState.xingHuoApiVersion
32+
private val XingHuoApiVersion.asGeneralDomain
33+
get() = when (this) {
34+
XingHuoApiVersion.V1 -> ""
35+
XingHuoApiVersion.V2 -> "v2"
36+
else -> "v3"
37+
}
2938

3039
private val appid: String
3140
get() = autoDevSettingsState.xingHuoAppId
@@ -119,7 +128,7 @@ class XingHuoProvider(val project: Project) : LLMProvider {
119128
val header = """
120129
|host: spark-api.xf-yun.com
121130
|date: $date
122-
|GET /v1.1/chat HTTP/1.1
131+
|GET /v${apiVersion.value}.1/chat HTTP/1.1
123132
""".trimMargin()
124133
val signature = hmacsha256.doFinal(header.toByteArray()).encodeBase64()
125134
val authorization =
@@ -130,7 +139,7 @@ class XingHuoProvider(val project: Project) : LLMProvider {
130139
"date" to date,
131140
"host" to "spark-api.xf-yun.com"
132141
)
133-
val urlBuilder = "https://spark-api.xf-yun.com/v1.1/chat".toHttpUrl().newBuilder()
142+
val urlBuilder = "https://spark-api.xf-yun.com/v${apiVersion.value}.1/chat".toHttpUrl().newBuilder()
134143
params.forEach {
135144
urlBuilder.addQueryParameter(it.key, it.value)
136145
}
@@ -147,7 +156,7 @@ class XingHuoProvider(val project: Project) : LLMProvider {
147156
},
148157
"parameter": {
149158
"chat": {
150-
"domain": "general",
159+
"domain": "general${apiVersion.asGeneralDomain}",
151160
"temperature": 0.5,
152161
"max_tokens": 1024
153162
}

0 commit comments

Comments
 (0)