From 926ff1f7339e78ef843060a747ae9c42b3f0604d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 18 Oct 2023 14:32:26 +0400 Subject: [PATCH] feat: add conversions from tailnet to proto --- codersdk/deployment_test.go | 4 +- tailnet/convert.go | 138 ++++++++++++++++++++++++++++++++ tailnet/convert_test.go | 153 ++++++++++++++++++++++++++++++++++++ 3 files changed, 293 insertions(+), 2 deletions(-) create mode 100644 tailnet/convert.go create mode 100644 tailnet/convert_test.go diff --git a/codersdk/deployment_test.go b/codersdk/deployment_test.go index 362cda8e0bd17..5bf9c98543849 100644 --- a/codersdk/deployment_test.go +++ b/codersdk/deployment_test.go @@ -223,12 +223,12 @@ func TestTimezoneOffsets(t *testing.T) { // Name: "Central", // Loc: must(time.LoadLocation("America/Chicago")), // ExpectedOffset: -5, - //}, + // }, //{ // Name: "Ireland", // Loc: must(time.LoadLocation("Europe/Dublin")), // ExpectedOffset: 1, - //}, + // }, { Name: "HalfHourTz", // This timezone is +6:30, but the function rounds to the nearest hour. diff --git a/tailnet/convert.go b/tailnet/convert.go new file mode 100644 index 0000000000000..bbd417532543a --- /dev/null +++ b/tailnet/convert.go @@ -0,0 +1,138 @@ +package tailnet + +import ( + "net/netip" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + + "github.com/coder/coder/v2/tailnet/proto" +) + +func UUIDToByteSlice(u uuid.UUID) []byte { + b := [16]byte(u) + o := make([]byte, 16) + copy(o, b[:]) // copy so that we can't mutate the original + return o +} + +func NodeToProto(n *Node) (*proto.Node, error) { + k, err := n.Key.MarshalBinary() + if err != nil { + return nil, err + } + disco, err := n.DiscoKey.MarshalText() + if err != nil { + return nil, err + } + derpForcedWebsocket := make(map[int32]string) + for i, s := range n.DERPForcedWebsocket { + derpForcedWebsocket[int32(i)] = s + } + addresses := make([]string, len(n.Addresses)) + for i, prefix := range n.Addresses { + s, err := prefix.MarshalText() + if err != nil { + return nil, err + } + addresses[i] = string(s) + } + allowedIPs := make([]string, len(n.AllowedIPs)) + for i, prefix := range n.AllowedIPs { + s, err := prefix.MarshalText() + if err != nil { + return nil, err + } + allowedIPs[i] = string(s) + } + return &proto.Node{ + Id: int64(n.ID), + AsOf: timestamppb.New(n.AsOf), + Key: k, + Disco: string(disco), + PreferredDerp: int32(n.PreferredDERP), + DerpLatency: n.DERPLatency, + DerpForcedWebsocket: derpForcedWebsocket, + Addresses: addresses, + AllowedIps: allowedIPs, + Endpoints: n.Endpoints, + }, nil +} + +func ProtoToNode(p *proto.Node) (*Node, error) { + k := key.NodePublic{} + err := k.UnmarshalBinary(p.GetKey()) + if err != nil { + return nil, err + } + disco := key.DiscoPublic{} + err = disco.UnmarshalText([]byte(p.GetDisco())) + if err != nil { + return nil, err + } + derpForcedWebsocket := make(map[int]string) + for i, s := range p.GetDerpForcedWebsocket() { + derpForcedWebsocket[int(i)] = s + } + addresses := make([]netip.Prefix, len(p.GetAddresses())) + for i, prefix := range p.GetAddresses() { + err = addresses[i].UnmarshalText([]byte(prefix)) + if err != nil { + return nil, err + } + } + allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps())) + for i, prefix := range p.GetAllowedIps() { + err = allowedIPs[i].UnmarshalText([]byte(prefix)) + if err != nil { + return nil, err + } + } + return &Node{ + ID: tailcfg.NodeID(p.GetId()), + AsOf: p.GetAsOf().AsTime(), + Key: k, + DiscoKey: disco, + PreferredDERP: int(p.GetPreferredDerp()), + DERPLatency: p.GetDerpLatency(), + DERPForcedWebsocket: derpForcedWebsocket, + Addresses: addresses, + AllowedIPs: allowedIPs, + Endpoints: p.Endpoints, + }, nil +} + +func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) { + nodes := make([]*Node, 0, len(resp.GetPeerUpdates())) + for _, pu := range resp.GetPeerUpdates() { + if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE { + continue + } + n, err := ProtoToNode(pu.Node) + if err != nil { + return nil, xerrors.Errorf("failed conversion from protobuf: %w", err) + } + nodes = append(nodes, n) + } + return nodes, nil +} + +func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) { + p, err := NodeToProto(node) + if err != nil { + return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err) + } + return &proto.CoordinateResponse{ + PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{ + { + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Uuid: UUIDToByteSlice(id), + Node: p, + Reason: reason, + }, + }, + }, nil +} diff --git a/tailnet/convert_test.go b/tailnet/convert_test.go new file mode 100644 index 0000000000000..0998f516faa02 --- /dev/null +++ b/tailnet/convert_test.go @@ -0,0 +1,153 @@ +package tailnet_test + +import ( + "net/netip" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" +) + +func TestNode(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + node tailnet.Node + }{ + { + name: "Zero", + }, + { + name: "AllFields", + node: tailnet.Node{ + ID: 33, + AsOf: time.Now(), + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + PreferredDERP: 12, + DERPLatency: map[string]float64{ + "1": 0.2, + "12": 0.3, + }, + DERPForcedWebsocket: map[int]string{ + 1: "forced", + }, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("ff80::aa:1/128"), + }, + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("ff80::aa:1/128"), + }, + Endpoints: []string{ + "192.168.0.1:3305", + "[ff80::aa:1]:2049", + }, + }, + }, + { + name: "dbtime", + node: tailnet.Node{ + AsOf: dbtime.Now(), + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p, err := tailnet.NodeToProto(&tc.node) + require.NoError(t, err) + + inv, err := tailnet.ProtoToNode(p) + require.NoError(t, err) + require.Equal(t, tc.node.ID, inv.ID) + require.True(t, tc.node.AsOf.Equal(inv.AsOf)) + require.Equal(t, tc.node.Key, inv.Key) + require.Equal(t, tc.node.DiscoKey, inv.DiscoKey) + require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP) + require.Equal(t, tc.node.DERPLatency, inv.DERPLatency) + require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket)) + for k, v := range inv.DERPForcedWebsocket { + nv, ok := tc.node.DERPForcedWebsocket[k] + require.True(t, ok) + require.Equal(t, nv, v) + } + require.ElementsMatch(t, tc.node.Addresses, inv.Addresses) + require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs) + require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints) + }) + } +} + +func TestUUIDToByteSlice(t *testing.T) { + t.Parallel() + u := uuid.New() + b := tailnet.UUIDToByteSlice(u) + u2, err := uuid.FromBytes(b) + require.NoError(t, err) + require.Equal(t, u, u2) + + b = tailnet.UUIDToByteSlice(uuid.Nil) + u2, err = uuid.FromBytes(b) + require.NoError(t, err) + require.Equal(t, uuid.Nil, u2) +} + +func TestOnlyNodeUpdates(t *testing.T) { + t.Parallel() + node := &tailnet.Node{ID: tailcfg.NodeID(1)} + p, err := tailnet.NodeToProto(node) + require.NoError(t, err) + resp := &proto.CoordinateResponse{ + PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{ + { + Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: p, + }, + { + Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, + Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED, + Reason: "disconnected", + }, + { + Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3}, + Kind: proto.CoordinateResponse_PeerUpdate_LOST, + Reason: "disconnected", + }, + { + Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4}, + }, + }, + } + nodes, err := tailnet.OnlyNodeUpdates(resp) + require.NoError(t, err) + require.Len(t, nodes, 1) + require.Equal(t, tailcfg.NodeID(1), nodes[0].ID) +} + +func TestSingleNodeUpdate(t *testing.T) { + t.Parallel() + node := &tailnet.Node{ID: tailcfg.NodeID(1)} + u := uuid.New() + resp, err := tailnet.SingleNodeUpdate(u, node, "unit test") + require.NoError(t, err) + require.Len(t, resp.PeerUpdates, 1) + up := resp.PeerUpdates[0] + require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind) + u2, err := uuid.FromBytes(up.Uuid) + require.NoError(t, err) + require.Equal(t, u, u2) + require.Equal(t, "unit test", up.Reason) + require.EqualValues(t, 1, up.Node.Id) +}