Skip to content

Commit 678acaa

Browse files
committed
feat: add conversions from tailnet to proto
1 parent 10fc9ca commit 678acaa

File tree

2 files changed

+290
-0
lines changed

2 files changed

+290
-0
lines changed

tailnet/convert.go

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package tailnet
2+
3+
import (
4+
"net/netip"
5+
"time"
6+
7+
"golang.org/x/xerrors"
8+
9+
coderproto "github.com/coder/coder/v2/proto"
10+
"github.com/coder/coder/v2/tailnet/proto"
11+
"github.com/google/uuid"
12+
"tailscale.com/tailcfg"
13+
"tailscale.com/types/key"
14+
)
15+
16+
func UUIDToByteSlice(u uuid.UUID) []byte {
17+
b := [16]byte(u)
18+
o := make([]byte, 16)
19+
copy(o, b[:]) // copy so that we can't mutate the original
20+
return o
21+
}
22+
23+
func NodeToProto(n *Node) (*proto.Node, error) {
24+
k, err := n.Key.MarshalBinary()
25+
if err != nil {
26+
return nil, err
27+
}
28+
disco, err := n.DiscoKey.MarshalText()
29+
if err != nil {
30+
return nil, err
31+
}
32+
derpForcedWebsocket := make(map[int32]string)
33+
for i, s := range n.DERPForcedWebsocket {
34+
derpForcedWebsocket[int32(i)] = s
35+
}
36+
addresses := make([]string, len(n.Addresses))
37+
for i, prefix := range n.Addresses {
38+
s, err := prefix.MarshalText()
39+
if err != nil {
40+
return nil, err
41+
}
42+
addresses[i] = string(s)
43+
}
44+
allowedIPs := make([]string, len(n.AllowedIPs))
45+
for i, prefix := range n.AllowedIPs {
46+
s, err := prefix.MarshalText()
47+
if err != nil {
48+
return nil, err
49+
}
50+
allowedIPs[i] = string(s)
51+
}
52+
return &proto.Node{
53+
Id: int64(n.ID),
54+
AsOf: &coderproto.UnixTime{Micro: n.AsOf.UnixMicro()},
55+
Key: k,
56+
Disco: string(disco),
57+
PreferredDerp: int32(n.PreferredDERP),
58+
DerpLatency: n.DERPLatency,
59+
DerpForcedWebsocket: derpForcedWebsocket,
60+
Addresses: addresses,
61+
AllowedIps: allowedIPs,
62+
Endpoints: n.Endpoints,
63+
}, nil
64+
}
65+
66+
func ProtoToNode(p *proto.Node) (*Node, error) {
67+
k := key.NodePublic{}
68+
err := k.UnmarshalBinary(p.GetKey())
69+
if err != nil {
70+
return nil, err
71+
}
72+
disco := key.DiscoPublic{}
73+
err = disco.UnmarshalText([]byte(p.GetDisco()))
74+
if err != nil {
75+
return nil, err
76+
}
77+
derpForcedWebsocket := make(map[int]string)
78+
for i, s := range p.GetDerpForcedWebsocket() {
79+
derpForcedWebsocket[int(i)] = s
80+
}
81+
addresses := make([]netip.Prefix, len(p.GetAddresses()))
82+
for i, prefix := range p.GetAddresses() {
83+
err = addresses[i].UnmarshalText([]byte(prefix))
84+
if err != nil {
85+
return nil, err
86+
}
87+
}
88+
allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps()))
89+
for i, prefix := range p.GetAllowedIps() {
90+
err = allowedIPs[i].UnmarshalText([]byte(prefix))
91+
if err != nil {
92+
return nil, err
93+
}
94+
}
95+
return &Node{
96+
ID: tailcfg.NodeID(p.GetId()),
97+
AsOf: time.UnixMicro(p.GetAsOf().GetMicro()).UTC(),
98+
Key: k,
99+
DiscoKey: disco,
100+
PreferredDERP: int(p.GetPreferredDerp()),
101+
DERPLatency: p.GetDerpLatency(),
102+
DERPForcedWebsocket: derpForcedWebsocket,
103+
Addresses: addresses,
104+
AllowedIPs: allowedIPs,
105+
Endpoints: p.Endpoints,
106+
}, nil
107+
}
108+
109+
func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) {
110+
nodes := make([]*Node, 0, len(resp.GetPeerUpdates()))
111+
for _, pu := range resp.GetPeerUpdates() {
112+
if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
113+
continue
114+
}
115+
n, err := ProtoToNode(pu.Node)
116+
if err != nil {
117+
return nil, xerrors.Errorf("failed conversion from protobuf: %w", err)
118+
}
119+
nodes = append(nodes, n)
120+
}
121+
return nodes, nil
122+
}
123+
124+
func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) {
125+
p, err := NodeToProto(node)
126+
if err != nil {
127+
return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err)
128+
}
129+
return &proto.CoordinateResponse{
130+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
131+
{
132+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
133+
Uuid: UUIDToByteSlice(id),
134+
Node: p,
135+
Reason: reason,
136+
},
137+
},
138+
}, nil
139+
}

tailnet/convert_test.go

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package tailnet_test
2+
3+
import (
4+
"net/netip"
5+
"testing"
6+
"time"
7+
8+
"github.com/coder/coder/v2/tailnet/proto"
9+
"tailscale.com/tailcfg"
10+
11+
"github.com/google/uuid"
12+
13+
"github.com/coder/coder/v2/coderd/database/dbtime"
14+
15+
"tailscale.com/types/key"
16+
17+
"github.com/coder/coder/v2/tailnet"
18+
"github.com/stretchr/testify/require"
19+
)
20+
21+
func TestNode(t *testing.T) {
22+
testCases := []struct {
23+
name string
24+
node tailnet.Node
25+
}{
26+
{
27+
name: "Zero",
28+
},
29+
{
30+
name: "AllFields",
31+
node: tailnet.Node{
32+
ID: 33,
33+
AsOf: time.Now(),
34+
Key: key.NewNode().Public(),
35+
DiscoKey: key.NewDisco().Public(),
36+
PreferredDERP: 12,
37+
DERPLatency: map[string]float64{
38+
"1": 0.2,
39+
"12": 0.3,
40+
},
41+
DERPForcedWebsocket: map[int]string{
42+
1: "forced",
43+
},
44+
Addresses: []netip.Prefix{
45+
netip.MustParsePrefix("10.0.0.0/8"),
46+
netip.MustParsePrefix("ff80::aa:1/128"),
47+
},
48+
AllowedIPs: []netip.Prefix{
49+
netip.MustParsePrefix("10.0.0.0/8"),
50+
netip.MustParsePrefix("ff80::aa:1/128"),
51+
},
52+
Endpoints: []string{
53+
"192.168.0.1:3305",
54+
"[ff80::aa:1]:2049",
55+
},
56+
},
57+
},
58+
{
59+
name: "dbtime",
60+
node: tailnet.Node{
61+
AsOf: dbtime.Now(),
62+
},
63+
},
64+
}
65+
66+
for _, tc := range testCases {
67+
tc := tc
68+
t.Run(tc.name, func(t *testing.T) {
69+
p, err := tailnet.NodeToProto(&tc.node)
70+
require.NoError(t, err)
71+
72+
inv, err := tailnet.ProtoToNode(p)
73+
require.NoError(t, err)
74+
require.Equal(t, tc.node.ID, inv.ID)
75+
require.True(t, tc.node.AsOf.Equal(inv.AsOf))
76+
require.Equal(t, tc.node.Key, inv.Key)
77+
require.Equal(t, tc.node.DiscoKey, inv.DiscoKey)
78+
require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP)
79+
require.Equal(t, tc.node.DERPLatency, inv.DERPLatency)
80+
require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket))
81+
for k, v := range inv.DERPForcedWebsocket {
82+
nv, ok := tc.node.DERPForcedWebsocket[k]
83+
require.True(t, ok)
84+
require.Equal(t, nv, v)
85+
}
86+
require.ElementsMatch(t, tc.node.Addresses, inv.Addresses)
87+
require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs)
88+
require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints)
89+
})
90+
}
91+
}
92+
93+
func TestUUIDToByteSlice(t *testing.T) {
94+
u := uuid.New()
95+
b := tailnet.UUIDToByteSlice(u)
96+
u2, err := uuid.FromBytes(b)
97+
require.NoError(t, err)
98+
require.Equal(t, u, u2)
99+
100+
b = tailnet.UUIDToByteSlice(uuid.Nil)
101+
u2, err = uuid.FromBytes(b)
102+
require.NoError(t, err)
103+
require.Equal(t, uuid.Nil, u2)
104+
}
105+
106+
func TestOnlyNodeUpdates(t *testing.T) {
107+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
108+
p, err := tailnet.NodeToProto(node)
109+
require.NoError(t, err)
110+
resp := &proto.CoordinateResponse{
111+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
112+
{
113+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
114+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
115+
Node: p,
116+
},
117+
{
118+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
119+
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
120+
Reason: "disconnected",
121+
},
122+
{
123+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3},
124+
Kind: proto.CoordinateResponse_PeerUpdate_LOST,
125+
Reason: "disconnected",
126+
},
127+
{
128+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4},
129+
},
130+
},
131+
}
132+
nodes, err := tailnet.OnlyNodeUpdates(resp)
133+
require.NoError(t, err)
134+
require.Len(t, nodes, 1)
135+
require.Equal(t, tailcfg.NodeID(1), nodes[0].ID)
136+
}
137+
138+
func TestSingleNodeUpdate(t *testing.T) {
139+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
140+
u := uuid.New()
141+
resp, err := tailnet.SingleNodeUpdate(u, node, "unit test")
142+
require.NoError(t, err)
143+
require.Len(t, resp.PeerUpdates, 1)
144+
up := resp.PeerUpdates[0]
145+
require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind)
146+
u2, err := uuid.FromBytes(up.Uuid)
147+
require.NoError(t, err)
148+
require.Equal(t, u, u2)
149+
require.Equal(t, "unit test", up.Reason)
150+
require.EqualValues(t, 1, up.Node.Id)
151+
}

0 commit comments

Comments
 (0)