@@ -22,11 +22,12 @@ import kotlinx.coroutines.flow.callbackFlow
22
22
import kotlinx.coroutines.withContext
23
23
import kotlinx.serialization.Serializable
24
24
import kotlinx.serialization.encodeToString
25
- import kotlinx.serialization.json.Json
25
+ import kotlinx.serialization.json.*
26
26
import okhttp3.MediaType.Companion.toMediaTypeOrNull
27
27
import okhttp3.OkHttpClient
28
28
import okhttp3.Request
29
29
import okhttp3.RequestBody
30
+ import org.jetbrains.annotations.VisibleForTesting
30
31
import java.time.Duration
31
32
32
33
@Serializable
@@ -40,8 +41,9 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
40
41
private val autoDevSettingsState = AutoDevSettingsState .getInstance()
41
42
private val url get() = autoDevSettingsState.customEngineServer
42
43
private val key get() = autoDevSettingsState.customEngineToken
43
- private val engineFormat get() = autoDevSettingsState.customEngineResponseFormat
44
- private val customPromptConfig: CustomPromptConfig ?
44
+ private val requestFormat: String get() = autoDevSettingsState.customEngineRequestFormat
45
+ private val responseFormat get() = autoDevSettingsState.customEngineResponseFormat
46
+ private val customPromptConfig: CustomPromptConfig
45
47
get() {
46
48
val prompts = autoDevSettingsState.customPrompts
47
49
return CustomPromptConfig .tryParse(prompts)
@@ -73,18 +75,24 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
73
75
messages + = Message (" user" , promptText)
74
76
75
77
val customRequest = CustomRequest (messages)
76
- val requestContent = Json .encodeToString< CustomRequest >(customRequest )
78
+ val requestContent = customRequest.updateCustomFormat(requestFormat )
77
79
78
80
val body = RequestBody .create(" application/json; charset=utf-8" .toMediaTypeOrNull(), requestContent)
79
- logger.info(" Requesting from $body " )
80
81
81
82
val builder = Request .Builder ()
82
83
if (key.isNotEmpty()) {
83
84
builder.addHeader(" Authorization" , " Bearer $key " )
85
+ builder.addHeader(" Content-Type" , " application/json" )
84
86
}
87
+ builder.appendCustomHeaders(requestFormat)
85
88
86
- client = client.newBuilder().readTimeout(timeout).build()
87
- val request = builder.url(url).post(body).build()
89
+ client = client.newBuilder()
90
+ .readTimeout(timeout)
91
+ .build()
92
+ val request = builder
93
+ .url(url)
94
+ .post(body)
95
+ .build()
88
96
89
97
val call = client.newCall(request)
90
98
val emitDone = false
@@ -95,19 +103,22 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
95
103
}, BackpressureStrategy .BUFFER )
96
104
97
105
try {
98
- logger.info(" Starting to stream:" )
99
106
return callbackFlow {
100
107
withContext(Dispatchers .IO ) {
101
108
sseFlowable
102
- .doOnError(Throwable ::printStackTrace)
103
- .blockingForEach { sse ->
104
- if (engineFormat.isNotEmpty()) {
105
- val chunk: String = JsonPath .parse(sse!! .data)?.read<String >(engineFormat)
106
- ? : throw Exception (" Failed to parse chunk: ${sse.data} , format: $engineFormat " )
107
- trySend(chunk)
108
- } else {
109
- val result: ChatCompletionResult =
110
- ObjectMapper ().readValue(sse!! .data, ChatCompletionResult ::class .java)
109
+ .doOnError{
110
+ it.printStackTrace()
111
+ close()
112
+ }
113
+ .blockingForEach { sse ->
114
+ if (responseFormat.isNotEmpty()) {
115
+ val chunk: String = JsonPath .parse(sse!! .data)?.read(responseFormat)
116
+ ? : throw Exception (" Failed to parse chunk" )
117
+ logger.warn(" got msg: $chunk " )
118
+ trySend(chunk)
119
+ } else {
120
+ val result: ChatCompletionResult =
121
+ ObjectMapper ().readValue(sse!! .data, ChatCompletionResult ::class .java)
111
122
112
123
val completion = result.choices[0 ].message
113
124
if (completion != null && completion.content != null ) {
@@ -134,7 +145,7 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
134
145
135
146
val body = RequestBody .create(" application/json; charset=utf-8" .toMediaTypeOrNull(), requestContent)
136
147
137
- logger.info(" Requesting from $ body" )
148
+ logger.info(" Requesting form: $requestContent ${ body.toString()} " )
138
149
val builder = Request .Builder ()
139
150
if (key.isNotEmpty()) {
140
151
builder.addHeader(" Authorization" , " Bearer $key " )
@@ -157,4 +168,69 @@ class CustomLLMProvider(val project: Project) : LLMProvider {
157
168
return " "
158
169
}
159
170
}
160
- }
171
+ }
172
+
173
+ @VisibleForTesting
174
+ fun Request.Builder.appendCustomHeaders (customRequestHeader : String ): Request .Builder = apply {
175
+ runCatching {
176
+ Json .parseToJsonElement(customRequestHeader)
177
+ .jsonObject[" customHeaders" ].let { customFields ->
178
+ customFields?.jsonObject?.forEach { (key, value) ->
179
+ header(key, value.jsonPrimitive.content)
180
+ }
181
+ }
182
+ }.onFailure {
183
+ // should I warn user?
184
+ println (" Failed to parse custom request header ${it.message} " )
185
+ }
186
+ }
187
+
188
+ @VisibleForTesting
189
+ fun JsonObject.updateCustomBody (customRequest : String ): JsonObject {
190
+ return runCatching {
191
+ buildJsonObject {
192
+ // copy origin object
193
+ this @updateCustomBody.forEach { u, v -> put(u, v) }
194
+
195
+ val customRequestJson = Json .parseToJsonElement(customRequest).jsonObject
196
+
197
+ customRequestJson[" customFields" ]?.let { customFields ->
198
+ customFields.jsonObject.forEach { (key, value) ->
199
+ put(key, value.jsonPrimitive.content)
200
+ }
201
+ }
202
+
203
+
204
+ // TODO clean code with magic literals
205
+ var roleKey = " role"
206
+ var contentKey = " message"
207
+ customRequestJson.jsonObject[" messageKeys" ]?.let {
208
+ roleKey = it.jsonObject[" role" ]?.jsonPrimitive?.content ? : " role"
209
+ contentKey = it.jsonObject[" content" ]?.jsonPrimitive?.content ? : " message"
210
+ }
211
+
212
+ val messages: JsonArray = this @updateCustomBody[" messages" ]?.jsonArray ? : buildJsonArray { }
213
+
214
+
215
+ this .put(" messages" , buildJsonArray {
216
+ messages.forEach { message ->
217
+ val role: String = message.jsonObject[" role" ]?.jsonPrimitive?.content ? : " user"
218
+ val content: String = message.jsonObject[" message" ]?.jsonPrimitive?.content ? : " "
219
+ add(buildJsonObject {
220
+ put(roleKey, role)
221
+ put(contentKey, content)
222
+ })
223
+ }
224
+ })
225
+ }
226
+ }.getOrElse {
227
+ logger<CustomLLMProvider >().error(" Failed to parse custom request body" , it)
228
+ this
229
+ }
230
+ }
231
+
232
+ fun CustomRequest.updateCustomFormat (format : String ): String {
233
+ val requestContentOri = Json .encodeToString<CustomRequest >(this )
234
+ return Json .parseToJsonElement(requestContentOri)
235
+ .jsonObject.updateCustomBody(format).toString()
236
+ }
0 commit comments