Skip to content

Commit 9d02979

Browse files
committed
fix(test): fix batch test generation cancel button
Fixes issue where cancel button in batch test generation progress dialog was not responsive. Added proper cancellation checks throughout the test generation process and replaced executor-based approach with cancellable background tasks. Resolves #407
1 parent 728d341 commit 9d02979

File tree

2 files changed

+101
-22
lines changed

2 files changed

+101
-22
lines changed

core/src/main/kotlin/cc/unitmesh/devti/actions/chat/AutoTestInMenuAction.kt

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import cc.unitmesh.devti.intentions.action.test.TestCodeGenRequest
88
import com.intellij.openapi.actionSystem.*
99
import com.intellij.openapi.actionSystem.ActionPlaces.PROJECT_VIEW_POPUP
1010
import com.intellij.openapi.editor.Editor
11-
import com.intellij.openapi.progress.EmptyProgressIndicator
12-
import com.intellij.openapi.progress.ProgressManager
11+
import com.intellij.openapi.progress.*
1312
import com.intellij.openapi.progress.impl.BackgroundableProcessIndicator
1413
import com.intellij.openapi.project.Project
1514
import com.intellij.openapi.ui.MessageType
@@ -21,9 +20,11 @@ import com.intellij.openapi.wm.WindowManager
2120
import com.intellij.psi.PsiFile
2221
import com.intellij.psi.PsiManager
2322
import com.intellij.psi.impl.file.PsiDirectoryFactory
24-
import java.util.concurrent.Executors
23+
import com.intellij.openapi.diagnostic.logger
2524

2625
class AutoTestInMenuAction : AnAction(AutoDevBundle.message("intentions.chat.code.test.name")) {
26+
private val logger = logger<AutoTestInMenuAction>()
27+
2728
override fun getActionUpdateThread(): ActionUpdateThread = ActionUpdateThread.BGT
2829

2930
fun getActionType(): ChatActionType = ChatActionType.GENERATE_TEST
@@ -60,26 +61,49 @@ class AutoTestInMenuAction : AnAction(AutoDevBundle.message("intentions.chat.cod
6061
}
6162

6263
private fun batchGenerateTests(files: List<PsiFile>, project: Project, editor: Editor?) {
63-
val total = files.size
64-
val executor = Executors.newSingleThreadExecutor()
65-
files.forEachIndexed { index, file ->
66-
val task = TestCodeGenTask(
67-
TestCodeGenRequest(file, file, project, editor),
68-
AutoDevBundle.message("intentions.chat.code.test.name")
69-
)
70-
71-
executor.submit {
72-
val progressMessage = """${index + 1}/${total} Processing file ${file.name} for test generation"""
73-
ProgressManager.getInstance().runProcessWithProgressSynchronously(
74-
{
75-
task.run(object : EmptyProgressIndicator() {})
76-
},
77-
progressMessage, true, project
78-
)
64+
val batchTask = object : Task.Backgroundable(
65+
project,
66+
AutoDevBundle.message("intentions.chat.code.test.name") + " (Batch)",
67+
true
68+
) {
69+
override fun run(indicator: ProgressIndicator) {
70+
val total = files.size
71+
indicator.isIndeterminate = false
72+
indicator.fraction = 0.0
73+
74+
files.forEachIndexed { index, file ->
75+
// Check for cancellation before processing each file
76+
indicator.checkCanceled()
77+
78+
indicator.text = "Processing ${index + 1}/$total: ${file.name}"
79+
indicator.fraction = index.toDouble() / total
80+
81+
val task = TestCodeGenTask(
82+
TestCodeGenRequest(file, file, project, editor),
83+
AutoDevBundle.message("intentions.chat.code.test.name")
84+
)
85+
86+
try {
87+
task.run(indicator)
88+
indicator.fraction = (index + 1).toDouble() / total
89+
} catch (e: ProcessCanceledException) {
90+
// User cancelled, stop processing
91+
indicator.text = "Batch test generation cancelled"
92+
throw e
93+
} catch (e: Exception) {
94+
// Log error but continue with next file
95+
logger.warn("Failed to generate test for file: ${file.name}", e)
96+
indicator.fraction = (index + 1).toDouble() / total
97+
}
98+
}
99+
100+
indicator.fraction = 1.0
101+
indicator.text = "Batch test generation completed"
79102
}
80103
}
81104

82-
executor.shutdown()
105+
ProgressManager.getInstance()
106+
.runProcessWithProgressAsynchronously(batchTask, BackgroundableProcessIndicator(batchTask))
83107
}
84108

85109
private fun isEnabled(e: AnActionEvent): Boolean {

core/src/main/kotlin/cc/unitmesh/devti/intentions/action/task/TestCodeGenTask.kt

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import com.intellij.openapi.editor.ScrollType
2626
import com.intellij.openapi.fileEditor.FileEditor
2727
import com.intellij.openapi.fileEditor.FileEditorManager
2828
import com.intellij.openapi.progress.ProgressIndicator
29+
import com.intellij.openapi.progress.ProcessCanceledException
2930
import com.intellij.openapi.progress.Task
3031
import com.intellij.openapi.project.DumbService
3132
import com.intellij.openapi.project.Project
@@ -53,10 +54,16 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
5354
indicator.fraction = 0.1
5455
indicator.text = AutoDevBundle.message("intentions.chat.code.test.step.prepare-context")
5556

57+
// Check for cancellation early
58+
indicator.checkCanceled()
59+
5660
AutoDevStatusService.notifyApplication(AutoDevStatus.InProgress)
5761
val testContext = autoTestService.findOrCreateTestFile(request.file, request.project, request.element)
5862
DumbService.getInstance(request.project).waitForSmartMode()
5963

64+
// Check for cancellation after waiting for smart mode
65+
indicator.checkCanceled()
66+
6067
if (testContext == null) {
6168
AutoDevStatusService.notifyApplication(AutoDevStatus.Error)
6269
logger.error("Failed to create test file for: ${request.file}")
@@ -66,16 +73,27 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
6673
indicator.text = AutoDevBundle.message("intentions.chat.code.test.step.collect-context")
6774
indicator.fraction = 0.3
6875

76+
// Check for cancellation before collecting context
77+
indicator.checkCanceled()
78+
6979
val testPromptContext = TestCodeGenContext()
7080

7181
val creationContext =
7282
ChatCreationContext(ChatOrigin.Intention, actionType, request.file, listOf(), element = request.element)
7383

7484
val contextItems: List<ChatContextItem> = runBlocking {
85+
// Check for cancellation in the blocking context
86+
if (indicator.isCanceled) {
87+
throw ProcessCanceledException()
88+
}
7589
ChatContextProvider.collectChatContextList(request.project, creationContext)
7690
}
7791

7892
testPromptContext.frameworkContext = contextItems.joinToString("\n", transform = ChatContextItem::text)
93+
94+
// Check for cancellation before read actions
95+
indicator.checkCanceled()
96+
7997
ReadAction.compute<Unit, Throwable> {
8098
if (testContext.relatedClasses.isNotEmpty()) {
8199
testPromptContext.relatedClasses = testContext.relatedClasses.joinToString("\n") {
@@ -113,6 +131,9 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
113131
testPromptContext.isNewFile = testContext.isNewFile
114132
testPromptContext.extContext = getCustomAgentTestContext(testPromptContext)
115133

134+
// Check for cancellation before template rendering
135+
indicator.checkCanceled()
136+
116137
templateRender.context = testPromptContext
117138
val prompter = templateRender.renderTemplate(template)
118139

@@ -121,6 +142,9 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
121142
indicator.fraction = 0.6
122143
indicator.text = AutoDevBundle.message("intentions.request.background.process.title")
123144

145+
// Check for cancellation before LLM request
146+
indicator.checkCanceled()
147+
124148
val flow: Flow<String> = try {
125149
LlmFactory.create(request.project).stream(prompter, "", false)
126150
} catch (e: Exception) {
@@ -130,13 +154,28 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
130154
}
131155

132156
runBlocking {
133-
writeTestToFile(request.project, flow, testContext)
157+
// Check for cancellation before writing to file
158+
if (indicator.isCanceled) {
159+
throw ProcessCanceledException()
160+
}
161+
162+
writeTestToFile(request.project, flow, testContext, indicator)
163+
164+
// Check for cancellation before verification
165+
if (indicator.isCanceled) {
166+
throw ProcessCanceledException()
167+
}
134168

135169
indicator.fraction = 1.0
136170
indicator.text = AutoDevBundle.message("intentions.chat.code.test.verify")
137171

138172
try {
139173
autoTestService.collectSyntaxError(testContext.outputFile, request.project) {
174+
// Check for cancellation before fixing syntax errors
175+
if (indicator.isCanceled) {
176+
throw ProcessCanceledException()
177+
}
178+
140179
autoTestService.tryFixSyntaxError(testContext.outputFile, request.project, it)
141180

142181
if (it.isNotEmpty()) {
@@ -146,9 +185,15 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
146185
)
147186
indicator.fraction = 1.0
148187
} else {
149-
autoTestService.runFile(request.project, testContext.outputFile, testContext.testElement, false)
188+
// Check for cancellation before running tests
189+
if (!indicator.isCanceled) {
190+
autoTestService.runFile(request.project, testContext.outputFile, testContext.testElement, false)
191+
}
150192
}
151193
}
194+
} catch (e: ProcessCanceledException) {
195+
// Re-throw cancellation exception
196+
throw e
152197
} catch (e: Exception) {
153198
AutoDevStatusService.notifyApplication(AutoDevStatus.Ready)
154199
indicator.fraction = 1.0
@@ -169,6 +214,7 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
169214
project: Project,
170215
flow: Flow<String>,
171216
context: TestFileContext,
217+
indicator: ProgressIndicator
172218
) {
173219
val fileEditorManager = FileEditorManager.getInstance(project)
174220
var editors: Array<FileEditor> = emptyArray()
@@ -183,6 +229,11 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
183229
val editor = fileEditorManager.selectedTextEditor
184230

185231
flow.collect {
232+
// Check for cancellation during flow collection
233+
if (indicator.isCanceled) {
234+
throw ProcessCanceledException()
235+
}
236+
186237
suggestion.append(it)
187238
val codeBlocks = MarkdownCodeHelper.parseCodeFromString(suggestion.toString())
188239
codeBlocks.forEach {
@@ -200,6 +251,10 @@ class TestCodeGenTask(val request: TestCodeGenRequest, displayMessage: String) :
200251

201252
val suggestion = StringBuilder()
202253
flow.collect {
254+
// Check for cancellation during flow collection
255+
if (indicator.isCanceled) {
256+
throw ProcessCanceledException()
257+
}
203258
suggestion.append(it)
204259
}
205260

0 commit comments

Comments
 (0)