Skip to content

Commit 728d341

Browse files
committed
feat(llm): add EventSource cancellation support
- Store EventSource reference for proper cleanup - Add cancellation logic in cancelCurrentRequest methods - Fix JSON formatting and code style improvements - Enable proper streaming request termination
1 parent 8fd2da8 commit 728d341

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

core/src/main/kotlin/cc/unitmesh/devti/llm2/LLMProvider2.kt

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ abstract class LLMProvider2 protected constructor(
4545
/** The job that is sending the request */
4646
protected var _sendingJob: Job? = null
4747

48+
/** The current EventSource for SSE streaming */
49+
protected var _currentEventSource: EventSource? = null
50+
4851
/**
4952
* 为会话创建一个 CoroutineScope
5053
*
@@ -70,7 +73,7 @@ abstract class LLMProvider2 protected constructor(
7073
var result = ""
7174
var sessionId: String? = null
7275

73-
factory.newEventSource(request, object : EventSourceListener() {
76+
val eventSource = factory.newEventSource(request, object : EventSourceListener() {
7477
override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
7578
super.onEvent(eventSource, id, type, data)
7679
if (data == "[DONE]") {
@@ -90,18 +93,21 @@ abstract class LLMProvider2 protected constructor(
9093
logger.warn(IllegalStateException("cannot parse with responseResolver: ${responseResolver}, ori data: $data"))
9194
""
9295
}
96+
9397
result += chunk
9498
onEvent(SessionMessageItem(Message("system", result)))
9599
}
96100

97101
override fun onClosed(eventSource: EventSource) {
102+
_currentEventSource = null
98103
if (result.isEmpty()) {
99104
onFailure(IllegalStateException("response is empty"))
100105
}
101106
onClosed()
102107
}
103108

104109
override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
110+
_currentEventSource = null
105111
onFailure(
106112
t ?: RuntimeException("error: ${response?.code} ${response?.message} ${response?.body?.string()}")
107113
)
@@ -111,6 +117,9 @@ abstract class LLMProvider2 protected constructor(
111117
onOpen()
112118
}
113119
})
120+
121+
// Store the EventSource reference so we can cancel it later
122+
_currentEventSource = eventSource
114123
}
115124

116125
/**
@@ -142,11 +151,9 @@ abstract class LLMProvider2 protected constructor(
142151
sessionId = sessionId,
143152
timestamp = System.currentTimeMillis()
144153
)
145-
146154
ApplicationManager.getApplication().messageBus
147155
.syncPublisher(TokenUsageListener.TOPIC)
148156
.onTokenUsage(tokenUsageEvent)
149-
150157
logger.info("Token usage event published: prompt=${usage.promptTokens}, completion=${usage.completionTokens}, total=${usage.totalTokens}")
151158
}
152159
} catch (e: Exception) {
@@ -182,10 +189,12 @@ abstract class LLMProvider2 protected constructor(
182189
logger.info("response: $content")
183190
val result: String = runCatching<String> {
184191
val result: String? = JsonPath.parse(content)?.read(responseResolver)
185-
result ?: throw java.lang.IllegalStateException("cannot parse with responseResolver: ${responseResolver}, ori data: $content")
192+
result
193+
?: throw java.lang.IllegalStateException("cannot parse with responseResolver: ${responseResolver}, ori data: $content")
186194
}.getOrElse {
187195
throw IllegalStateException("cannot parse with responseResolver: ${responseResolver}, ori data: $content")
188196
}
197+
189198
return SessionMessageItem(Message("system", result))
190199
}
191200
}
@@ -217,15 +226,18 @@ abstract class LLMProvider2 protected constructor(
217226
/** 同步取消当前请求,并将等待请求完成 */
218227
suspend fun cancelCurrentRequest(session: ChatSession<Message>) {
219228
_sendingJob?.cancelAndJoin()
229+
_currentEventSource?.cancel()
230+
_currentEventSource = null
220231
}
221232

222233
/** 取消当前请求,本 api 不会等待请求完成 */
223234
fun cancelCurrentRequestSync() {
224235
_sendingJob?.cancel()
236+
_currentEventSource?.cancel()
237+
_currentEventSource = null
225238
}
226239

227240
companion object {
228-
229241
/** 返回在配置中设置的 provider */
230242
operator fun invoke(autoDevSettingsState: AutoDevSettingsState = AutoDevSettingsState.getInstance()): LLMProvider2 =
231243
LLMProvider2(
@@ -273,12 +285,12 @@ abstract class LLMProvider2 protected constructor(
273285
return GithubCopilotProvider(
274286
responseResolver = if (stream) "\$.choices[0].delta.content" else "\$.choices[0].message.content",
275287
requestCustomize = """{"customFields": {
276-
"model": "$actualModelName",
277-
"intent": false,
278-
"n": 1,
279-
"temperature": 0.1,
280-
"stream": ${ if (stream) "true" else "false" }
281-
}}
288+
"model": "$actualModelName",
289+
"intent": false,
290+
"n": 1,
291+
"temperature": 0.1,
292+
"stream": ${if (stream) "true" else "false"}
293+
}}
282294
""".trimIndent(),
283295
project = project,
284296
)
@@ -319,16 +331,19 @@ private class DefaultLLMTextProvider(
319331

320332
override fun textComplete(session: ChatSession<Message>, stream: Boolean): Flow<SessionMessageItem<Message>> {
321333
val client = httpClient.newBuilder().readTimeout(Duration.ofSeconds(30)).build()
334+
322335
val requestBuilder = Request.Builder().apply {
323336
if (authorizationKey.isNotEmpty()) {
324337
addHeader("Authorization", "Bearer $authorizationKey")
325338
}
326339
appendCustomHeaders(requestCustomize)
327340
}
341+
328342
val customRequest = CustomRequest(session.chatHistory.map {
329343
val cm = it.chatMessage
330344
Message(cm.role, cm.content)
331345
})
346+
332347
val requestBodyText = customRequest.updateCustomFormat(requestCustomize)
333348
val content = requestBodyText.toByteArray()
334349
val requestBody = content.toRequestBody("application/json".toMediaTypeOrNull(), 0, content.size)
@@ -360,7 +375,8 @@ private class DefaultLLMTextProvider(
360375
}
361376
}
362377
}
378+
363379
awaitClose()
364380
}
365381
}
366-
}
382+
}

0 commit comments

Comments
 (0)