@@ -45,6 +45,9 @@ abstract class LLMProvider2 protected constructor(
45
45
/* * The job that is sending the request */
46
46
protected var _sendingJob : Job ? = null
47
47
48
+ /* * The current EventSource for SSE streaming */
49
+ protected var _currentEventSource : EventSource ? = null
50
+
48
51
/* *
49
52
* 为会话创建一个 CoroutineScope
50
53
*
@@ -70,7 +73,7 @@ abstract class LLMProvider2 protected constructor(
70
73
var result = " "
71
74
var sessionId: String? = null
72
75
73
- factory.newEventSource(request, object : EventSourceListener () {
76
+ val eventSource = factory.newEventSource(request, object : EventSourceListener () {
74
77
override fun onEvent (eventSource : EventSource , id : String? , type : String? , data : String ) {
75
78
super .onEvent(eventSource, id, type, data)
76
79
if (data == " [DONE]" ) {
@@ -90,18 +93,21 @@ abstract class LLMProvider2 protected constructor(
90
93
logger.warn(IllegalStateException (" cannot parse with responseResolver: ${responseResolver} , ori data: $data " ))
91
94
" "
92
95
}
96
+
93
97
result + = chunk
94
98
onEvent(SessionMessageItem (Message (" system" , result)))
95
99
}
96
100
97
101
override fun onClosed (eventSource : EventSource ) {
102
+ _currentEventSource = null
98
103
if (result.isEmpty()) {
99
104
onFailure(IllegalStateException (" response is empty" ))
100
105
}
101
106
onClosed()
102
107
}
103
108
104
109
override fun onFailure (eventSource : EventSource , t : Throwable ? , response : Response ? ) {
110
+ _currentEventSource = null
105
111
onFailure(
106
112
t ? : RuntimeException (" error: ${response?.code} ${response?.message} ${response?.body?.string()} " )
107
113
)
@@ -111,6 +117,9 @@ abstract class LLMProvider2 protected constructor(
111
117
onOpen()
112
118
}
113
119
})
120
+
121
+ // Store the EventSource reference so we can cancel it later
122
+ _currentEventSource = eventSource
114
123
}
115
124
116
125
/* *
@@ -142,11 +151,9 @@ abstract class LLMProvider2 protected constructor(
142
151
sessionId = sessionId,
143
152
timestamp = System .currentTimeMillis()
144
153
)
145
-
146
154
ApplicationManager .getApplication().messageBus
147
155
.syncPublisher(TokenUsageListener .TOPIC )
148
156
.onTokenUsage(tokenUsageEvent)
149
-
150
157
logger.info(" Token usage event published: prompt=${usage.promptTokens} , completion=${usage.completionTokens} , total=${usage.totalTokens} " )
151
158
}
152
159
} catch (e: Exception ) {
@@ -182,10 +189,12 @@ abstract class LLMProvider2 protected constructor(
182
189
logger.info(" response: $content " )
183
190
val result: String = runCatching<String > {
184
191
val result: String? = JsonPath .parse(content)?.read(responseResolver)
185
- result ? : throw java.lang.IllegalStateException (" cannot parse with responseResolver: ${responseResolver} , ori data: $content " )
192
+ result
193
+ ? : throw java.lang.IllegalStateException (" cannot parse with responseResolver: ${responseResolver} , ori data: $content " )
186
194
}.getOrElse {
187
195
throw IllegalStateException (" cannot parse with responseResolver: ${responseResolver} , ori data: $content " )
188
196
}
197
+
189
198
return SessionMessageItem (Message (" system" , result))
190
199
}
191
200
}
@@ -217,15 +226,18 @@ abstract class LLMProvider2 protected constructor(
217
226
/* * 同步取消当前请求,并将等待请求完成 */
218
227
suspend fun cancelCurrentRequest (session : ChatSession <Message >) {
219
228
_sendingJob ?.cancelAndJoin()
229
+ _currentEventSource ?.cancel()
230
+ _currentEventSource = null
220
231
}
221
232
222
233
/* * 取消当前请求,本 api 不会等待请求完成 */
223
234
fun cancelCurrentRequestSync () {
224
235
_sendingJob ?.cancel()
236
+ _currentEventSource ?.cancel()
237
+ _currentEventSource = null
225
238
}
226
239
227
240
companion object {
228
-
229
241
/* * 返回在配置中设置的 provider */
230
242
operator fun invoke (autoDevSettingsState : AutoDevSettingsState = AutoDevSettingsState .getInstance()): LLMProvider2 =
231
243
LLMProvider2 (
@@ -273,12 +285,12 @@ abstract class LLMProvider2 protected constructor(
273
285
return GithubCopilotProvider (
274
286
responseResolver = if (stream) " \$ .choices[0].delta.content" else " \$ .choices[0].message.content" ,
275
287
requestCustomize = """ {"customFields": {
276
- "model": "$actualModelName ",
277
- "intent": false,
278
- "n": 1,
279
- "temperature": 0.1,
280
- "stream": ${ if (stream) " true" else " false" }
281
- }}
288
+ "model": "$actualModelName ",
289
+ "intent": false,
290
+ "n": 1,
291
+ "temperature": 0.1,
292
+ "stream": ${if (stream) " true" else " false" }
293
+ }}
282
294
""" .trimIndent(),
283
295
project = project,
284
296
)
@@ -319,16 +331,19 @@ private class DefaultLLMTextProvider(
319
331
320
332
override fun textComplete (session : ChatSession <Message >, stream : Boolean ): Flow <SessionMessageItem <Message >> {
321
333
val client = httpClient.newBuilder().readTimeout(Duration .ofSeconds(30 )).build()
334
+
322
335
val requestBuilder = Request .Builder ().apply {
323
336
if (authorizationKey.isNotEmpty()) {
324
337
addHeader(" Authorization" , " Bearer $authorizationKey " )
325
338
}
326
339
appendCustomHeaders(requestCustomize)
327
340
}
341
+
328
342
val customRequest = CustomRequest (session.chatHistory.map {
329
343
val cm = it.chatMessage
330
344
Message (cm.role, cm.content)
331
345
})
346
+
332
347
val requestBodyText = customRequest.updateCustomFormat(requestCustomize)
333
348
val content = requestBodyText.toByteArray()
334
349
val requestBody = content.toRequestBody(" application/json" .toMediaTypeOrNull(), 0 , content.size)
@@ -360,7 +375,8 @@ private class DefaultLLMTextProvider(
360
375
}
361
376
}
362
377
}
378
+
363
379
awaitClose()
364
380
}
365
381
}
366
- }
382
+ }
0 commit comments