Skip to content

Commit 3e0c992

Browse files
committed
feat(rust): add relevant classes to TestFileContext
This commit adds the ability to include relevant classes in the TestFileContext of the RustTestService. The relevant classes are looked up based on the element passed to the lookupRelevantClass function. If the element is a RsFunction, the return type and input parameters are extracted and resolved to obtain the corresponding RustClassContext. These relevant classes are then included in the TestFileContext. Additionally, a new private function resolveReferenceTypes is added to resolve the reference types of RsTypeReference.
1 parent 8e92028 commit 3e0c992

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

rust/src/main/kotlin/cc/unitmesh/rust/context/RustMethodContextBuilder.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import cc.unitmesh.devti.context.MethodContext
44
import cc.unitmesh.devti.context.builder.MethodContextBuilder
55
import com.intellij.openapi.application.runReadAction
66
import com.intellij.psi.PsiElement
7+
import com.intellij.psi.util.PsiTreeUtil
78
import org.rust.ide.presentation.presentationInfo
89
import org.rust.lang.core.psi.*
910

@@ -20,15 +21,21 @@ class RustMethodContextBuilder : MethodContextBuilder {
2021
val language = psiElement.language.displayName
2122

2223
val signature = psiElement.presentationInfo?.signatureText
24+
val paramsName = psiElement.valueParameterList?.valueParameterList?.map {
25+
it.text
26+
} ?: emptyList()
27+
28+
val enclosingClass = PsiTreeUtil.getParentOfType(psiElement, RsImplItem::class.java)
29+
2330
return MethodContext(
2431
psiElement,
2532
text,
2633
psiElement.name,
2734
signature.toString(),
28-
null,
35+
enclosingClass,
2936
language,
3037
returnType,
31-
emptyList(),
38+
paramsName,
3239
includeClassContext,
3340
emptyList()
3441
)

rust/src/main/kotlin/cc/unitmesh/rust/provider/RustTestService.kt

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cc.unitmesh.rust.provider
33
import cc.unitmesh.devti.context.ClassContext
44
import cc.unitmesh.devti.provider.WriteTestService
55
import cc.unitmesh.devti.provider.context.TestFileContext
6+
import cc.unitmesh.rust.context.RustClassContextBuilder
67
import cc.unitmesh.rust.context.RustMethodContextBuilder
78
import com.intellij.execution.configurations.RunProfile
89
import com.intellij.openapi.application.runReadAction
@@ -13,6 +14,7 @@ import com.intellij.psi.util.PsiTreeUtil
1314
import org.rust.cargo.runconfig.command.CargoCommandConfiguration
1415
import org.rust.lang.RsLanguage
1516
import org.rust.lang.core.psi.RsFunction
17+
import org.rust.lang.core.psi.RsTypeReference
1618
import org.rust.lang.core.psi.RsUseItem
1719

1820
class RustTestService : WriteTestService() {
@@ -37,10 +39,12 @@ class RustTestService : WriteTestService() {
3739
it.text
3840
}
3941

42+
val relevantClasses = lookupRelevantClass(project, element)
43+
4044
return TestFileContext(
4145
false,
4246
sourceFile.virtualFile,
43-
listOf(),
47+
relevantClasses,
4448
"",
4549
RsLanguage,
4650
currentObject,
@@ -49,7 +53,30 @@ class RustTestService : WriteTestService() {
4953
}
5054

5155
override fun lookupRelevantClass(project: Project, element: PsiElement): List<ClassContext> {
56+
when (element) {
57+
is RsFunction -> {
58+
val returnType = element.retType?.typeReference
59+
val input = element.valueParameterList?.valueParameterList?.map {
60+
it.typeReference
61+
} ?: emptyList()
62+
63+
val refs = (listOf(returnType) + input).filterNotNull()
64+
val types = resolveReferenceTypes(project, refs)
65+
66+
return types.mapNotNull {
67+
RustClassContextBuilder().getClassContext(it, false)
68+
}
69+
}
70+
}
71+
5272
return listOf()
5373
}
5474

75+
private fun resolveReferenceTypes(project: Project, rsTypeReferences: List<RsTypeReference>): List<PsiElement> {
76+
val mapNotNull = rsTypeReferences.mapNotNull {
77+
it.reference?.resolve()
78+
}
79+
80+
return mapNotNull
81+
}
5582
}

src/main/kotlin/cc/unitmesh/devti/intentions/action/task/TestCodeGenTask.kt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import com.intellij.openapi.progress.Task
2222
import com.intellij.openapi.project.DumbService
2323
import com.intellij.openapi.project.Project
2424
import com.intellij.openapi.vfs.VirtualFile
25-
import kotlinx.coroutines.InternalCoroutinesApi
2625
import kotlinx.coroutines.flow.*
2726
import kotlinx.coroutines.runBlocking
2827

@@ -132,7 +131,6 @@ class TestCodeGenTask(val request: TestCodeGenRequest) :
132131
}
133132
}
134133

135-
@OptIn(InternalCoroutinesApi::class)
136134
private suspend fun writeTestToFile(
137135
project: Project,
138136
flow: Flow<String>,
@@ -148,7 +146,8 @@ class TestCodeGenTask(val request: TestCodeGenRequest) :
148146
val modifier = CodeModifierProvider().modifier(context.language)
149147
?: throw IllegalStateException("Unsupported language: ${context.language}")
150148

151-
parseCodeFromString(suggestion.toString()).forEach {
149+
val codeBlocks = parseCodeFromString(suggestion.toString())
150+
codeBlocks.forEach {
152151
modifier.insertTestCode(context.outputFile, project, it)
153152
}
154153
}

src/main/kotlin/cc/unitmesh/devti/util/parser/Markdown.kt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ fun parseCodeFromString(markdown: String): List<String> {
1919
node.accept(visitor)
2020

2121
if (visitor.code.isEmpty()) {
22-
// TODO: we need to add multiple code blocks support
23-
val isJavaMethod = markdown.contains("public ") || markdown.contains("private ") || markdown.contains("protected ")
24-
if (isJavaMethod) {
25-
return listOf(markdown)
26-
}
27-
2822
return listOf(markdown)
2923
}
3024

0 commit comments

Comments
 (0)