Skip to content

Commit 2c0a526

Browse files
committed
feat: add conversions from tailnet to proto
1 parent 8d5a13d commit 2c0a526

File tree

2 files changed

+285
-0
lines changed

2 files changed

+285
-0
lines changed

tailnet/convert.go

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package tailnet
2+
3+
import (
4+
"golang.org/x/xerrors"
5+
"google.golang.org/protobuf/types/known/timestamppb"
6+
"net/netip"
7+
8+
"github.com/coder/coder/v2/tailnet/proto"
9+
"github.com/google/uuid"
10+
"tailscale.com/tailcfg"
11+
"tailscale.com/types/key"
12+
)
13+
14+
func UUIDToByteSlice(u uuid.UUID) []byte {
15+
b := [16]byte(u)
16+
o := make([]byte, 16)
17+
copy(o, b[:]) // copy so that we can't mutate the original
18+
return o
19+
}
20+
21+
func NodeToProto(n *Node) (*proto.Node, error) {
22+
k, err := n.Key.MarshalBinary()
23+
if err != nil {
24+
return nil, err
25+
}
26+
disco, err := n.DiscoKey.MarshalText()
27+
if err != nil {
28+
return nil, err
29+
}
30+
derpForcedWebsocket := make(map[int32]string)
31+
for i, s := range n.DERPForcedWebsocket {
32+
derpForcedWebsocket[int32(i)] = s
33+
}
34+
addresses := make([]string, len(n.Addresses))
35+
for i, prefix := range n.Addresses {
36+
s, err := prefix.MarshalText()
37+
if err != nil {
38+
return nil, err
39+
}
40+
addresses[i] = string(s)
41+
}
42+
allowedIPs := make([]string, len(n.AllowedIPs))
43+
for i, prefix := range n.AllowedIPs {
44+
s, err := prefix.MarshalText()
45+
if err != nil {
46+
return nil, err
47+
}
48+
allowedIPs[i] = string(s)
49+
}
50+
return &proto.Node{
51+
Id: int64(n.ID),
52+
AsOf: timestamppb.New(n.AsOf),
53+
Key: k,
54+
Disco: string(disco),
55+
PreferredDerp: int32(n.PreferredDERP),
56+
DerpLatency: n.DERPLatency,
57+
DerpForcedWebsocket: derpForcedWebsocket,
58+
Addresses: addresses,
59+
AllowedIps: allowedIPs,
60+
Endpoints: n.Endpoints,
61+
}, nil
62+
}
63+
64+
func ProtoToNode(p *proto.Node) (*Node, error) {
65+
k := key.NodePublic{}
66+
err := k.UnmarshalBinary(p.GetKey())
67+
if err != nil {
68+
return nil, err
69+
}
70+
disco := key.DiscoPublic{}
71+
err = disco.UnmarshalText([]byte(p.GetDisco()))
72+
if err != nil {
73+
return nil, err
74+
}
75+
derpForcedWebsocket := make(map[int]string)
76+
for i, s := range p.GetDerpForcedWebsocket() {
77+
derpForcedWebsocket[int(i)] = s
78+
}
79+
addresses := make([]netip.Prefix, len(p.GetAddresses()))
80+
for i, prefix := range p.GetAddresses() {
81+
err = addresses[i].UnmarshalText([]byte(prefix))
82+
if err != nil {
83+
return nil, err
84+
}
85+
}
86+
allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps()))
87+
for i, prefix := range p.GetAllowedIps() {
88+
err = allowedIPs[i].UnmarshalText([]byte(prefix))
89+
if err != nil {
90+
return nil, err
91+
}
92+
}
93+
return &Node{
94+
ID: tailcfg.NodeID(p.GetId()),
95+
AsOf: p.GetAsOf().AsTime(),
96+
Key: k,
97+
DiscoKey: disco,
98+
PreferredDERP: int(p.GetPreferredDerp()),
99+
DERPLatency: p.GetDerpLatency(),
100+
DERPForcedWebsocket: derpForcedWebsocket,
101+
Addresses: addresses,
102+
AllowedIPs: allowedIPs,
103+
Endpoints: p.Endpoints,
104+
}, nil
105+
}
106+
107+
func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) {
108+
nodes := make([]*Node, 0, len(resp.GetPeerUpdates()))
109+
for _, pu := range resp.GetPeerUpdates() {
110+
if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
111+
continue
112+
}
113+
n, err := ProtoToNode(pu.Node)
114+
if err != nil {
115+
return nil, xerrors.Errorf("failed conversion from protobuf: %w", err)
116+
}
117+
nodes = append(nodes, n)
118+
}
119+
return nodes, nil
120+
}
121+
122+
func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) {
123+
p, err := NodeToProto(node)
124+
if err != nil {
125+
return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err)
126+
}
127+
return &proto.CoordinateResponse{
128+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
129+
{
130+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
131+
Uuid: UUIDToByteSlice(id),
132+
Node: p,
133+
Reason: reason,
134+
},
135+
},
136+
}, nil
137+
}

tailnet/convert_test.go

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package tailnet_test
2+
3+
import (
4+
"net/netip"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
"tailscale.com/tailcfg"
11+
"tailscale.com/types/key"
12+
13+
"github.com/coder/coder/v2/coderd/database/dbtime"
14+
"github.com/coder/coder/v2/tailnet"
15+
"github.com/coder/coder/v2/tailnet/proto"
16+
)
17+
18+
func TestNode(t *testing.T) {
19+
testCases := []struct {
20+
name string
21+
node tailnet.Node
22+
}{
23+
{
24+
name: "Zero",
25+
},
26+
{
27+
name: "AllFields",
28+
node: tailnet.Node{
29+
ID: 33,
30+
AsOf: time.Now(),
31+
Key: key.NewNode().Public(),
32+
DiscoKey: key.NewDisco().Public(),
33+
PreferredDERP: 12,
34+
DERPLatency: map[string]float64{
35+
"1": 0.2,
36+
"12": 0.3,
37+
},
38+
DERPForcedWebsocket: map[int]string{
39+
1: "forced",
40+
},
41+
Addresses: []netip.Prefix{
42+
netip.MustParsePrefix("10.0.0.0/8"),
43+
netip.MustParsePrefix("ff80::aa:1/128"),
44+
},
45+
AllowedIPs: []netip.Prefix{
46+
netip.MustParsePrefix("10.0.0.0/8"),
47+
netip.MustParsePrefix("ff80::aa:1/128"),
48+
},
49+
Endpoints: []string{
50+
"192.168.0.1:3305",
51+
"[ff80::aa:1]:2049",
52+
},
53+
},
54+
},
55+
{
56+
name: "dbtime",
57+
node: tailnet.Node{
58+
AsOf: dbtime.Now(),
59+
},
60+
},
61+
}
62+
63+
for _, tc := range testCases {
64+
tc := tc
65+
t.Run(tc.name, func(t *testing.T) {
66+
p, err := tailnet.NodeToProto(&tc.node)
67+
require.NoError(t, err)
68+
69+
inv, err := tailnet.ProtoToNode(p)
70+
require.NoError(t, err)
71+
require.Equal(t, tc.node.ID, inv.ID)
72+
require.True(t, tc.node.AsOf.Equal(inv.AsOf))
73+
require.Equal(t, tc.node.Key, inv.Key)
74+
require.Equal(t, tc.node.DiscoKey, inv.DiscoKey)
75+
require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP)
76+
require.Equal(t, tc.node.DERPLatency, inv.DERPLatency)
77+
require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket))
78+
for k, v := range inv.DERPForcedWebsocket {
79+
nv, ok := tc.node.DERPForcedWebsocket[k]
80+
require.True(t, ok)
81+
require.Equal(t, nv, v)
82+
}
83+
require.ElementsMatch(t, tc.node.Addresses, inv.Addresses)
84+
require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs)
85+
require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints)
86+
})
87+
}
88+
}
89+
90+
func TestUUIDToByteSlice(t *testing.T) {
91+
u := uuid.New()
92+
b := tailnet.UUIDToByteSlice(u)
93+
u2, err := uuid.FromBytes(b)
94+
require.NoError(t, err)
95+
require.Equal(t, u, u2)
96+
97+
b = tailnet.UUIDToByteSlice(uuid.Nil)
98+
u2, err = uuid.FromBytes(b)
99+
require.NoError(t, err)
100+
require.Equal(t, uuid.Nil, u2)
101+
}
102+
103+
func TestOnlyNodeUpdates(t *testing.T) {
104+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
105+
p, err := tailnet.NodeToProto(node)
106+
require.NoError(t, err)
107+
resp := &proto.CoordinateResponse{
108+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
109+
{
110+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
111+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
112+
Node: p,
113+
},
114+
{
115+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
116+
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
117+
Reason: "disconnected",
118+
},
119+
{
120+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3},
121+
Kind: proto.CoordinateResponse_PeerUpdate_LOST,
122+
Reason: "disconnected",
123+
},
124+
{
125+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4},
126+
},
127+
},
128+
}
129+
nodes, err := tailnet.OnlyNodeUpdates(resp)
130+
require.NoError(t, err)
131+
require.Len(t, nodes, 1)
132+
require.Equal(t, tailcfg.NodeID(1), nodes[0].ID)
133+
}
134+
135+
func TestSingleNodeUpdate(t *testing.T) {
136+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
137+
u := uuid.New()
138+
resp, err := tailnet.SingleNodeUpdate(u, node, "unit test")
139+
require.NoError(t, err)
140+
require.Len(t, resp.PeerUpdates, 1)
141+
up := resp.PeerUpdates[0]
142+
require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind)
143+
u2, err := uuid.FromBytes(up.Uuid)
144+
require.NoError(t, err)
145+
require.Equal(t, u, u2)
146+
require.Equal(t, "unit test", up.Reason)
147+
require.EqualValues(t, 1, up.Node.Id)
148+
}

0 commit comments

Comments
 (0)