1
1
package cc.unitmesh.database.actions
2
2
3
3
import cc.unitmesh.devti.AutoDevBundle
4
- import cc.unitmesh.devti.gui.sendToChatWindow
4
+ import cc.unitmesh.devti.gui.chat.ChatCodingPanel
5
+ import cc.unitmesh.devti.gui.sendToChatPanel
5
6
import cc.unitmesh.devti.intentions.action.base.AbstractChatIntention
6
- import cc.unitmesh.devti.provider.ContextPrompter
7
+ import cc.unitmesh.devti.llms.LLMProvider
8
+ import cc.unitmesh.devti.llms.LlmFactory
7
9
import cc.unitmesh.devti.template.TemplateRender
8
10
import com.intellij.database.model.DasTable
9
11
import com.intellij.database.model.ObjectKind
10
12
import com.intellij.database.psi.DbPsiFacade
11
13
import com.intellij.database.util.DasUtil
14
+ import com.intellij.openapi.application.ApplicationManager
15
+ import com.intellij.openapi.application.ReadAction
12
16
import com.intellij.openapi.diagnostic.logger
13
17
import com.intellij.openapi.editor.Editor
18
+ import com.intellij.openapi.progress.ProgressIndicator
19
+ import com.intellij.openapi.progress.ProgressManager
20
+ import com.intellij.openapi.progress.Task
14
21
import com.intellij.openapi.project.Project
15
22
import com.intellij.psi.PsiFile
23
+ import kotlinx.coroutines.runBlocking
16
24
17
25
18
26
class GenSqlScriptBySelection : AbstractChatIntention () {
19
27
override fun priority (): Int = 1001
20
-
21
28
override fun startInWriteAction (): Boolean = false
22
-
23
29
override fun getFamilyName (): String = AutoDevBundle .message(" migration.database.plsql" )
24
-
25
30
override fun getText (): String = AutoDevBundle .message(" migration.database.sql.generate" )
26
31
27
32
override fun isAvailable (project : Project , editor : Editor ? , file : PsiFile ? ): Boolean {
@@ -32,10 +37,12 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
32
37
private val logger = logger<GenSqlScriptBySelection >()
33
38
34
39
override fun invoke (project : Project , editor : Editor ? , file : PsiFile ? ) {
40
+ if (editor == null || file == null ) return
41
+
35
42
val dbPsiFacade = DbPsiFacade .getInstance(project)
36
43
val dataSource = dbPsiFacade.dataSources.firstOrNull() ? : return
37
44
38
- val selectedText = editor? .selectionModel? .selectedText
45
+ val selectedText = editor.selectionModel.selectedText
39
46
40
47
val rawDataSource = dbPsiFacade.getDataSourceManager(dataSource).dataSources.firstOrNull() ? : return
41
48
val databaseVersion = rawDataSource.databaseVersion
@@ -55,16 +62,80 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
55
62
)
56
63
57
64
val actions = DbContextActionProvider (dasTables)
58
- val prompter = generateStepOnePrompt(dbContext, actions)
59
65
60
- sendToChatWindow(project, getActionType()) { panel, service ->
61
- service.handlePromptAndResponse(panel, object : ContextPrompter () {
62
- override fun displayPrompt (): String = prompter
63
- override fun requestPrompt (): String = prompter
64
- }, null , false )
66
+ sendToChatPanel(project) { contentPanel, _ ->
67
+ val llmProvider = LlmFactory ().create(project)
68
+ val prompter = GenSqlFlow (dbContext, actions, contentPanel, llmProvider)
69
+ ApplicationManager .getApplication().invokeLater {
70
+
71
+ ProgressManager .getInstance()
72
+ .run (generateSqlWorkflow(project, contentPanel, prompter))
73
+ }
65
74
}
66
75
}
67
76
77
+ private fun generateSqlWorkflow (
78
+ project : Project ,
79
+ ui : ChatCodingPanel ,
80
+ flow : GenSqlFlow ,
81
+ ) =
82
+ object : Task .Backgroundable (project, " Loading retained test failure" , true ) {
83
+ override fun run (indicator : ProgressIndicator ) {
84
+ indicator.fraction = 0.2
85
+
86
+
87
+ indicator.text = AutoDevBundle .message(" migration.database.sql.generate.clarify" )
88
+ val tables = ReadAction .compute<String , Throwable > {
89
+ flow.clarify()
90
+ }
91
+
92
+ // tables will be list in string format, like: `[table1, table2]`, we need to parse to Lists
93
+ val tableNames = tables.substringAfter(" [" ).substringBefore(" ]" )
94
+ .split(" , " ).map { it.trim() }
95
+
96
+ indicator.fraction = 0.6
97
+ val sqlScript = flow.generate(tableNames)
98
+
99
+ logger.info(" SQL Script: $sqlScript " )
100
+
101
+ indicator.fraction = 1.0
102
+ }
103
+ }
104
+ }
105
+
106
+ class GenSqlFlow (
107
+ val dbContext : DbContext ,
108
+ val actions : DbContextActionProvider ,
109
+ val ui : ChatCodingPanel ,
110
+ val llm : LLMProvider
111
+ ) {
112
+ private val logger = logger<GenSqlFlow >()
113
+
114
+ fun clarify (): String {
115
+ val stepOnePrompt = generateStepOnePrompt(dbContext, actions)
116
+ ui.addMessage(stepOnePrompt, true , stepOnePrompt)
117
+ // for answer
118
+ ui.addMessage(AutoDevBundle .message(" autodev.loading" ))
119
+
120
+ return runBlocking {
121
+ val prompt = llm.stream(stepOnePrompt, " " )
122
+ return @runBlocking ui.updateMessage(prompt)
123
+ }
124
+ }
125
+
126
+ fun generate (tableNames : List <String >): String {
127
+ val stepTwoPrompt = generateStepTwoPrompt(dbContext, actions, tableNames)
128
+ ui.addMessage(stepTwoPrompt, true , stepTwoPrompt)
129
+ // for answer
130
+ ui.addMessage(AutoDevBundle .message(" autodev.loading" ))
131
+
132
+ return runBlocking {
133
+ val prompt = llm.stream(stepTwoPrompt, " " )
134
+ return @runBlocking ui.updateMessage(prompt)
135
+ }
136
+ }
137
+
138
+
68
139
private fun generateStepOnePrompt (context : DbContext , actions : DbContextActionProvider ): String {
69
140
val templateRender = TemplateRender (" genius/sql" )
70
141
val template = templateRender.getTemplate(" sql-gen-clarify.vm" )
@@ -77,15 +148,35 @@ class GenSqlScriptBySelection : AbstractChatIntention() {
77
148
logger.info(" Prompt: $prompter " )
78
149
return prompter
79
150
}
151
+
152
+ private fun generateStepTwoPrompt (
153
+ dbContext : DbContext ,
154
+ actions : DbContextActionProvider ,
155
+ tableInfos : List <String >
156
+ ): String {
157
+ val templateRender = TemplateRender (" genius/sql" )
158
+ val template = templateRender.getTemplate(" sql-gen-generate.vm" )
159
+
160
+ dbContext.tableInfos = actions.getTableColumns(tableInfos)
161
+
162
+ templateRender.context = dbContext
163
+ templateRender.actions = actions
164
+
165
+ val prompter = templateRender.renderTemplate(template)
166
+
167
+ logger.info(" Prompt: $prompter " )
168
+ return prompter
169
+ }
80
170
}
81
171
172
+
82
173
data class DbContext (
83
174
val requirement : String ,
84
175
val databaseVersion : String ,
85
176
val schemaName : String ,
86
177
val tableNames : List <String >,
87
178
// for step 2
88
- val tableInfos : List <String > = emptyList(),
179
+ var tableInfos : List <String > = emptyList(),
89
180
)
90
181
91
182
data class DbContextActionProvider (val dasTables : List <DasTable >) {
0 commit comments