Skip to content

Add incoming_calls and outgoing_calls tools #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

---
Name: HelperFunction
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • helper.go
- Called By: AnotherConsumer
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • another_consumer.go
- Called By: ConsumerFunction
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • consumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

---
Name: FooBar
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go
- Called By: main
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

---
Name: ConsumerFunction
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • consumer.go
- Calls: GetName
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go
- Calls: HelperFunction
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • helper.go
- Calls: Method
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go
- Calls: Println
Detail: fmt • print.go
- Calls: Fprintln
Detail: fmt • print.go
- Calls: Process
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go
- Calls: Printf
Detail: fmt • print.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

---
Name: main
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go
- Calls: FooBar
Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go
- Calls: Println
Detail: fmt • print.go
- Calls: Println
Detail: fmt • print.go
- Calls: Fprintln
Detail: fmt • print.go
61 changes: 61 additions & 0 deletions integrationtests/tests/go/call_hierarchy/incoming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package callhierarchy_test

import (
"context"
"strings"
"testing"
"time"

"github.com/isaacphi/mcp-language-server/integrationtests/tests/common"
"github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal"
"github.com/isaacphi/mcp-language-server/internal/tools"
)

func TestIncomingCalls(t *testing.T) {
suite := internal.GetTestSuite(t)

ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second)
defer cancel()

tests := []struct {
name string
symbolName string
depth int
expectedText string
snapshotName string
}{
{
name: "Function with calls in same file",
symbolName: "FooBar",
expectedText: ": main",
depth: 5,
snapshotName: "incoming-same-file",
},
{
name: "Function with calls in other file",
symbolName: "HelperFunction",
depth: 5,
expectedText: "ConsumerFunction",
snapshotName: "incoming-other-file",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Call the GetIncomingCalls tool
result, err := tools.GetIncomingCalls(ctx, suite.Client, tc.symbolName, tc.depth)
if err != nil {
t.Fatalf("Failed to find incoming calls: %v", err)
}

// Check that the result contains relevant information
if !strings.Contains(result, tc.expectedText) {
t.Errorf("Incoming calls do not contain expected text: %s", tc.expectedText)
}

// Use snapshot testing to verify exact output
common.SnapshotTest(t, "go", "call_hierarchy", tc.snapshotName, result)
})
}

}
61 changes: 61 additions & 0 deletions integrationtests/tests/go/call_hierarchy/outgoing_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package callhierarchy_test

import (
"context"
"strings"
"testing"
"time"

"github.com/isaacphi/mcp-language-server/integrationtests/tests/common"
"github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal"
"github.com/isaacphi/mcp-language-server/internal/tools"
)

func TestOutgoingCalls(t *testing.T) {
suite := internal.GetTestSuite(t)

ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second)
defer cancel()

tests := []struct {
name string
symbolName string
depth int
expectedText string
snapshotName string
}{
{
name: "Function with calls in other file",
symbolName: "ConsumerFunction",
depth: 2,
expectedText: "HelperFunction",
snapshotName: "outgoing-other-file",
},
{
name: "Function with calls in same file",
symbolName: "main",
depth: 2,
expectedText: "FooBar",
snapshotName: "outgoing-same-file",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Call the GetOutgoingCalls tool
result, err := tools.GetOutgoingCalls(ctx, suite.Client, tc.symbolName, tc.depth)
if err != nil {
t.Fatalf("Failed to find outgoing calls: %v", err)
}

// Check that the result contains relevant information
if !strings.Contains(result, tc.expectedText) {
t.Errorf("Outgoing calls do not contain expected text: %s", tc.expectedText)
}

// Use snapshot testing to verify exact output
common.SnapshotTest(t, "go", "call_hierarchy", tc.snapshotName, result)
})
}

}
176 changes: 176 additions & 0 deletions internal/tools/call_hierarchy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package tools

import (
"context"
"fmt"
"sort"
"strings"

"github.com/isaacphi/mcp-language-server/internal/lsp"
"github.com/isaacphi/mcp-language-server/internal/protocol"
)

func GetIncomingCalls(ctx context.Context, client *lsp.Client, symbolName string, maxDepth int) (string, error) {
return getCallHierarchy(ctx, client, symbolName, maxDepth, recurseIncomingCalls)
}

func GetOutgoingCalls(ctx context.Context, client *lsp.Client, symbolName string, maxDepth int) (string, error) {
return getCallHierarchy(ctx, client, symbolName, maxDepth, recurseOutgoingCalls)
}

func getCallHierarchy(
ctx context.Context, client *lsp.Client, symbolName string, maxDepth int,
recurse func(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int),
) (string, error) {
// First get the symbol location like ReadDefinition does
symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{
Query: symbolName,
})
if err != nil {
return "", fmt.Errorf("failed to fetch symbol: %v", err)
}

results, err := symbolResult.Results()
if err != nil {
return "", fmt.Errorf("failed to parse results: %v", err)
}

// After this point we just return errors instead of erroring out
var result strings.Builder

for _, symbol := range results {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this tool it might be better to accept an exact location instead of a name, sort of like how hover works. The reason I say this is because "definition" is currently pretty imprecise, especailly in larger projects. If it is based on a location then it could be called after a first call to find out where the desired symbol is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe?

In my initial experiments, I'm not sure adding these calls is actually useful anyways. For some projects the output is really large because the LLM isn't setting depth. I noticed that without this it would use references to basically do the same thing.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean you're not sure if these call_hierarchy tools are useful?

In my experience, it already works really well if you instruct the LLM to recursively look up definitions that are relevant to a specific code flow you're interested in. This way it only looks things up that are relevant.

// Handle different matching strategies based on the search term
if strings.Contains(symbolName, ".") {
// For qualified names like "Type.Method", check for various matches
parts := strings.Split(symbolName, ".")
methodName := parts[len(parts)-1]

// Try matching the unqualified method name for languages that don't use qualified names in symbols
if symbol.GetName() != symbolName && symbol.GetName() != methodName {
continue
}
} else if symbol.GetName() != symbolName {
// For unqualified names, exact match only
continue
}

result.WriteString("\n---\n")

// Get the location of the symbol
loc := symbol.GetLocation()

chParams := protocol.CallHierarchyPrepareParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: loc.URI,
},
Position: loc.Range.Start,
},
}
items, err := client.PrepareCallHierarchy(ctx, chParams)
if err != nil {
result.WriteString(fmt.Sprintf("%s: Error: %v\n", symbol.GetName(), err))
continue
}

for _, item := range items {
recurse(ctx, client, item, &result, 0, maxDepth)
}
}

return result.String(), nil
}

func recurseIncomingCalls(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int) {

var prefix string
if depth != 0 {
prefix = strings.Repeat(" ", (depth-1)*2+2)

result.WriteString(strings.Repeat(" ", (depth-1)*2))
result.WriteRune('-')
result.WriteString(" Called By: ")
} else {
result.WriteString("Name: ")
}

result.WriteString(item.Name)
result.WriteRune('\n')

result.WriteString(prefix)
result.WriteString("Detail: ")
result.WriteString(item.Detail)
result.WriteRune('\n')

if depth >= maxDepth {
return
}

calls, err := client.IncomingCalls(ctx, protocol.CallHierarchyIncomingCallsParams{
Item: item,
})

if err != nil {
result.WriteString(prefix)
result.WriteString("Error: ")
result.WriteString(err.Error())
result.WriteRune('\n')
return
}

// ensure output is deterministic for tests
sort.Slice(calls, func(i, j int) bool {
return calls[i].From.Name < calls[j].From.Name
})

for _, call := range calls {
recurseIncomingCalls(ctx, client, call.From, result, depth+1, maxDepth)
}
}

func recurseOutgoingCalls(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int) {

var prefix string
if depth != 0 {
prefix = strings.Repeat(" ", (depth-1)*2+2)

result.WriteString(strings.Repeat(" ", (depth-1)*2))
result.WriteRune('-')
result.WriteString(" Calls: ")
} else {
result.WriteString("Name: ")
}

result.WriteString(item.Name)
result.WriteRune('\n')

result.WriteString(prefix)
result.WriteString("Detail: ")
result.WriteString(item.Detail)
result.WriteRune('\n')

if depth >= maxDepth {
return
}

calls, err := client.OutgoingCalls(ctx, protocol.CallHierarchyOutgoingCallsParams{
Item: item,
})

if err != nil {
result.WriteString(prefix)
result.WriteString("Error: ")
result.WriteString(err.Error())
result.WriteRune('\n')
return
}

// ensure output is deterministic for tests
sort.Slice(calls, func(i, j int) bool {
return calls[i].To.Name < calls[j].To.Name
})

for _, call := range calls {
recurseOutgoingCalls(ctx, client, call.To, result, depth+1, maxDepth)
}
}
Loading
Loading