@@ -2,11 +2,13 @@ package cc.unitmesh.devti.llm2
2
2
3
3
import cc.unitmesh.devti.llms.custom.CustomRequest
4
4
import cc.unitmesh.devti.llms.custom.Message
5
+ import cc.unitmesh.devti.llms.custom.Usage
5
6
import cc.unitmesh.devti.llms.custom.appendCustomHeaders
6
7
import cc.unitmesh.devti.llms.custom.updateCustomFormat
7
8
import cc.unitmesh.devti.settings.AutoDevSettingsState
8
9
import cc.unitmesh.devti.util.AutoDevAppScope
9
10
import cc.unitmesh.devti.util.AutoDevCoroutineScope
11
+ import com.intellij.openapi.application.ApplicationManager
10
12
import com.intellij.openapi.diagnostic.Logger
11
13
import com.intellij.openapi.diagnostic.logger
12
14
import com.intellij.openapi.project.Project
@@ -23,35 +25,12 @@ import okhttp3.OkHttpClient
23
25
import okhttp3.Request
24
26
import okhttp3.RequestBody.Companion.toRequestBody
25
27
import okhttp3.Response
26
- import okhttp3.sse.EventSource
27
- import okhttp3.sse.EventSourceListener
28
- import okhttp3.sse.EventSources
28
+ import okhttp3.sse.*
29
29
import java.time.Duration
30
30
31
31
/* *
32
- * LLMProvider provide only session-free interfaces
33
- *
34
- * It's LLMProvider's responsibility to maintain the network connection But
35
- * the chat session is maintained by the client
36
- *
37
- * [LLMProvider2] provides a factory companion object to create different
38
- * providers
39
- *
40
- * for now, we only support text completion, see [DefaultLLMTextProvider].
41
- * you can implement your own provider by extending [LLMProvider2] and
42
- * override [textComplete] method
43
- *
44
- * ```kotlin
45
- * val provider = LLMProvider2()
46
- * val session = ChatSession("sessionName")
47
- * // if you don't need to maintain the history, you can ignore the session
48
- * // stream is default to true
49
- * provider.request("text", session = session, stream = true).catch {
50
- * // handle errors
51
- * }.collect {
52
- * // incoming new message without the original history messages
53
- * }
54
- * ```
32
+ * LLMProvider2 is an abstract class that provides a base implementation for LLM (Large Language Model) providers.
33
+ * It handles the communication with LLM services and manages the streaming of responses.
55
34
*
56
35
* @property project if not null means this is a project level provider,
57
36
* will be disposed when project closed
@@ -89,12 +68,21 @@ abstract class LLMProvider2 protected constructor(
89
68
) {
90
69
val factory = EventSources .createFactory(client)
91
70
var result = " "
71
+ var sessionId: String? = null
72
+
92
73
factory.newEventSource(request, object : EventSourceListener () {
93
74
override fun onEvent (eventSource : EventSource , id : String? , type : String? , data : String ) {
94
75
super .onEvent(eventSource, id, type, data)
95
76
if (data == " [DONE]" ) {
96
77
return
97
78
}
79
+
80
+ if (sessionId == null ) {
81
+ sessionId = tryExtractSessionId(data)
82
+ }
83
+
84
+ tryParseAndNotifyTokenUsage(data, sessionId)
85
+
98
86
val chunk: String = runCatching {
99
87
val result: String? = JsonPath .parse(data)?.read(responseResolver)
100
88
result ? : " "
@@ -125,6 +113,60 @@ abstract class LLMProvider2 protected constructor(
125
113
})
126
114
}
127
115
116
+ /* *
117
+ * Try to parse token usage data from SSE response and notify listeners
118
+ *
119
+ * @param data The raw SSE data string
120
+ * @param sessionId The session ID if available
121
+ */
122
+ private fun tryParseAndNotifyTokenUsage (data : String , sessionId : String? ) {
123
+ try {
124
+ val usageData: Usage ? = runCatching {
125
+ JsonPath .parse(data)?.read<Map <String , Any >>(" \$ .usage" )?.let { usageMap ->
126
+ Usage (
127
+ promptTokens = (usageMap[" prompt_tokens" ] as ? Number )?.toLong() ? : 0 ,
128
+ completionTokens = (usageMap[" completion_tokens" ] as ? Number )?.toLong() ? : 0 ,
129
+ totalTokens = (usageMap[" total_tokens" ] as ? Number )?.toLong() ? : 0
130
+ )
131
+ }
132
+ }.getOrNull()
133
+
134
+ val model: String? = runCatching {
135
+ JsonPath .parse(data)?.read<String >(" \$ .model" )
136
+ }.getOrNull()
137
+
138
+ usageData?.let { usage ->
139
+ val tokenUsageEvent = TokenUsageEvent (
140
+ usage = usage,
141
+ model = model,
142
+ sessionId = sessionId,
143
+ timestamp = System .currentTimeMillis()
144
+ )
145
+
146
+ ApplicationManager .getApplication().messageBus
147
+ .syncPublisher(TokenUsageListener .TOPIC )
148
+ .onTokenUsage(tokenUsageEvent)
149
+
150
+ logger.info(" Token usage event published: prompt=${usage.promptTokens} , completion=${usage.completionTokens} , total=${usage.totalTokens} " )
151
+ }
152
+ } catch (e: Exception ) {
153
+ // Silently ignore parsing errors for usage data since it's optional
154
+ logger.debug(" Failed to parse token usage from response data" , e)
155
+ }
156
+ }
157
+
158
+ /* *
159
+ * Try to extract session ID from response data
160
+ *
161
+ * @param data The raw SSE data string
162
+ * @return The session ID if found, null otherwise
163
+ */
164
+ private fun tryExtractSessionId (data : String ): String? {
165
+ return runCatching {
166
+ JsonPath .parse(data)?.read<String >(" \$ .id" )
167
+ }.getOrNull()
168
+ }
169
+
128
170
protected fun directResult (client : OkHttpClient , request : Request ): SessionMessageItem <Message > {
129
171
client.newCall(request).execute().use {
130
172
val body = it.body
0 commit comments