Skip to content

Commit 7dc968c

Browse files
committed
Add DERP meshing to arbitrary addresses
1 parent 1883430 commit 7dc968c

File tree

3 files changed

+270
-32
lines changed

3 files changed

+270
-32
lines changed

enterprise/derpmesh/derpmesh.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package derpmesh
2+
3+
import (
4+
"context"
5+
"sync"
6+
7+
"golang.org/x/xerrors"
8+
"tailscale.com/derp"
9+
"tailscale.com/derp/derphttp"
10+
"tailscale.com/types/key"
11+
12+
"github.com/coder/coder/tailnet"
13+
14+
"cdr.dev/slog"
15+
)
16+
17+
func New(logger slog.Logger, server *derp.Server) *Mesh {
18+
return &Mesh{
19+
logger: logger,
20+
server: server,
21+
ctx: context.Background(),
22+
closed: make(chan struct{}),
23+
active: make(map[string]context.CancelFunc),
24+
}
25+
}
26+
27+
type Mesh struct {
28+
logger slog.Logger
29+
server *derp.Server
30+
ctx context.Context
31+
32+
mutex sync.Mutex
33+
closed chan struct{}
34+
active map[string]context.CancelFunc
35+
}
36+
37+
// SetAddresses performs a diff of the incoming addresses and adds
38+
// or removes DERP clients from the mesh.
39+
func (m *Mesh) SetAddresses(addresses []string) {
40+
total := make(map[string]struct{}, 0)
41+
for _, address := range addresses {
42+
total[address] = struct{}{}
43+
added, err := m.addAddress(address)
44+
if err != nil {
45+
m.logger.Error(m.ctx, "failed to add address", slog.F("address", address), slog.Error(err))
46+
continue
47+
}
48+
if added {
49+
m.logger.Debug(m.ctx, "added mesh address", slog.F("address", address))
50+
}
51+
}
52+
53+
m.mutex.Lock()
54+
for address := range m.active {
55+
_, found := total[address]
56+
if found {
57+
continue
58+
}
59+
removed := m.removeAddress(address)
60+
if removed {
61+
m.logger.Debug(m.ctx, "removed mesh address", slog.F("address", address))
62+
}
63+
}
64+
m.mutex.Unlock()
65+
}
66+
67+
// addAddress begins meshing with a new address.
68+
// It's expected that this is a full HTTP address with a path.
69+
// e.g. http://127.0.0.1:8080/derp
70+
func (m *Mesh) addAddress(address string) (bool, error) {
71+
m.mutex.Lock()
72+
defer m.mutex.Unlock()
73+
_, isActive := m.active[address]
74+
if isActive {
75+
return false, nil
76+
}
77+
client, err := derphttp.NewClient(m.server.PrivateKey(), address, tailnet.Logger(m.logger))
78+
if err != nil {
79+
return false, xerrors.Errorf("create derp client: %w", err)
80+
}
81+
client.MeshKey = m.server.MeshKey()
82+
ctx, cancelFunc := context.WithCancel(m.ctx)
83+
closed := make(chan struct{})
84+
closeFunc := func() {
85+
cancelFunc()
86+
_ = client.Close()
87+
<-closed
88+
}
89+
m.active[address] = closeFunc
90+
go func() {
91+
defer close(closed)
92+
client.RunWatchConnectionLoop(ctx, m.server.PublicKey(), tailnet.Logger(m.logger), func(np key.NodePublic) {
93+
m.server.AddPacketForwarder(np, client)
94+
}, func(np key.NodePublic) {
95+
m.server.RemovePacketForwarder(np, client)
96+
})
97+
}()
98+
return true, nil
99+
}
100+
101+
// removeAddress stops meshing with a given address.
102+
func (m *Mesh) removeAddress(address string) bool {
103+
cancelFunc, isActive := m.active[address]
104+
if isActive {
105+
cancelFunc()
106+
}
107+
return isActive
108+
}
109+
110+
// Close ends all active meshes with the DERP server.
111+
func (m *Mesh) Close() error {
112+
m.mutex.Lock()
113+
defer m.mutex.Unlock()
114+
select {
115+
case <-m.closed:
116+
return nil
117+
default:
118+
}
119+
close(m.closed)
120+
for _, cancelFunc := range m.active {
121+
cancelFunc()
122+
}
123+
return nil
124+
}

enterprise/derpmesh/derpmesh_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package derpmesh_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
"go.uber.org/goleak"
13+
"tailscale.com/derp"
14+
"tailscale.com/derp/derphttp"
15+
"tailscale.com/types/key"
16+
17+
"cdr.dev/slog"
18+
"cdr.dev/slog/sloggers/slogtest"
19+
"github.com/coder/coder/enterprise/derpmesh"
20+
"github.com/coder/coder/tailnet"
21+
)
22+
23+
func TestMain(m *testing.M) {
24+
goleak.VerifyTestMain(m)
25+
}
26+
27+
func TestDERPMesh(t *testing.T) {
28+
t.Parallel()
29+
t.Run("ExchangeMessages", func(t *testing.T) {
30+
// This tests messages passing through multiple DERP servers.
31+
t.Parallel()
32+
firstServer, firstServerURL := startDERP(t)
33+
defer firstServer.Close()
34+
secondServer, secondServerURL := startDERP(t)
35+
firstMesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), firstServer)
36+
firstMesh.SetAddresses([]string{secondServerURL})
37+
secondMesh := derpmesh.New(slogtest.Make(t, nil).Named("second").Leveled(slog.LevelDebug), secondServer)
38+
secondMesh.SetAddresses([]string{firstServerURL})
39+
defer firstMesh.Close()
40+
defer secondMesh.Close()
41+
42+
first := key.NewNode()
43+
second := key.NewNode()
44+
firstClient, err := derphttp.NewClient(first, secondServerURL, tailnet.Logger(slogtest.Make(t, nil)))
45+
require.NoError(t, err)
46+
secondClient, err := derphttp.NewClient(second, firstServerURL, tailnet.Logger(slogtest.Make(t, nil)))
47+
require.NoError(t, err)
48+
err = secondClient.Connect(context.Background())
49+
require.NoError(t, err)
50+
51+
sent := []byte("hello world")
52+
err = firstClient.Send(second.Public(), sent)
53+
require.NoError(t, err)
54+
55+
got := recvData(t, secondClient)
56+
require.Equal(t, sent, got)
57+
})
58+
t.Run("RemoveAddress", func(t *testing.T) {
59+
// This tests messages passing through multiple DERP servers.
60+
t.Parallel()
61+
server, serverURL := startDERP(t)
62+
mesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), server)
63+
mesh.SetAddresses([]string{"http://fake.com"})
64+
// This should trigger a removal...
65+
mesh.SetAddresses([]string{})
66+
defer mesh.Close()
67+
68+
first := key.NewNode()
69+
second := key.NewNode()
70+
firstClient, err := derphttp.NewClient(first, serverURL, tailnet.Logger(slogtest.Make(t, nil)))
71+
require.NoError(t, err)
72+
secondClient, err := derphttp.NewClient(second, serverURL, tailnet.Logger(slogtest.Make(t, nil)))
73+
require.NoError(t, err)
74+
err = secondClient.Connect(context.Background())
75+
require.NoError(t, err)
76+
sent := []byte("hello world")
77+
err = firstClient.Send(second.Public(), sent)
78+
require.NoError(t, err)
79+
got := recvData(t, secondClient)
80+
require.Equal(t, sent, got)
81+
})
82+
t.Run("TwentyMeshes", func(t *testing.T) {
83+
t.Parallel()
84+
meshes := make([]*derpmesh.Mesh, 0, 20)
85+
serverURLs := make([]string, 0, 20)
86+
for i := 0; i < 20; i++ {
87+
server, url := startDERP(t)
88+
mesh := derpmesh.New(slogtest.Make(t, nil).Named("mesh").Leveled(slog.LevelDebug), server)
89+
t.Cleanup(func() {
90+
_ = server.Close()
91+
_ = mesh.Close()
92+
})
93+
serverURLs = append(serverURLs, url)
94+
meshes = append(meshes, mesh)
95+
}
96+
for _, mesh := range meshes {
97+
mesh.SetAddresses(serverURLs)
98+
}
99+
100+
first := key.NewNode()
101+
second := key.NewNode()
102+
firstClient, err := derphttp.NewClient(first, serverURLs[9], tailnet.Logger(slogtest.Make(t, nil)))
103+
require.NoError(t, err)
104+
secondClient, err := derphttp.NewClient(second, serverURLs[16], tailnet.Logger(slogtest.Make(t, nil)))
105+
require.NoError(t, err)
106+
err = secondClient.Connect(context.Background())
107+
require.NoError(t, err)
108+
109+
sent := []byte("hello world")
110+
err = firstClient.Send(second.Public(), sent)
111+
require.NoError(t, err)
112+
113+
got := recvData(t, secondClient)
114+
require.Equal(t, sent, got)
115+
})
116+
}
117+
118+
func recvData(t *testing.T, client *derphttp.Client) []byte {
119+
for {
120+
msg, err := client.Recv()
121+
if errors.Is(err, io.EOF) {
122+
return nil
123+
}
124+
assert.NoError(t, err)
125+
t.Logf("derp: %T", msg)
126+
switch msg := msg.(type) {
127+
case derp.ReceivedPacket:
128+
return msg.Data
129+
default:
130+
// Drop all others!
131+
}
132+
}
133+
}
134+
135+
func startDERP(t *testing.T) (*derp.Server, string) {
136+
logf := tailnet.Logger(slogtest.Make(t, nil))
137+
d := derp.NewServer(key.NewNode(), logf)
138+
d.SetMeshKey("some-key")
139+
server := httptest.NewUnstartedServer(derphttp.Handler(d))
140+
server.Start()
141+
t.Cleanup(func() {
142+
_ = d.Close()
143+
})
144+
t.Cleanup(server.Close)
145+
return d, server.URL
146+
}

enterprise/tailmesh/tailmesh.go

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)