Skip to content

Commit 2fcab1c

Browse files
committed
feat: enable recording datasets works in local
1 parent f24cc6a commit 2fcab1c

File tree

6 files changed

+78
-1
lines changed

6 files changed

+78
-1
lines changed

src/main/kotlin/cc/unitmesh/devti/llms/azure/AzureOpenAIProvider.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@ package cc.unitmesh.devti.llms.azure
33
import cc.unitmesh.devti.custom.action.CustomPromptConfig
44
import cc.unitmesh.devti.gui.chat.ChatRole
55
import cc.unitmesh.devti.llms.LLMProvider
6+
import cc.unitmesh.devti.recording.EmptyRecording
7+
import cc.unitmesh.devti.recording.JsonlRecording
8+
import cc.unitmesh.devti.recording.Recording
9+
import cc.unitmesh.devti.recording.RecordingInstruction
610
import cc.unitmesh.devti.settings.AutoDevSettingsState
11+
import cc.unitmesh.devti.settings.custom.teamPromptsSettings
712
import com.fasterxml.jackson.databind.ObjectMapper
813
import com.intellij.openapi.components.Service
14+
import com.intellij.openapi.components.service
915
import com.intellij.openapi.diagnostic.logger
1016
import com.intellij.openapi.project.Project
1117
import com.theokanning.openai.completion.chat.ChatCompletionResult
@@ -53,6 +59,14 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider {
5359
private val maxTokenLength: Int
5460
get() = AutoDevSettingsState.getInstance().fetchMaxTokenLength()
5561

62+
private val recording: Recording
63+
get() {
64+
if (project.teamPromptsSettings.state.recordingInLocal) {
65+
return project.service<JsonlRecording>()
66+
}
67+
return EmptyRecording()
68+
}
69+
5670

5771
init {
5872
val prompts = autoDevSettingsState.customPrompts
@@ -149,6 +163,8 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider {
149163
call.enqueue(cc.unitmesh.devti.llms.azure.ResponseBodyCallback(emitter, emitDone))
150164
}, BackpressureStrategy.BUFFER)
151165

166+
var output = ""
167+
152168
return callbackFlow {
153169
sseFlowable
154170
.doOnError(Throwable::printStackTrace)
@@ -161,6 +177,8 @@ class AzureOpenAIProvider(val project: Project) : LLMProvider {
161177
}
162178
}
163179

180+
recording.write(RecordingInstruction(promptText, output))
181+
164182
close()
165183
}
166184
}

src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@ package cc.unitmesh.devti.llms.openai
22

33
import cc.unitmesh.devti.gui.chat.ChatRole
44
import cc.unitmesh.devti.llms.LLMProvider
5+
import cc.unitmesh.devti.recording.EmptyRecording
6+
import cc.unitmesh.devti.recording.JsonlRecording
7+
import cc.unitmesh.devti.recording.Recording
8+
import cc.unitmesh.devti.recording.RecordingInstruction
59
import cc.unitmesh.devti.settings.AutoDevSettingsState
10+
import cc.unitmesh.devti.settings.custom.teamPromptsSettings
611
import com.intellij.openapi.components.Service
12+
import com.intellij.openapi.components.service
713
import com.intellij.openapi.diagnostic.Logger
814
import com.intellij.openapi.diagnostic.logger
915
import com.intellij.openapi.project.Project
@@ -63,6 +69,14 @@ class OpenAIProvider(val project: Project) : LLMProvider {
6369
private val messages: MutableList<ChatMessage> = ArrayList()
6470
private var historyMessageLength: Int = 0
6571

72+
private val recording: Recording
73+
get() {
74+
if (project.teamPromptsSettings.state.recordingInLocal) {
75+
return project.service<JsonlRecording>()
76+
}
77+
return EmptyRecording()
78+
}
79+
6680
override fun clearMessage() {
6781
messages.clear()
6882
historyMessageLength = 0
@@ -89,24 +103,27 @@ class OpenAIProvider(val project: Project) : LLMProvider {
89103
clearMessage()
90104
}
91105

106+
var output = ""
92107
val completionRequest = prepareRequest(promptText, systemPrompt)
93108

94109
return callbackFlow {
95110
withContext(Dispatchers.IO) {
96111
service.streamChatCompletion(completionRequest)
97-
.doOnError{ error ->
112+
.doOnError { error ->
98113
logger.error("Error in stream", error)
99114
trySend(error.message ?: "Error occurs")
100115
}
101116
.blockingForEach { response ->
102117
if (response.choices.isNotEmpty()) {
103118
val completion = response.choices[0].message
104119
if (completion != null && completion.content != null) {
120+
output += completion.content
105121
trySend(completion.content)
106122
}
107123
}
108124
}
109125

126+
recording.write(RecordingInstruction(promptText, output))
110127
close()
111128
}
112129
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package cc.unitmesh.devti.recording
2+
3+
class EmptyRecording: Recording {
4+
override fun write(instruction: RecordingInstruction) {
5+
// do nothing
6+
}
7+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package cc.unitmesh.devti.recording
2+
3+
import com.intellij.openapi.components.Service
4+
import com.intellij.openapi.project.Project
5+
import com.intellij.openapi.project.guessProjectDir
6+
import kotlinx.serialization.encodeToString
7+
import kotlinx.serialization.json.Json
8+
import java.nio.file.Path
9+
10+
@Service(Service.Level.PROJECT)
11+
class JsonlRecording(val project: Project) : Recording {
12+
private val recordingPath: Path = Path.of(project.guessProjectDir()!!.path, "recording.jsonl")
13+
override fun write(instruction: RecordingInstruction) {
14+
if (!recordingPath.toFile().exists()) {
15+
recordingPath.toFile().createNewFile()
16+
}
17+
18+
recordingPath.toFile().appendText(Json.encodeToString(instruction) + "\n")
19+
}
20+
}
21+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package cc.unitmesh.devti.recording
2+
3+
interface Recording {
4+
fun write(instruction: RecordingInstruction)
5+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package cc.unitmesh.devti.recording
2+
3+
import kotlinx.serialization.Serializable
4+
5+
@Serializable
6+
data class RecordingInstruction(
7+
val instruction: String,
8+
val output: String,
9+
)

0 commit comments

Comments
 (0)