Skip to content

Commit 46e2950

Browse files
authored
support string or null Message.ID (#31)
1 parent 4f895b5 commit 46e2950

File tree

7 files changed

+101
-27
lines changed

7 files changed

+101
-27
lines changed

integrationtests/snapshots/rust/hover/struct-type.snap

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,7 @@ pub struct TestStruct {
66
}
77

88

9-
size = 32 (0x20), align = 0x8
9+
size = 32 (0x20), align = 0x8
10+
11+
12+
contain types with destructors (drop glue); doesn't have a destructor

integrationtests/tests/rust/diagnostics/diagnostics_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ func TestDiagnostics(t *testing.T) {
8888
t.Errorf("Expected unreachable code error but got: %s", result)
8989
}
9090

91-
common.SnapshotTest(t, "rust", "diagnostics", "unreachable", result)
91+
t.Skip("Flaky snapshot. If we have diagnostics then it's working, but the format changes often.")
92+
// common.SnapshotTest(t, "rust", "diagnostics", "unreachable", result)
9293
})
9394

9495
// Test file dependency: file A (helper.rs) provides a function,

integrationtests/tests/rust/hover/hover_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ func TestHover(t *testing.T) {
131131
t.Fatalf("Failed to open %s: %v", tt.file, err)
132132
}
133133

134-
time.Sleep(3 * time.Second)
135-
136134
// Get hover info
137135
result, err := tools.GetHoverInfo(ctx, suite.Client, filePath, tt.line, tt.column)
138136
if err != nil {

integrationtests/tests/rust/internal/helpers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func GetTestSuite(t *testing.T) *common.TestSuite {
2121
Command: "rust-analyzer",
2222
Args: []string{},
2323
WorkspaceDir: filepath.Join(repoRoot, "integrationtests/workspaces/rust"),
24-
InitializeTimeMs: 5000,
24+
InitializeTimeMs: 3000,
2525
}
2626

2727
// Create a test suite

internal/lsp/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type Client struct {
2626
nextID atomic.Int32
2727

2828
// Response handlers
29-
handlers map[int32]chan *Message
29+
handlers map[string]chan *Message
3030
handlersMu sync.RWMutex
3131

3232
// Server request handlers
@@ -71,7 +71,7 @@ func NewClient(command string, args ...string) (*Client, error) {
7171
stdin: stdin,
7272
stdout: bufio.NewReader(stdout),
7373
stderr: stderr,
74-
handlers: make(map[int32]chan *Message),
74+
handlers: make(map[string]chan *Message),
7575
notificationHandlers: make(map[string]NotificationHandler),
7676
serverRequestHandlers: make(map[string]ServerRequestHandler),
7777
diagnostics: make(map[protocol.DocumentUri][]protocol.Diagnostic),

internal/lsp/protocol.go

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,78 @@ package lsp
22

33
import (
44
"encoding/json"
5+
"fmt"
6+
"strconv"
57
)
68

9+
// MessageID represents a JSON-RPC ID which can be a string, number, or null
10+
// per the JSON-RPC 2.0 specification
11+
type MessageID struct {
12+
Value any
13+
}
14+
15+
// MarshalJSON implements custom JSON marshaling for MessageID
16+
func (id *MessageID) MarshalJSON() ([]byte, error) {
17+
if id == nil || id.Value == nil {
18+
return []byte("null"), nil
19+
}
20+
return json.Marshal(id.Value)
21+
}
22+
23+
// UnmarshalJSON implements custom JSON unmarshaling for MessageID
24+
func (id *MessageID) UnmarshalJSON(data []byte) error {
25+
if string(data) == "null" {
26+
id.Value = nil
27+
return nil
28+
}
29+
30+
var value any
31+
if err := json.Unmarshal(data, &value); err != nil {
32+
return err
33+
}
34+
35+
// Convert float64 (default JSON number type) to int32 for backward compatibility
36+
if num, ok := value.(float64); ok {
37+
id.Value = int32(num)
38+
} else {
39+
id.Value = value
40+
}
41+
42+
return nil
43+
}
44+
45+
// String returns a string representation of the ID
46+
func (id *MessageID) String() string {
47+
if id == nil || id.Value == nil {
48+
return "<null>"
49+
}
50+
51+
switch v := id.Value.(type) {
52+
case int32:
53+
return strconv.FormatInt(int64(v), 10)
54+
case string:
55+
return v
56+
default:
57+
return fmt.Sprintf("%v", v)
58+
}
59+
}
60+
61+
// Equals checks if two MessageIDs are equal
62+
func (id *MessageID) Equals(other *MessageID) bool {
63+
if id == nil || other == nil {
64+
return id == other
65+
}
66+
if id.Value == nil || other.Value == nil {
67+
return id.Value == other.Value
68+
}
69+
70+
return fmt.Sprintf("%v", id.Value) == fmt.Sprintf("%v", other.Value)
71+
}
72+
773
// Message represents a JSON-RPC 2.0 message
874
type Message struct {
975
JSONRPC string `json:"jsonrpc"`
10-
ID int32 `json:"id,omitempty"`
76+
ID *MessageID `json:"id,omitempty"`
1177
Method string `json:"method,omitempty"`
1278
Params json.RawMessage `json:"params,omitempty"`
1379
Result json.RawMessage `json:"result,omitempty"`
@@ -20,15 +86,15 @@ type ResponseError struct {
2086
Message string `json:"message"`
2187
}
2288

23-
func NewRequest(id int32, method string, params any) (*Message, error) {
89+
func NewRequest(id any, method string, params any) (*Message, error) {
2490
paramsJSON, err := json.Marshal(params)
2591
if err != nil {
2692
return nil, err
2793
}
2894

2995
return &Message{
3096
JSONRPC: "2.0",
31-
ID: id,
97+
ID: &MessageID{Value: id},
3298
Method: method,
3399
Params: paramsJSON,
34100
}, nil
@@ -44,5 +110,7 @@ func NewNotification(method string, params any) (*Message, error) {
44110
JSONRPC: "2.0",
45111
Method: method,
46112
Params: paramsJSON,
113+
// Notifications don't have an ID by definition
114+
ID: nil,
47115
}, nil
48116
}

internal/lsp/transport.go

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
2424
}
2525

2626
// High-level operation log
27-
lspLogger.Debug("Sending message: method=%s id=%d", msg.Method, msg.ID)
27+
lspLogger.Debug("Sending message: method=%s id=%v", msg.Method, msg.ID)
2828

2929
// Wire protocol log (more detailed)
3030
wireLogger.Debug("-> Sending: %s", string(data))
@@ -83,12 +83,12 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
8383
}
8484

8585
// Log higher-level information about the message type
86-
if msg.Method != "" && msg.ID != 0 {
87-
lspLogger.Debug("Received request from server: method=%s id=%d", msg.Method, msg.ID)
86+
if msg.Method != "" && msg.ID != nil && msg.ID.Value != nil {
87+
lspLogger.Debug("Received request from server: method=%s id=%v", msg.Method, msg.ID)
8888
} else if msg.Method != "" {
8989
lspLogger.Debug("Received notification: method=%s", msg.Method)
90-
} else if msg.ID != 0 {
91-
lspLogger.Debug("Received response for ID: %d", msg.ID)
90+
} else if msg.ID != nil && msg.ID.Value != nil {
91+
lspLogger.Debug("Received response for ID: %v", msg.ID)
9292
}
9393

9494
return &msg, nil
@@ -109,7 +109,7 @@ func (c *Client) handleMessages() {
109109
}
110110

111111
// Handle server->client request (has both Method and ID)
112-
if msg.Method != "" && msg.ID != 0 {
112+
if msg.Method != "" && msg.ID != nil && msg.ID.Value != nil {
113113
response := &Message{
114114
JSONRPC: "2.0",
115115
ID: msg.ID,
@@ -121,7 +121,7 @@ func (c *Client) handleMessages() {
121121
c.serverHandlersMu.RUnlock()
122122

123123
if ok {
124-
lspLogger.Debug("Processing server request: method=%s id=%d", msg.Method, msg.ID)
124+
lspLogger.Debug("Processing server request: method=%s id=%v", msg.Method, msg.ID)
125125
result, err := handler(msg.Params)
126126
if err != nil {
127127
lspLogger.Error("Error handling server request %s: %v", msg.Method, err)
@@ -158,7 +158,7 @@ func (c *Client) handleMessages() {
158158
}
159159

160160
// Handle notification (has Method but no ID)
161-
if msg.Method != "" && msg.ID == 0 {
161+
if msg.Method != "" && (msg.ID == nil || msg.ID.Value == nil) {
162162
c.notificationMu.RLock()
163163
handler, ok := c.notificationHandlers[msg.Method]
164164
c.notificationMu.RUnlock()
@@ -173,17 +173,19 @@ func (c *Client) handleMessages() {
173173
}
174174

175175
// Handle response to our request (has ID but no Method)
176-
if msg.ID != 0 && msg.Method == "" {
176+
if msg.ID != nil && msg.ID.Value != nil && msg.Method == "" {
177+
// Convert ID to string for map lookup
178+
idStr := msg.ID.String()
177179
c.handlersMu.RLock()
178-
ch, ok := c.handlers[msg.ID]
180+
ch, ok := c.handlers[idStr]
179181
c.handlersMu.RUnlock()
180182

181183
if ok {
182-
lspLogger.Debug("Sending response for ID %d to handler", msg.ID)
184+
lspLogger.Debug("Sending response for ID %v to handler", msg.ID)
183185
ch <- msg
184186
close(ch)
185187
} else {
186-
lspLogger.Debug("No handler for response ID: %d", msg.ID)
188+
lspLogger.Debug("No handler for response ID: %v", msg.ID)
187189
}
188190
}
189191
}
@@ -193,7 +195,7 @@ func (c *Client) handleMessages() {
193195
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
194196
id := c.nextID.Add(1)
195197

196-
lspLogger.Debug("Making call: method=%s id=%d", method, id)
198+
lspLogger.Debug("Making call: method=%s id=%v", method, id)
197199

198200
msg, err := NewRequest(id, method, params)
199201
if err != nil {
@@ -202,13 +204,15 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
202204

203205
// Create response channel
204206
ch := make(chan *Message, 1)
207+
// Convert ID to string for map lookup
208+
idStr := msg.ID.String()
205209
c.handlersMu.Lock()
206-
c.handlers[id] = ch
210+
c.handlers[idStr] = ch
207211
c.handlersMu.Unlock()
208212

209213
defer func() {
210214
c.handlersMu.Lock()
211-
delete(c.handlers, id)
215+
delete(c.handlers, idStr)
212216
c.handlersMu.Unlock()
213217
}()
214218

@@ -217,12 +221,12 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
217221
return fmt.Errorf("failed to send request: %w", err)
218222
}
219223

220-
lspLogger.Debug("Waiting for response to request ID: %d", id)
224+
lspLogger.Debug("Waiting for response to request ID: %v", msg.ID)
221225

222226
// Wait for response
223227
resp := <-ch
224228

225-
lspLogger.Debug("Received response for request ID: %d", id)
229+
lspLogger.Debug("Received response for request ID: %v", msg.ID)
226230

227231
if resp.Error != nil {
228232
lspLogger.Error("Request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)

0 commit comments

Comments
 (0)