Skip to content

Commit 5aa76c2

Browse files
committed
feat(rust): add RustClassContextBuilderTest and modify RustClassContextBuilder
- Add RustClassContextBuilderTest to test the functionality of the RustClassContextBuilder class. - Modify RustClassContextBuilder to include functions in the ClassContext and remove unnecessary code. - The test case in RustClassContextBuilderTest checks if the struct is formatted correctly and returns the expected result.
1 parent 5c7ad6f commit 5aa76c2

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

prompts/templates/Test.kt.vm

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
// Here is the Test template code.
2-
// for test intellij plugin
1+
// Here is the Test template code, please copy and paste to the test file
2+
// and replace the class name and the test method name
3+
//
34
//import com.intellij.testFramework.LightPlatformTestCase
45
//class /*TestClassName*/Test : LightPlatformTestCase() {
56
// // the Intellij test should start with test

rust/src/main/kotlin/cc/unitmesh/rust/context/RustClassContextBuilder.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package cc.unitmesh.rust.context
33
import cc.unitmesh.devti.context.ClassContext
44
import cc.unitmesh.devti.context.builder.ClassContextBuilder
55
import com.intellij.psi.PsiElement
6+
import com.intellij.psi.util.PsiTreeUtil
67
import org.rust.lang.core.psi.RsEnumItem
8+
import org.rust.lang.core.psi.RsFunction
9+
import org.rust.lang.core.psi.RsImplItem
710
import org.rust.lang.core.psi.RsStructItem
811
import org.rust.lang.core.psi.ext.RsStructOrEnumItemElement
912
import org.rust.lang.core.psi.ext.fields
@@ -15,13 +18,17 @@ class RustClassContextBuilder : ClassContextBuilder {
1518
when (psiElement) {
1619
is RsStructItem -> {
1720
val fields: List<PsiElement> = psiElement.fields.map {
18-
it.typeReference?.reference?.resolve() ?: it
21+
it
1922
}
23+
val impls = PsiTreeUtil.getChildrenOfTypeAsList(psiElement.containingFile, RsImplItem::class.java)
24+
val functions = impls.filter { it.name == psiElement.name }
25+
.flatMap { PsiTreeUtil.getChildrenOfTypeAsList(it, RsFunction::class.java) }
26+
2027
return ClassContext(
2128
psiElement,
2229
psiElement.text,
2330
psiElement.name,
24-
emptyList(),
31+
functions,
2532
fields,
2633
emptyList(),
2734
emptyList(),
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package cc.unitmesh.rust.context;
2+
3+
import cc.unitmesh.devti.context.ClassContext
4+
import cc.unitmesh.devti.context.builder.ClassContextBuilder
5+
import com.intellij.psi.PsiElement
6+
import com.intellij.psi.PsiFileFactory
7+
import com.intellij.psi.util.PsiTreeUtil
8+
import com.intellij.testFramework.fixtures.BasePlatformTestCase
9+
import com.jetbrains.cidr.lang.psi.OCDeclaration
10+
import org.junit.Assert.assertEquals
11+
import org.junit.Assert.assertNull
12+
import org.junit.Test
13+
import org.rust.lang.core.psi.RsEnumItem
14+
import org.rust.lang.core.psi.RsStructItem
15+
import org.rust.lang.core.psi.ext.RsStructOrEnumItemElement
16+
import org.rust.lang.core.psi.ext.fields
17+
18+
class RustClassContextBuilderTest: BasePlatformTestCase() {
19+
20+
fun testShouldFormatStruct() {
21+
// given
22+
val code = myFixture.configureByText("test.rs", """
23+
use crate::embedding::Embedding;
24+
use crate::similarity::{CosineSimilarity, RelevanceScore};
25+
26+
#[derive(Debug, Clone)]
27+
pub struct Entry {
28+
id: String,
29+
embedding: Embedding,
30+
embedded: Document,
31+
}
32+
33+
impl Entry {
34+
fn new(id: String, embedding: Embedding, embedded: Document) -> Self {
35+
Entry { id, embedding, embedded }
36+
}
37+
}
38+
""".trimIndent())
39+
40+
// when
41+
val decl = PsiTreeUtil.getChildrenOfTypeAsList(code, RsStructItem::class.java).first()
42+
43+
// then
44+
val result = RustClassContextBuilder().getClassContext(decl, false)!!
45+
assertEquals("Entry", result.name)
46+
assertEquals(result.format(), """
47+
'package: Entry
48+
class Entry {
49+
50+
51+
}
52+
""".trimIndent())
53+
}
54+
}

0 commit comments

Comments
 (0)