Skip to content

Commit 93756d3

Browse files
committed
feat(endpoints): add callee lookup for related classes #308
Introduce `lookupCallee` method to find methods called by a given method. This extends the functionality of `RelatedClassesProvider` and integrates it into the endpoint knowledge provider to include callees in the result set
1 parent a1d6c4e commit 93756d3

File tree

4 files changed

+83
-3
lines changed

4 files changed

+83
-3
lines changed

core/src/main/kotlin/cc/unitmesh/devti/provider/RelatedClassesProvider.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cc.unitmesh.devti.provider
22

33
import com.intellij.lang.Language
44
import com.intellij.lang.LanguageExtension
5+
import com.intellij.openapi.project.Project
56
import com.intellij.psi.PsiElement
67
import com.intellij.psi.PsiFile
78

@@ -27,6 +28,8 @@ interface RelatedClassesProvider {
2728
*/
2829
fun lookupIO(element: PsiElement): List<PsiElement>
2930

31+
fun lookupCallee(project: Project, element: PsiElement): List<PsiElement> = emptyList()
32+
3033
fun lookupIO(element: PsiFile): List<PsiElement>
3134

3235
companion object {

exts/ext-endpoints/src/233/main/kotlin/cc/unitmesh/endpoints/bridge/EndpointKnowledgeWebApiProvider.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ class EndpointKnowledgeWebApiProvider : KnowledgeWebApiProvider() {
3434
RelatedClassesProvider.provide(it.language)?.lookupIO(it)
3535
}.flatten()
3636

37-
val allElements = decls + relatedCode
37+
val callees = decls.mapNotNull {
38+
RelatedClassesProvider.provide(it.language)?.lookupCallee(project, it)
39+
}.flatten()
40+
41+
val allElements = decls + relatedCode + callees
3842
future.complete(allElements)
3943
}
4044
}

java/src/main/kotlin/cc/unitmesh/idea/provider/JavaRelatedClassesProvider.kt

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ package cc.unitmesh.idea.provider
33
import cc.unitmesh.devti.provider.RelatedClassesProvider
44
import cc.unitmesh.idea.context.JavaContextCollection
55
import cc.unitmesh.idea.service.JavaTypeUtil.resolveByType
6+
import cc.unitmesh.idea.util.JavaCallHelper.findCallees
67
import com.intellij.openapi.application.ApplicationManager
78
import com.intellij.openapi.application.runReadAction
9+
import com.intellij.openapi.project.Project
810
import com.intellij.openapi.roots.ProjectFileIndex
911
import com.intellij.psi.*
10-
import com.intellij.psi.util.*
12+
import com.intellij.psi.util.PsiTreeUtil
13+
import com.intellij.psi.util.PsiUtil
1114
import com.intellij.testIntegration.TestFinderHelper
1215

1316
class JavaRelatedClassesProvider : RelatedClassesProvider {
@@ -24,6 +27,13 @@ class JavaRelatedClassesProvider : RelatedClassesProvider {
2427
}
2528
}
2629

30+
override fun lookupCallee(project: Project, element: PsiElement): List<PsiElement> {
31+
return when (element) {
32+
is PsiMethod -> findCallees(project, element)
33+
else -> emptyList()
34+
}
35+
}
36+
2737
override fun lookupIO(element: PsiFile): List<PsiElement> {
2838
return when (element) {
2939
is PsiJavaFile -> findRelatedClasses(element.classes.first()) + lookupTestFile(element.classes.first())
@@ -57,7 +67,10 @@ class JavaRelatedClassesProvider : RelatedClassesProvider {
5767
val resolve = (it.type as PsiClassType).resolve() ?: return@mapNotNull null
5868
if (resolve.qualifiedName == qualifiedName) return@mapNotNull null
5969

60-
if (isJavaBuiltin(resolve.qualifiedName) == true || JavaContextCollection.isPopularFramework(resolve.qualifiedName) == true) {
70+
if (isJavaBuiltin(resolve.qualifiedName) == true || JavaContextCollection.isPopularFramework(
71+
resolve.qualifiedName
72+
) == true
73+
) {
6174
return@mapNotNull null
6275
}
6376

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package cc.unitmesh.idea.util
2+
3+
import com.intellij.openapi.progress.ProgressManager
4+
import com.intellij.openapi.progress.util.ProgressIndicatorBase
5+
import com.intellij.openapi.project.Project
6+
import com.intellij.psi.JavaRecursiveElementVisitor
7+
import com.intellij.psi.PsiMethod
8+
import com.intellij.psi.PsiMethodCallExpression
9+
import com.intellij.psi.search.ProjectScope
10+
import com.intellij.psi.search.searches.MethodReferencesSearch
11+
import com.intellij.psi.util.*
12+
13+
object JavaCallHelper {
14+
/**
15+
* Finds all the methods called by the given method.
16+
*
17+
* @param method the method for which callees need to be found
18+
* @return a list of PsiMethod objects representing the methods called by the given method
19+
*/
20+
fun findCallees(project: Project, method: PsiMethod): List<PsiMethod> {
21+
val calledMethods = mutableSetOf<PsiMethod>()
22+
method.accept(object : JavaRecursiveElementVisitor() {
23+
override fun visitMethodCallExpression(expression: PsiMethodCallExpression) {
24+
super.visitMethodCallExpression(expression)
25+
calledMethods.add(expression.resolveMethod() ?: return)
26+
}
27+
})
28+
29+
return calledMethods
30+
.filter {
31+
val containingClass = it.containingClass ?: return@filter false
32+
if (!ProjectScope.getProjectScope(project).contains(containingClass.containingFile.virtualFile)) {
33+
return@filter false
34+
}
35+
36+
true
37+
}
38+
}
39+
40+
/**
41+
* Finds all the callers of a given method.
42+
*
43+
* @param method the method for which callers need to be found
44+
* @return a list of PsiMethod objects representing the callers of the given method
45+
*/
46+
fun findCallers(project: Project, method: PsiMethod): List<PsiMethod> {
47+
val callers: MutableList<PsiMethod> = ArrayList()
48+
49+
ProgressManager.getInstance().runProcess(Runnable {
50+
val references = MethodReferencesSearch.search(method, method.useScope, true).findAll()
51+
for (reference in references) {
52+
PsiTreeUtil.getParentOfType(reference.element, PsiMethod::class.java)?.let {
53+
callers.add(it)
54+
}
55+
}
56+
}, ProgressIndicatorBase())
57+
58+
return callers.distinct()
59+
}
60+
}

0 commit comments

Comments
 (0)