Skip to content

Commit 6882e8e

Browse files
authored
feat: add conversions from tailnet to proto (#10441)
Adds conversions from existing tailnet types to protobuf
1 parent f4026ed commit 6882e8e

File tree

3 files changed

+293
-2
lines changed

3 files changed

+293
-2
lines changed

codersdk/deployment_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ func TestTimezoneOffsets(t *testing.T) {
223223
// Name: "Central",
224224
// Loc: must(time.LoadLocation("America/Chicago")),
225225
// ExpectedOffset: -5,
226-
//},
226+
// },
227227
//{
228228
// Name: "Ireland",
229229
// Loc: must(time.LoadLocation("Europe/Dublin")),
230230
// ExpectedOffset: 1,
231-
//},
231+
// },
232232
{
233233
Name: "HalfHourTz",
234234
// This timezone is +6:30, but the function rounds to the nearest hour.

tailnet/convert.go

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package tailnet
2+
3+
import (
4+
"net/netip"
5+
6+
"github.com/google/uuid"
7+
"golang.org/x/xerrors"
8+
"google.golang.org/protobuf/types/known/timestamppb"
9+
"tailscale.com/tailcfg"
10+
"tailscale.com/types/key"
11+
12+
"github.com/coder/coder/v2/tailnet/proto"
13+
)
14+
15+
func UUIDToByteSlice(u uuid.UUID) []byte {
16+
b := [16]byte(u)
17+
o := make([]byte, 16)
18+
copy(o, b[:]) // copy so that we can't mutate the original
19+
return o
20+
}
21+
22+
func NodeToProto(n *Node) (*proto.Node, error) {
23+
k, err := n.Key.MarshalBinary()
24+
if err != nil {
25+
return nil, err
26+
}
27+
disco, err := n.DiscoKey.MarshalText()
28+
if err != nil {
29+
return nil, err
30+
}
31+
derpForcedWebsocket := make(map[int32]string)
32+
for i, s := range n.DERPForcedWebsocket {
33+
derpForcedWebsocket[int32(i)] = s
34+
}
35+
addresses := make([]string, len(n.Addresses))
36+
for i, prefix := range n.Addresses {
37+
s, err := prefix.MarshalText()
38+
if err != nil {
39+
return nil, err
40+
}
41+
addresses[i] = string(s)
42+
}
43+
allowedIPs := make([]string, len(n.AllowedIPs))
44+
for i, prefix := range n.AllowedIPs {
45+
s, err := prefix.MarshalText()
46+
if err != nil {
47+
return nil, err
48+
}
49+
allowedIPs[i] = string(s)
50+
}
51+
return &proto.Node{
52+
Id: int64(n.ID),
53+
AsOf: timestamppb.New(n.AsOf),
54+
Key: k,
55+
Disco: string(disco),
56+
PreferredDerp: int32(n.PreferredDERP),
57+
DerpLatency: n.DERPLatency,
58+
DerpForcedWebsocket: derpForcedWebsocket,
59+
Addresses: addresses,
60+
AllowedIps: allowedIPs,
61+
Endpoints: n.Endpoints,
62+
}, nil
63+
}
64+
65+
func ProtoToNode(p *proto.Node) (*Node, error) {
66+
k := key.NodePublic{}
67+
err := k.UnmarshalBinary(p.GetKey())
68+
if err != nil {
69+
return nil, err
70+
}
71+
disco := key.DiscoPublic{}
72+
err = disco.UnmarshalText([]byte(p.GetDisco()))
73+
if err != nil {
74+
return nil, err
75+
}
76+
derpForcedWebsocket := make(map[int]string)
77+
for i, s := range p.GetDerpForcedWebsocket() {
78+
derpForcedWebsocket[int(i)] = s
79+
}
80+
addresses := make([]netip.Prefix, len(p.GetAddresses()))
81+
for i, prefix := range p.GetAddresses() {
82+
err = addresses[i].UnmarshalText([]byte(prefix))
83+
if err != nil {
84+
return nil, err
85+
}
86+
}
87+
allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps()))
88+
for i, prefix := range p.GetAllowedIps() {
89+
err = allowedIPs[i].UnmarshalText([]byte(prefix))
90+
if err != nil {
91+
return nil, err
92+
}
93+
}
94+
return &Node{
95+
ID: tailcfg.NodeID(p.GetId()),
96+
AsOf: p.GetAsOf().AsTime(),
97+
Key: k,
98+
DiscoKey: disco,
99+
PreferredDERP: int(p.GetPreferredDerp()),
100+
DERPLatency: p.GetDerpLatency(),
101+
DERPForcedWebsocket: derpForcedWebsocket,
102+
Addresses: addresses,
103+
AllowedIPs: allowedIPs,
104+
Endpoints: p.Endpoints,
105+
}, nil
106+
}
107+
108+
func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) {
109+
nodes := make([]*Node, 0, len(resp.GetPeerUpdates()))
110+
for _, pu := range resp.GetPeerUpdates() {
111+
if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
112+
continue
113+
}
114+
n, err := ProtoToNode(pu.Node)
115+
if err != nil {
116+
return nil, xerrors.Errorf("failed conversion from protobuf: %w", err)
117+
}
118+
nodes = append(nodes, n)
119+
}
120+
return nodes, nil
121+
}
122+
123+
func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) {
124+
p, err := NodeToProto(node)
125+
if err != nil {
126+
return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err)
127+
}
128+
return &proto.CoordinateResponse{
129+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
130+
{
131+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
132+
Uuid: UUIDToByteSlice(id),
133+
Node: p,
134+
Reason: reason,
135+
},
136+
},
137+
}, nil
138+
}

tailnet/convert_test.go

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package tailnet_test
2+
3+
import (
4+
"net/netip"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
"tailscale.com/tailcfg"
11+
"tailscale.com/types/key"
12+
13+
"github.com/coder/coder/v2/coderd/database/dbtime"
14+
"github.com/coder/coder/v2/tailnet"
15+
"github.com/coder/coder/v2/tailnet/proto"
16+
)
17+
18+
func TestNode(t *testing.T) {
19+
t.Parallel()
20+
testCases := []struct {
21+
name string
22+
node tailnet.Node
23+
}{
24+
{
25+
name: "Zero",
26+
},
27+
{
28+
name: "AllFields",
29+
node: tailnet.Node{
30+
ID: 33,
31+
AsOf: time.Now(),
32+
Key: key.NewNode().Public(),
33+
DiscoKey: key.NewDisco().Public(),
34+
PreferredDERP: 12,
35+
DERPLatency: map[string]float64{
36+
"1": 0.2,
37+
"12": 0.3,
38+
},
39+
DERPForcedWebsocket: map[int]string{
40+
1: "forced",
41+
},
42+
Addresses: []netip.Prefix{
43+
netip.MustParsePrefix("10.0.0.0/8"),
44+
netip.MustParsePrefix("ff80::aa:1/128"),
45+
},
46+
AllowedIPs: []netip.Prefix{
47+
netip.MustParsePrefix("10.0.0.0/8"),
48+
netip.MustParsePrefix("ff80::aa:1/128"),
49+
},
50+
Endpoints: []string{
51+
"192.168.0.1:3305",
52+
"[ff80::aa:1]:2049",
53+
},
54+
},
55+
},
56+
{
57+
name: "dbtime",
58+
node: tailnet.Node{
59+
AsOf: dbtime.Now(),
60+
},
61+
},
62+
}
63+
64+
for _, tc := range testCases {
65+
tc := tc
66+
t.Run(tc.name, func(t *testing.T) {
67+
t.Parallel()
68+
p, err := tailnet.NodeToProto(&tc.node)
69+
require.NoError(t, err)
70+
71+
inv, err := tailnet.ProtoToNode(p)
72+
require.NoError(t, err)
73+
require.Equal(t, tc.node.ID, inv.ID)
74+
require.True(t, tc.node.AsOf.Equal(inv.AsOf))
75+
require.Equal(t, tc.node.Key, inv.Key)
76+
require.Equal(t, tc.node.DiscoKey, inv.DiscoKey)
77+
require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP)
78+
require.Equal(t, tc.node.DERPLatency, inv.DERPLatency)
79+
require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket))
80+
for k, v := range inv.DERPForcedWebsocket {
81+
nv, ok := tc.node.DERPForcedWebsocket[k]
82+
require.True(t, ok)
83+
require.Equal(t, nv, v)
84+
}
85+
require.ElementsMatch(t, tc.node.Addresses, inv.Addresses)
86+
require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs)
87+
require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints)
88+
})
89+
}
90+
}
91+
92+
func TestUUIDToByteSlice(t *testing.T) {
93+
t.Parallel()
94+
u := uuid.New()
95+
b := tailnet.UUIDToByteSlice(u)
96+
u2, err := uuid.FromBytes(b)
97+
require.NoError(t, err)
98+
require.Equal(t, u, u2)
99+
100+
b = tailnet.UUIDToByteSlice(uuid.Nil)
101+
u2, err = uuid.FromBytes(b)
102+
require.NoError(t, err)
103+
require.Equal(t, uuid.Nil, u2)
104+
}
105+
106+
func TestOnlyNodeUpdates(t *testing.T) {
107+
t.Parallel()
108+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
109+
p, err := tailnet.NodeToProto(node)
110+
require.NoError(t, err)
111+
resp := &proto.CoordinateResponse{
112+
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
113+
{
114+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
115+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
116+
Node: p,
117+
},
118+
{
119+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
120+
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
121+
Reason: "disconnected",
122+
},
123+
{
124+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3},
125+
Kind: proto.CoordinateResponse_PeerUpdate_LOST,
126+
Reason: "disconnected",
127+
},
128+
{
129+
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4},
130+
},
131+
},
132+
}
133+
nodes, err := tailnet.OnlyNodeUpdates(resp)
134+
require.NoError(t, err)
135+
require.Len(t, nodes, 1)
136+
require.Equal(t, tailcfg.NodeID(1), nodes[0].ID)
137+
}
138+
139+
func TestSingleNodeUpdate(t *testing.T) {
140+
t.Parallel()
141+
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
142+
u := uuid.New()
143+
resp, err := tailnet.SingleNodeUpdate(u, node, "unit test")
144+
require.NoError(t, err)
145+
require.Len(t, resp.PeerUpdates, 1)
146+
up := resp.PeerUpdates[0]
147+
require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind)
148+
u2, err := uuid.FromBytes(up.Uuid)
149+
require.NoError(t, err)
150+
require.Equal(t, u, u2)
151+
require.Equal(t, "unit test", up.Reason)
152+
require.EqualValues(t, 1, up.Node.Id)
153+
}

0 commit comments

Comments
 (0)