Skip to content

Commit 27b0597

Browse files
committed
feat: add conversions from tailnet to proto
1 parent 8d5a13d commit 27b0597

File tree

3 files changed

+290
-1
lines changed

3 files changed

+290
-1
lines changed

codersdk/deployment_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func TestTimezoneOffsets(t *testing.T) {
223223
// Name: "Central",
224224
// Loc: must(time.LoadLocation("America/Chicago")),
225225
// ExpectedOffset: -5,
226-
//},
226+
// },
227227
//{
228228
// Name: "Ireland",
229229
// Loc: must(time.LoadLocation("Europe/Dublin")),

tailnet/convert.go

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package tailnet
2+
3+
import (
4+
"net/netip"
5+
6+
"golang.org/x/xerrors"
7+
"tailscale.com/tailcfg"
8+
"tailscale.com/types/key"
9+
10+
"github.com/coder/coder/v2/tailnet/proto"
11+
)
12+
13+
func UUIDToByteSlice(u uuid.UUID) []byte {
14+
b := [16]byte(u)
15+
o := make([]byte, 16)
16+
copy(o, b[:]) // copy so that we can't mutate the original
17+
return o
18+
}
19+
20+
func NodeToProto(n *Node) (*proto.Node, error) {
21+
k, err := n.Key.MarshalBinary()
22+
if err != nil {
23+
return nil, err
24+
}
25+
disco, err := n.DiscoKey.MarshalText()
26+
if err != nil {
27+
return nil, err
28+
}
29+
derpForcedWebsocket := make(map[int32]string)
30+
for i, s := range n.DERPForcedWebsocket {
31+
derpForcedWebsocket[int32(i)] = s
32+
}
33+
addresses := make([]string, len(n.Addresses))
34+
for i, prefix := range n.Addresses {
35+
s, err := prefix.MarshalText()
36+
if err != nil {
37+
return nil, err
38+
}
39+
addresses[i] = string(s)
40+
}
41+
allowedIPs := make([]string, len(n.AllowedIPs))
42+
for i, prefix := range n.AllowedIPs {
43+
s, err := prefix.MarshalText()
44+
if err != nil {
45+
return nil, err
46+
}
47+
allowedIPs[i] = string(s)
48+
}
49+
return &proto.Node{
50+
Id: int64(n.ID),
51+
AsOf: timestamppb.New(n.AsOf),
52+
Key: k,
53+
Disco: string(disco),
54+
PreferredDerp: int32(n.PreferredDERP),
55+
DerpLatency: n.DERPLatency,
56+
DerpForcedWebsocket: derpForcedWebsocket,
57+
Addresses: addresses,
58+
AllowedIps: allowedIPs,
59+
Endpoints: n.Endpoints,
60+
}, nil
61+
}
62+
63+
func ProtoToNode(p *proto.Node) (*Node, error) {
64+
k := key.NodePublic{}
65+
err := k.UnmarshalBinary(p.GetKey())
66+
if err != nil {
67+
return nil, err
68+
}
69+
disco := key.DiscoPublic{}
70+
err = disco.UnmarshalText([]byte(p.GetDisco()))
71+
if err != nil {
72+
return nil, err
73+
}
74+
derpForcedWebsocket := make(map[int]string)
75+
for i, s := range p.GetDerpForcedWebsocket() {
76+
derpForcedWebsocket[int(i)] = s
77+
}
78+
addresses := make([]netip.Prefix, len(p.GetAddresses()))
79+
for i, prefix := range p.GetAddresses() {
80+
err = addresses[i].UnmarshalText([]byte(prefix))
81+
if err != nil {
82+
return nil, err
83+
}
84+
}
85+
allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps()))
86+
for i, prefix := range p.GetAllowedIps() {
87+
err = allowedIPs[i].UnmarshalText([]byte(prefix))
88+
if err != nil {
89+
return nil, err
90+
}
91+
}
92+
return &Node{
93+
ID: tailcfg.NodeID(p.GetId()),
94+
AsOf: p.GetAsOf().AsTime(),
95+
Key: k,
96+
DiscoKey: disco,
97+
PreferredDERP: int(p.GetPreferredDerp()),
98+
DERPLatency: p.GetDerpLatency(),
99+
DERPForcedWebsocket: derpForcedWebsocket,
100+
Addresses: addresses,
101+
AllowedIPs: allowedIPs,
102+
Endpoints: p.Endpoints,
103+
}, nil
104+
}
105+
106+
func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) {
107+
nodes := make([]*Node, 0, len(resp.GetPeerUpdates()))
108+
for _, pu := range resp.GetPeerUpdates() {
109+
if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
110+
continue
111+
}
112+
n, err := ProtoToNode(pu.Node)
113+
if err != nil {
114+
return nil, xerrors.Errorf("failed conversion from protobuf: %w", err)
115+
}
116+
nodes = append(nodes, n)
117+
}
118+
return nodes, nil
119+
}
120+
121+
func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) {
122+
p, err := NodeToProto(node)
123+
if err != nil {
124+
return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err)
125+
}
126+
return &proto.CoordinateResponse{
127+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
128+
{
129+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
130+
Uuid: UUIDToByteSlice(id),
131+
Node: p,
132+
Reason: reason,
133+
},
134+
},
135+
}, nil
136+
}

tailnet/convert_test.go

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package tailnet_test
2+
3+
import (
4+
"net/netip"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
"tailscale.com/tailcfg"
11+
"tailscale.com/types/key"
12+
13+
"github.com/coder/coder/v2/coderd/database/dbtime"
14+
"github.com/coder/coder/v2/tailnet"
15+
"github.com/coder/coder/v2/tailnet/proto"
16+
)
17+
18+
func TestNode(t *testing.T) {
19+
t.Parallel()
20+
testCases := []struct {
21+
name string
22+
node tailnet.Node
23+
}{
24+
{
25+
name: "Zero",
26+
},
27+
{
28+
name: "AllFields",
29+
node: tailnet.Node{
30+
ID: 33,
31+
AsOf: time.Now(),
32+
Key: key.NewNode().Public(),
33+
DiscoKey: key.NewDisco().Public(),
34+
PreferredDERP: 12,
35+
DERPLatency: map[string]float64{
36+
"1": 0.2,
37+
"12": 0.3,
38+
},
39+
DERPForcedWebsocket: map[int]string{
40+
1: "forced",
41+
},
42+
Addresses: []netip.Prefix{
43+
netip.MustParsePrefix("10.0.0.0/8"),
44+
netip.MustParsePrefix("ff80::aa:1/128"),
45+
},
46+
AllowedIPs: []netip.Prefix{
47+
netip.MustParsePrefix("10.0.0.0/8"),
48+
netip.MustParsePrefix("ff80::aa:1/128"),
49+
},
50+
Endpoints: []string{
51+
"192.168.0.1:3305",
52+
"[ff80::aa:1]:2049",
53+
},
54+
},
55+
},
56+
{
57+
name: "dbtime",
58+
node: tailnet.Node{
59+
AsOf: dbtime.Now(),
60+
},
61+
},
62+
}
63+
64+
for _, tc := range testCases {
65+
tc := tc
66+
t.Run(tc.name, func(t *testing.T) {
67+
t.Parallel()
68+
p, err := tailnet.NodeToProto(&tc.node)
69+
require.NoError(t, err)
70+
71+
inv, err := tailnet.ProtoToNode(p)
72+
require.NoError(t, err)
73+
require.Equal(t, tc.node.ID, inv.ID)
74+
require.True(t, tc.node.AsOf.Equal(inv.AsOf))
75+
require.Equal(t, tc.node.Key, inv.Key)
76+
require.Equal(t, tc.node.DiscoKey, inv.DiscoKey)
77+
require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP)
78+
require.Equal(t, tc.node.DERPLatency, inv.DERPLatency)
79+
require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket))
80+
for k, v := range inv.DERPForcedWebsocket {
81+
nv, ok := tc.node.DERPForcedWebsocket[k]
82+
require.True(t, ok)
83+
require.Equal(t, nv, v)
84+
}
85+
require.ElementsMatch(t, tc.node.Addresses, inv.Addresses)
86+
require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs)
87+
require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints)
88+
})
89+
}
90+
}
91+
92+
func TestUUIDToByteSlice(t *testing.T) {
93+
t.Parallel()
94+
u := uuid.New()
95+
b := tailnet.UUIDToByteSlice(u)
96+
u2, err := uuid.FromBytes(b)
97+
require.NoError(t, err)
98+
require.Equal(t, u, u2)
99+
100+
b = tailnet.UUIDToByteSlice(uuid.Nil)
101+
u2, err = uuid.FromBytes(b)
102+
require.NoError(t, err)
103+
require.Equal(t, uuid.Nil, u2)
104+
}
105+
106+
func TestOnlyNodeUpdates(t *testing.T) {
107+
t.Parallel()
108+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
109+
p, err := tailnet.NodeToProto(node)
110+
require.NoError(t, err)
111+
resp := &proto.CoordinateResponse{
112+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
113+
{
114+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
115+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
116+
Node: p,
117+
},
118+
{
119+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
120+
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
121+
Reason: "disconnected",
122+
},
123+
{
124+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3},
125+
Kind: proto.CoordinateResponse_PeerUpdate_LOST,
126+
Reason: "disconnected",
127+
},
128+
{
129+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4},
130+
},
131+
},
132+
}
133+
nodes, err := tailnet.OnlyNodeUpdates(resp)
134+
require.NoError(t, err)
135+
require.Len(t, nodes, 1)
136+
require.Equal(t, tailcfg.NodeID(1), nodes[0].ID)
137+
}
138+
139+
func TestSingleNodeUpdate(t *testing.T) {
140+
t.Parallel()
141+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
142+
u := uuid.New()
143+
resp, err := tailnet.SingleNodeUpdate(u, node, "unit test")
144+
require.NoError(t, err)
145+
require.Len(t, resp.PeerUpdates, 1)
146+
up := resp.PeerUpdates[0]
147+
require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind)
148+
u2, err := uuid.FromBytes(up.Uuid)
149+
require.NoError(t, err)
150+
require.Equal(t, u, u2)
151+
require.Equal(t, "unit test", up.Reason)
152+
require.EqualValues(t, 1, up.Node.Id)
153+
}

0 commit comments

Comments
 (0)