Skip to content

Commit 88b4633

Browse files
committed
feat(custom-agent): refactor to use LlmProvider for chat processing #51
The chat service has been refactored to use a new LlmProvider class for creating and interacting with the language model. This change simplifies the code by removing the need for direct instantiation of the LlmFactory and instead relying on the provider to create and manage the language model instance. The provider also allows for easier integration of custom logic or configuration specific to the project. Additionally, the chat processor now uses a StringBuilder to accumulate the chat response when using the stream response action, and the response is appended to the local message history of the LlmProvider. This ensures that the chat history is correctly maintained and can be accessed for future reference.
1 parent a82d8ba commit 88b4633

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

src/main/kotlin/cc/unitmesh/devti/counit/CustomAgentChatProcessor.kt

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@ import cc.unitmesh.devti.counit.model.CustomAgentState
66
import cc.unitmesh.devti.counit.model.ResponseAction
77
import cc.unitmesh.devti.gui.chat.ChatCodingPanel
88
import cc.unitmesh.devti.gui.chat.ChatContext
9+
import cc.unitmesh.devti.gui.chat.ChatRole
10+
import cc.unitmesh.devti.llms.LLMProvider
911
import cc.unitmesh.devti.provider.ContextPrompter
1012
import cc.unitmesh.devti.util.LLMCoroutineScope
1113
import com.intellij.openapi.components.Service
1214
import com.intellij.openapi.diagnostic.logger
1315
import com.intellij.openapi.project.Project
14-
import kotlinx.coroutines.flow.collect
1516
import kotlinx.coroutines.launch
1617
import kotlinx.coroutines.runBlocking
1718

1819
@Service(Service.Level.PROJECT)
1920
class CustomAgentChatProcessor(val project: Project) {
2021
private val customAgentExecutor = CustomAgentExecutor(project)
2122

22-
fun handleChat(prompter: ContextPrompter, ui: ChatCodingPanel, context: ChatContext?) {
23+
fun handleChat(prompter: ContextPrompter, ui: ChatCodingPanel, context: ChatContext?, llmProvider: LLMProvider) {
2324
val originPrompt = prompter.requestPrompt()
2425
ui.addMessage(originPrompt, true, originPrompt)
2526

@@ -38,9 +39,13 @@ class CustomAgentChatProcessor(val project: Project) {
3839
when (selectedAgent.responseAction) {
3940
ResponseAction.Direct -> {
4041
val message = ui.addMessage("loading", false, "")
42+
val sb = StringBuilder()
4143
runBlocking {
42-
ui.updateMessage(response)
44+
val result = ui.updateMessage(response)
45+
sb.append(result)
4346
}
47+
48+
llmProvider.appendLocalMessage(sb.toString(), ChatRole.Assistant)
4449
message.reRenderAssistantOutput()
4550
ui.hiddenProgressBar()
4651
ui.updateUI()
@@ -49,7 +54,7 @@ class CustomAgentChatProcessor(val project: Project) {
4954
ResponseAction.Stream -> {
5055
ui.addMessage(AutoDevBundle.message("autodev.loading"))
5156
LLMCoroutineScope.scope(project).launch {
52-
// ui.updateMessage(response)
57+
ui.updateMessage(response)
5358
}
5459
}
5560

@@ -60,8 +65,11 @@ class CustomAgentChatProcessor(val project: Project) {
6065
sb.append(it)
6166
}
6267
}
68+
69+
val content = sb.toString()
70+
llmProvider.appendLocalMessage(content, ChatRole.Assistant)
6371
ui.removeLastMessage()
64-
ui.setInput(sb.toString())
72+
ui.setInput(content)
6573
ui.hiddenProgressBar()
6674
}
6775

@@ -78,6 +86,7 @@ class CustomAgentChatProcessor(val project: Project) {
7886
}
7987
// TODO: add decode support
8088
val content = sb.toString()
89+
llmProvider.appendLocalMessage(content, ChatRole.Assistant)
8190

8291
ui.appendWebView(content, project)
8392
ui.hiddenProgressBar()

src/main/kotlin/cc/unitmesh/devti/gui/chat/ChatCodingService.kt

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,12 @@ import kotlinx.coroutines.flow.Flow
1616
import kotlinx.coroutines.launch
1717

1818
class ChatCodingService(var actionType: ChatActionType, val project: Project) {
19-
private val llmFactory = LlmFactory()
19+
private val llmProvider = LlmFactory().create(project)
2020
private val counitProcessor = project.service<CustomAgentChatProcessor>()
2121

2222
val action = actionType.instruction(project = project)
2323

24-
fun getLabel(): String {
25-
val capitalizedAction = actionType
26-
return "$capitalizedAction Code"
27-
}
24+
fun getLabel(): String = "$actionType Code"
2825

2926
fun handlePromptAndResponse(
3027
ui: ChatCodingPanel,
@@ -37,7 +34,7 @@ class ChatCodingService(var actionType: ChatActionType, val project: Project) {
3734

3835
if (project.customAgentSetting.enableCustomRag && ui.hasSelectedCustomAgent()) {
3936
if (ui.getSelectedCustomAgent().state === CustomAgentState.START) {
40-
counitProcessor.handleChat(prompter, ui, context)
37+
counitProcessor.handleChat(prompter, ui, context, llmProvider)
4138
return
4239
}
4340
}
@@ -74,7 +71,7 @@ class ChatCodingService(var actionType: ChatActionType, val project: Project) {
7471
ui.addMessage(AutoDevBundle.message("autodev.loading"))
7572

7673
ApplicationManager.getApplication().executeOnPooledThread {
77-
val response = llmFactory.create(project).stream(requestPrompt, systemPrompt)
74+
val response = llmProvider.stream(requestPrompt, systemPrompt)
7875

7976
LLMCoroutineScope.scope(project).launch {
8077
ui.updateMessage(response)
@@ -83,7 +80,7 @@ class ChatCodingService(var actionType: ChatActionType, val project: Project) {
8380
}
8481

8582
private fun makeChatBotRequest(requestPrompt: String, newChatContext: Boolean): Flow<String> {
86-
return llmFactory.create(project).stream(requestPrompt, "", keepHistory = !newChatContext)
83+
return llmProvider.stream(requestPrompt, "", keepHistory = !newChatContext)
8784
}
8885

8986
private fun getCodeSection(content: String, prefixText: String, suffixText: String): String {
@@ -96,6 +93,6 @@ class ChatCodingService(var actionType: ChatActionType, val project: Project) {
9693
}
9794

9895
fun clearSession() {
99-
llmFactory.create(project).clearMessage()
96+
llmProvider.clearMessage()
10097
}
10198
}

0 commit comments

Comments
 (0)