Skip to content

support string or null Message.ID #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion integrationtests/snapshots/rust/hover/struct-type.snap
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ pub struct TestStruct {
}


size = 32 (0x20), align = 0x8
size = 32 (0x20), align = 0x8


contain types with destructors (drop glue); doesn't have a destructor
3 changes: 2 additions & 1 deletion integrationtests/tests/rust/diagnostics/diagnostics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ func TestDiagnostics(t *testing.T) {
t.Errorf("Expected unreachable code error but got: %s", result)
}

common.SnapshotTest(t, "rust", "diagnostics", "unreachable", result)
t.Skip("Flaky snapshot. If we have diagnostics then it's working, but the format changes often.")
// common.SnapshotTest(t, "rust", "diagnostics", "unreachable", result)
})

// Test file dependency: file A (helper.rs) provides a function,
Expand Down
2 changes: 0 additions & 2 deletions integrationtests/tests/rust/hover/hover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ func TestHover(t *testing.T) {
t.Fatalf("Failed to open %s: %v", tt.file, err)
}

time.Sleep(3 * time.Second)

// Get hover info
result, err := tools.GetHoverInfo(ctx, suite.Client, filePath, tt.line, tt.column)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/tests/rust/internal/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func GetTestSuite(t *testing.T) *common.TestSuite {
Command: "rust-analyzer",
Args: []string{},
WorkspaceDir: filepath.Join(repoRoot, "integrationtests/workspaces/rust"),
InitializeTimeMs: 5000,
InitializeTimeMs: 3000,
}

// Create a test suite
Expand Down
4 changes: 2 additions & 2 deletions internal/lsp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Client struct {
nextID atomic.Int32

// Response handlers
handlers map[int32]chan *Message
handlers map[string]chan *Message
handlersMu sync.RWMutex

// Server request handlers
Expand Down Expand Up @@ -71,7 +71,7 @@ func NewClient(command string, args ...string) (*Client, error) {
stdin: stdin,
stdout: bufio.NewReader(stdout),
stderr: stderr,
handlers: make(map[int32]chan *Message),
handlers: make(map[string]chan *Message),
notificationHandlers: make(map[string]NotificationHandler),
serverRequestHandlers: make(map[string]ServerRequestHandler),
diagnostics: make(map[protocol.DocumentUri][]protocol.Diagnostic),
Expand Down
74 changes: 71 additions & 3 deletions internal/lsp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,78 @@ package lsp

import (
"encoding/json"
"fmt"
"strconv"
)

// MessageID represents a JSON-RPC ID which can be a string, number, or null
// per the JSON-RPC 2.0 specification
type MessageID struct {
Value any
}

// MarshalJSON implements custom JSON marshaling for MessageID
func (id *MessageID) MarshalJSON() ([]byte, error) {
if id == nil || id.Value == nil {
return []byte("null"), nil
}
return json.Marshal(id.Value)
}

// UnmarshalJSON implements custom JSON unmarshaling for MessageID
func (id *MessageID) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
id.Value = nil
return nil
}

var value any
if err := json.Unmarshal(data, &value); err != nil {
return err
}

// Convert float64 (default JSON number type) to int32 for backward compatibility
if num, ok := value.(float64); ok {
id.Value = int32(num)
} else {
id.Value = value
}

return nil
}

// String returns a string representation of the ID
func (id *MessageID) String() string {
if id == nil || id.Value == nil {
return "<null>"
}

switch v := id.Value.(type) {
case int32:
return strconv.FormatInt(int64(v), 10)
case string:
return v
default:
return fmt.Sprintf("%v", v)
}
}

// Equals checks if two MessageIDs are equal
func (id *MessageID) Equals(other *MessageID) bool {
if id == nil || other == nil {
return id == other
}
if id.Value == nil || other.Value == nil {
return id.Value == other.Value
}

return fmt.Sprintf("%v", id.Value) == fmt.Sprintf("%v", other.Value)
}

// Message represents a JSON-RPC 2.0 message
type Message struct {
JSONRPC string `json:"jsonrpc"`
ID int32 `json:"id,omitempty"`
ID *MessageID `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Expand All @@ -20,15 +86,15 @@ type ResponseError struct {
Message string `json:"message"`
}

func NewRequest(id int32, method string, params any) (*Message, error) {
func NewRequest(id any, method string, params any) (*Message, error) {
paramsJSON, err := json.Marshal(params)
if err != nil {
return nil, err
}

return &Message{
JSONRPC: "2.0",
ID: id,
ID: &MessageID{Value: id},
Method: method,
Params: paramsJSON,
}, nil
Expand All @@ -44,5 +110,7 @@ func NewNotification(method string, params any) (*Message, error) {
JSONRPC: "2.0",
Method: method,
Params: paramsJSON,
// Notifications don't have an ID by definition
ID: nil,
}, nil
}
38 changes: 21 additions & 17 deletions internal/lsp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
}

// High-level operation log
lspLogger.Debug("Sending message: method=%s id=%d", msg.Method, msg.ID)
lspLogger.Debug("Sending message: method=%s id=%v", msg.Method, msg.ID)

// Wire protocol log (more detailed)
wireLogger.Debug("-> Sending: %s", string(data))
Expand Down Expand Up @@ -83,12 +83,12 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}

// Log higher-level information about the message type
if msg.Method != "" && msg.ID != 0 {
lspLogger.Debug("Received request from server: method=%s id=%d", msg.Method, msg.ID)
if msg.Method != "" && msg.ID != nil && msg.ID.Value != nil {
lspLogger.Debug("Received request from server: method=%s id=%v", msg.Method, msg.ID)
} else if msg.Method != "" {
lspLogger.Debug("Received notification: method=%s", msg.Method)
} else if msg.ID != 0 {
lspLogger.Debug("Received response for ID: %d", msg.ID)
} else if msg.ID != nil && msg.ID.Value != nil {
lspLogger.Debug("Received response for ID: %v", msg.ID)
}

return &msg, nil
Expand All @@ -109,7 +109,7 @@ func (c *Client) handleMessages() {
}

// Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 {
if msg.Method != "" && msg.ID != nil && msg.ID.Value != nil {
response := &Message{
JSONRPC: "2.0",
ID: msg.ID,
Expand All @@ -121,7 +121,7 @@ func (c *Client) handleMessages() {
c.serverHandlersMu.RUnlock()

if ok {
lspLogger.Debug("Processing server request: method=%s id=%d", msg.Method, msg.ID)
lspLogger.Debug("Processing server request: method=%s id=%v", msg.Method, msg.ID)
result, err := handler(msg.Params)
if err != nil {
lspLogger.Error("Error handling server request %s: %v", msg.Method, err)
Expand Down Expand Up @@ -158,7 +158,7 @@ func (c *Client) handleMessages() {
}

// Handle notification (has Method but no ID)
if msg.Method != "" && msg.ID == 0 {
if msg.Method != "" && (msg.ID == nil || msg.ID.Value == nil) {
c.notificationMu.RLock()
handler, ok := c.notificationHandlers[msg.Method]
c.notificationMu.RUnlock()
Expand All @@ -173,17 +173,19 @@ func (c *Client) handleMessages() {
}

// Handle response to our request (has ID but no Method)
if msg.ID != 0 && msg.Method == "" {
if msg.ID != nil && msg.ID.Value != nil && msg.Method == "" {
// Convert ID to string for map lookup
idStr := msg.ID.String()
c.handlersMu.RLock()
ch, ok := c.handlers[msg.ID]
ch, ok := c.handlers[idStr]
c.handlersMu.RUnlock()

if ok {
lspLogger.Debug("Sending response for ID %d to handler", msg.ID)
lspLogger.Debug("Sending response for ID %v to handler", msg.ID)
ch <- msg
close(ch)
} else {
lspLogger.Debug("No handler for response ID: %d", msg.ID)
lspLogger.Debug("No handler for response ID: %v", msg.ID)
}
}
}
Expand All @@ -193,7 +195,7 @@ func (c *Client) handleMessages() {
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
id := c.nextID.Add(1)

lspLogger.Debug("Making call: method=%s id=%d", method, id)
lspLogger.Debug("Making call: method=%s id=%v", method, id)

msg, err := NewRequest(id, method, params)
if err != nil {
Expand All @@ -202,13 +204,15 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any

// Create response channel
ch := make(chan *Message, 1)
// Convert ID to string for map lookup
idStr := msg.ID.String()
c.handlersMu.Lock()
c.handlers[id] = ch
c.handlers[idStr] = ch
c.handlersMu.Unlock()

defer func() {
c.handlersMu.Lock()
delete(c.handlers, id)
delete(c.handlers, idStr)
c.handlersMu.Unlock()
}()

Expand All @@ -217,12 +221,12 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
return fmt.Errorf("failed to send request: %w", err)
}

lspLogger.Debug("Waiting for response to request ID: %d", id)
lspLogger.Debug("Waiting for response to request ID: %v", msg.ID)

// Wait for response
resp := <-ch

lspLogger.Debug("Received response for request ID: %d", id)
lspLogger.Debug("Received response for request ID: %v", msg.ID)

if resp.Error != nil {
lspLogger.Error("Request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)
Expand Down
Loading