aboutsummaryrefslogtreecommitdiff
path: root/gopls/internal/lsp/lsprpc/middleware_test.go
blob: c528eae5c625a1937bb715e81877dea221f0b48a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package lsprpc_test

import (
	"context"
	"errors"
	"fmt"
	"testing"
	"time"

	. "golang.org/x/tools/gopls/internal/lsp/lsprpc"
	jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
)

var noopBinder = BinderFunc(func(context.Context, *jsonrpc2_v2.Connection) jsonrpc2_v2.ConnectionOptions {
	return jsonrpc2_v2.ConnectionOptions{}
})

func TestHandshakeMiddleware(t *testing.T) {
	sh := &Handshaker{
		Metadata: Metadata{
			"answer": 42,
		},
	}
	ctx := context.Background()
	env := new(TestEnv)
	defer env.Shutdown(t)
	l, _ := env.serve(ctx, t, sh.Middleware(noopBinder))
	conn := env.dial(ctx, t, l.Dialer(), noopBinder, false)
	ch := &Handshaker{
		Metadata: Metadata{
			"question": 6 * 9,
		},
	}

	check := func(connected bool) error {
		clients := sh.Peers()
		servers := ch.Peers()
		want := 0
		if connected {
			want = 1
		}
		if got := len(clients); got != want {
			return fmt.Errorf("got %d clients on the server, want %d", got, want)
		}
		if got := len(servers); got != want {
			return fmt.Errorf("got %d servers on the client, want %d", got, want)
		}
		if !connected {
			return nil
		}
		client := clients[0]
		server := servers[0]
		if _, ok := client.Metadata["question"]; !ok {
			return errors.New("no client metadata")
		}
		if _, ok := server.Metadata["answer"]; !ok {
			return errors.New("no server metadata")
		}
		if client.LocalID != server.RemoteID {
			return fmt.Errorf("client.LocalID == %d, server.PeerID == %d", client.LocalID, server.RemoteID)
		}
		if client.RemoteID != server.LocalID {
			return fmt.Errorf("client.PeerID == %d, server.LocalID == %d", client.RemoteID, server.LocalID)
		}
		return nil
	}

	if err := check(false); err != nil {
		t.Fatalf("before handshake: %v", err)
	}
	ch.ClientHandshake(ctx, conn)
	if err := check(true); err != nil {
		t.Fatalf("after handshake: %v", err)
	}
	conn.Close()
	// Wait for up to ~2s for connections to get cleaned up.
	delay := 25 * time.Millisecond
	for retries := 3; retries >= 0; retries-- {
		time.Sleep(delay)
		err := check(false)
		if err == nil {
			return
		}
		if retries == 0 {
			t.Fatalf("after closing connection: %v", err)
		}
		delay *= 4
	}
}