From 1a3816cf80314e7663d70706e9edc57ba8c5d110 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 12 Nov 2024 16:40:36 +0400 Subject: [PATCH] feat: set DNS hostnames in workspace updates controller --- tailnet/configmaps.go | 20 ----- tailnet/configmaps_internal_test.go | 39 +-------- tailnet/conn.go | 16 ---- tailnet/controllers.go | 87 +++++++++++++++---- tailnet/controllers_test.go | 128 ++++++++++++++++++++++++++-- 5 files changed, 199 insertions(+), 91 deletions(-) diff --git a/tailnet/configmaps.go b/tailnet/configmaps.go index 8b85924a711b4..605fe559bffac 100644 --- a/tailnet/configmaps.go +++ b/tailnet/configmaps.go @@ -277,16 +277,6 @@ func (c *configMaps) setAddresses(ips []netip.Prefix) { c.Broadcast() } -func (c *configMaps) addHosts(hosts map[dnsname.FQDN][]netip.Addr) { - c.L.Lock() - defer c.L.Unlock() - for name, addrs := range hosts { - c.hosts[name] = slices.Clone(addrs) - } - c.netmapDirty = true - c.Broadcast() -} - func (c *configMaps) setHosts(hosts map[dnsname.FQDN][]netip.Addr) { c.L.Lock() defer c.L.Unlock() @@ -298,16 +288,6 @@ func (c *configMaps) setHosts(hosts map[dnsname.FQDN][]netip.Addr) { c.Broadcast() } -func (c *configMaps) removeHosts(names []dnsname.FQDN) { - c.L.Lock() - defer c.L.Unlock() - for _, name := range names { - delete(c.hosts, name) - } - c.netmapDirty = true - c.Broadcast() -} - // setBlockEndpoints sets whether we should block configuring endpoints we learn // from peers. It triggers a configuration of the engine if the value changes. // nolint: revive diff --git a/tailnet/configmaps_internal_test.go b/tailnet/configmaps_internal_test.go index e64cdb10871d6..ecdc7b146a008 100644 --- a/tailnet/configmaps_internal_test.go +++ b/tailnet/configmaps_internal_test.go @@ -10,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" "tailscale.com/tailcfg" @@ -1177,8 +1176,8 @@ func TestConfigMaps_addRemoveHosts(t *testing.T) { addr3 := CoderServicePrefix.AddrFromUUID(uuid.New()) addr4 := CoderServicePrefix.AddrFromUUID(uuid.New()) - // WHEN: we add two hosts - uut.addHosts(map[dnsname.FQDN][]netip.Addr{ + // WHEN: we set two hosts + uut.setHosts(map[dnsname.FQDN][]netip.Addr{ "agent.myws.me.coder.": { addr1, }, @@ -1207,36 +1206,6 @@ func TestConfigMaps_addRemoveHosts(t *testing.T) { OnlyIPv6: true, }) - // WHEN: we add a new host - newHost := map[dnsname.FQDN][]netip.Addr{ - "agent2.myws.me.coder.": { - addr4, - }, - } - uut.addHosts(newHost) - - // THEN: the engine is reconfigured with both the old and new hosts - _ = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) - req = testutil.RequireRecvCtx(ctx, t, fEng.reconfig) - require.Equal(t, req.dnsCfg, &dns.Config{ - Routes: map[dnsname.FQDN][]*dnstype.Resolver{ - CoderDNSSuffix: nil, - }, - Hosts: map[dnsname.FQDN][]netip.Addr{ - "agent.myws.me.coder.": { - addr1, - }, - "dev.main.me.coder.": { - addr2, - addr3, - }, - "agent2.myws.me.coder.": { - addr4, - }, - }, - OnlyIPv6: true, - }) - // WHEN: We replace the hosts with a new set uut.setHosts(map[dnsname.FQDN][]netip.Addr{ "newagent.myws.me.coder.": { @@ -1265,8 +1234,8 @@ func TestConfigMaps_addRemoveHosts(t *testing.T) { OnlyIPv6: true, }) - // WHEN: we remove all the hosts, and a bad host - uut.removeHosts(append(maps.Keys(req.dnsCfg.Hosts), "badhostname")) + // WHEN: we remove all the hosts + uut.setHosts(map[dnsname.FQDN][]netip.Addr{}) _ = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) req = testutil.RequireRecvCtx(ctx, t, fEng.reconfig) diff --git a/tailnet/conn.go b/tailnet/conn.go index 64d7e3171441a..17affa770d5ee 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -451,22 +451,6 @@ func (c *Conn) SetAddresses(ips []netip.Prefix) error { return nil } -func (c *Conn) AddDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error { - if c.dnsConfigurator == nil { - return xerrors.New("no DNSConfigurator set") - } - c.configMaps.addHosts(hosts) - return nil -} - -func (c *Conn) RemoveDNSHosts(names []dnsname.FQDN) error { - if c.dnsConfigurator == nil { - return xerrors.New("no DNSConfigurator set") - } - c.configMaps.removeHosts(names) - return nil -} - // SetDNSHosts replaces the map of DNS hosts for the connection. func (c *Conn) SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error { if c.dnsConfigurator == nil { diff --git a/tailnet/controllers.go b/tailnet/controllers.go index d250422160ea9..5c2c3151ab26e 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -6,6 +6,7 @@ import ( "io" "maps" "math" + "net/netip" "strings" "sync" "time" @@ -15,6 +16,7 @@ import ( "storj.io/drpc" "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" + "tailscale.com/util/dnsname" "cdr.dev/slog" "github.com/coder/coder/v2/codersdk" @@ -104,6 +106,12 @@ type WorkspaceUpdatesController interface { New(WorkspaceUpdatesClient) CloserWaiter } +// DNSHostsSetter is something that you can set a mapping of DNS names to IPs on. It's the subset +// of the tailnet.Conn that we use to configure DNS records. +type DNSHostsSetter interface { + SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error +} + // ControlProtocolClients represents an abstract interface to the tailnet control plane via a set // of protocol clients. The Closer should close all the clients (e.g. by closing the underlying // connection). @@ -835,8 +843,9 @@ func (r *basicResumeTokenRefresher) refresh() { } type tunnelAllWorkspaceUpdatesController struct { - coordCtrl *TunnelSrcCoordController - logger slog.Logger + coordCtrl *TunnelSrcCoordController + dnsHostSetter DNSHostsSetter + logger slog.Logger } type workspace struct { @@ -845,6 +854,22 @@ type workspace struct { agents map[uuid.UUID]agent } +// addAllDNSNames adds names for all of its agents to the given map of names +func (w workspace) addAllDNSNames(names map[dnsname.FQDN][]netip.Addr) error { + for _, a := range w.agents { + // TODO: technically, DNS labels cannot start with numbers, but the rules are often not + // strictly enforced. + // TODO: support ...coder + fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%s.%s.me.coder.", a.name, w.name)) + if err != nil { + return err + } + names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.id)} + } + // TODO: Possibly support .coder. alias if there is only one agent + return nil +} + type agent struct { id uuid.UUID name string @@ -852,23 +877,25 @@ type agent struct { func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) CloserWaiter { updater := &tunnelUpdater{ - client: client, - errChan: make(chan error, 1), - logger: t.logger, - coordCtrl: t.coordCtrl, - recvLoopDone: make(chan struct{}), - workspaces: make(map[uuid.UUID]*workspace), + client: client, + errChan: make(chan error, 1), + logger: t.logger, + coordCtrl: t.coordCtrl, + dnsHostsSetter: t.dnsHostSetter, + recvLoopDone: make(chan struct{}), + workspaces: make(map[uuid.UUID]*workspace), } go updater.recvLoop() return updater } type tunnelUpdater struct { - errChan chan error - logger slog.Logger - client WorkspaceUpdatesClient - coordCtrl *TunnelSrcCoordController - recvLoopDone chan struct{} + errChan chan error + logger slog.Logger + client WorkspaceUpdatesClient + coordCtrl *TunnelSrcCoordController + dnsHostsSetter DNSHostsSetter + recvLoopDone chan struct{} // don't need the mutex since only manipulated by the recvLoop workspaces map[uuid.UUID]*workspace @@ -991,6 +1018,16 @@ func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error { } allAgents := t.allAgentIDs() t.coordCtrl.SyncDestinations(allAgents) + if t.dnsHostsSetter != nil { + t.logger.Debug(context.Background(), "updating dns hosts") + dnsNames := t.allDNSNames() + err := t.dnsHostsSetter.SetDNSHosts(dnsNames) + if err != nil { + return xerrors.Errorf("failed to set DNS hosts: %w", err) + } + } else { + t.logger.Debug(context.Background(), "skipping setting DNS names because we have no setter") + } return nil } @@ -1035,10 +1072,30 @@ func (t *tunnelUpdater) allAgentIDs() []uuid.UUID { return out } +func (t *tunnelUpdater) allDNSNames() map[dnsname.FQDN][]netip.Addr { + names := make(map[dnsname.FQDN][]netip.Addr) + for _, w := range t.workspaces { + err := w.addAllDNSNames(names) + if err != nil { + // This should never happen in production, because converting the FQDN only fails + // if names are too long, and we put strict length limits on agent, workspace, and user + // names. + t.logger.Critical(context.Background(), + "failed to include DNS name(s)", + slog.F("workspace_id", w.id), + slog.Error(err)) + } + } + return names +} + +// NewTunnelAllWorkspaceUpdatesController creates a WorkspaceUpdatesController that creates tunnels +// (via the TunnelSrcCoordController) to all agents received over the WorkspaceUpdates RPC. If a +// DNSHostSetter is provided, it also programs DNS hosts based on the agent and workspace names. func NewTunnelAllWorkspaceUpdatesController( - logger slog.Logger, c *TunnelSrcCoordController, + logger slog.Logger, c *TunnelSrcCoordController, d DNSHostsSetter, ) WorkspaceUpdatesController { - return &tunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c} + return &tunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c, dnsHostSetter: d} } // NewController creates a new Controller without running it diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 26b8286eb3d7e..1b11ebbe16419 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "net/netip" "slices" "sync" "sync/atomic" @@ -23,6 +24,7 @@ import ( "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/util/dnsname" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -1344,14 +1346,56 @@ func testUUID(b ...byte) uuid.UUID { return o } +type fakeDNSSetter struct { + ctx context.Context + t testing.TB + calls chan *setDNSCall +} + +type setDNSCall struct { + hosts map[dnsname.FQDN][]netip.Addr + err chan<- error +} + +func newFakeDNSSetter(ctx context.Context, t testing.TB) *fakeDNSSetter { + return &fakeDNSSetter{ + ctx: ctx, + t: t, + calls: make(chan *setDNSCall), + } +} + +func (f *fakeDNSSetter) SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error { + f.t.Helper() + errs := make(chan error) + call := &setDNSCall{ + hosts: hosts, + err: errs, + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting to send SetDNSHosts() call") + return f.ctx.Err() + case f.calls <- call: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timed out waiting for SetDNSHosts() call response") + return f.ctx.Err() + case err := <-errs: + return err + } +} + func setupConnectedAllWorkspaceUpdatesController( - ctx context.Context, t testing.TB, logger slog.Logger, + ctx context.Context, t testing.TB, logger slog.Logger, dnsSetter tailnet.DNSHostsSetter, ) ( *fakeCoordinatorClient, *fakeWorkspaceUpdateClient, ) { fConn := &fakeCoordinatee{} tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) - uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, dnsSetter) // connect up a coordinator client, to track adding and removing tunnels coordC := newFakeCoordinatorClient(ctx, t) @@ -1385,7 +1429,8 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger) + fDNS := newFakeDNSSetter(ctx, t) + coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger, fDNS) // Initial update contains 2 workspaces with 1 & 2 agents, respectively w1ID := testUUID(1) @@ -1418,6 +1463,16 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) { require.Contains(t, adds, w1a1ID) require.Contains(t, adds, w2a1ID) require.Contains(t, adds, w2a2ID) + + // Also triggers setting DNS hosts + expectedDNS := map[dnsname.FQDN][]netip.Addr{ + "w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + "w2a1.w2.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0201::")}, + "w2a2.w2.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0202::")}, + } + dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls) + require.Equal(t, expectedDNS, dnsCall.hosts) + testutil.RequireSendCtx(ctx, t, dnsCall.err, nil) } func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) { @@ -1425,7 +1480,8 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger) + fDNS := newFakeDNSSetter(ctx, t) + coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger, fDNS) w1ID := testUUID(1) w1a1ID := testUUID(1, 1) @@ -1447,6 +1503,14 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) { require.Equal(t, w1a1ID[:], coordCall.req.GetAddTunnel().GetId()) testutil.RequireSendCtx(ctx, t, coordCall.err, nil) + // DNS for w1a1 + expectedDNS := map[dnsname.FQDN][]netip.Addr{ + "w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + } + dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls) + require.Equal(t, expectedDNS, dnsCall.hosts) + testutil.RequireSendCtx(ctx, t, dnsCall.err, nil) + // Send update that removes w1a1 and adds w1a2 agentUpdate := &proto.WorkspaceUpdate{ UpsertedAgents: []*proto.Agent{ @@ -1468,6 +1532,60 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) { coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs) require.Equal(t, w1a1ID[:], coordCall.req.GetRemoveTunnel().GetId()) testutil.RequireSendCtx(ctx, t, coordCall.err, nil) + + // DNS contains only w1a2 + expectedDNS = map[dnsname.FQDN][]netip.Addr{ + "w1a2.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0102::")}, + } + dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls) + require.Equal(t, expectedDNS, dnsCall.hosts) + testutil.RequireSendCtx(ctx, t, dnsCall.err, nil) +} + +func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + dnsError := xerrors.New("a bad thing happened") + logger := slogtest.Make(t, + &slogtest.Options{IgnoredErrorIs: []error{dnsError}}). + Leveled(slog.LevelDebug) + + fDNS := newFakeDNSSetter(ctx, t) + fConn := &fakeCoordinatee{} + tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, fDNS) + + updateC := newFakeWorkspaceUpdateClient(ctx, t) + updateCW := uut.New(updateC) + + w1ID := testUUID(1) + w1a1ID := testUUID(1, 1) + initUp := &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: w1ID[:], Name: "w1"}, + }, + UpsertedAgents: []*proto.Agent{ + {Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]}, + }, + } + upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv) + testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp) + + // DNS for w1a1 + expectedDNS := map[dnsname.FQDN][]netip.Addr{ + "w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + } + dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls) + require.Equal(t, expectedDNS, dnsCall.hosts) + testutil.RequireSendCtx(ctx, t, dnsCall.err, dnsError) + + // should trigger a close on the client + closeCall := testutil.RequireRecvCtx(ctx, t, updateC.close) + testutil.RequireSendCtx(ctx, t, closeCall, io.EOF) + + // error should be our initial DNS error + err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait()) + require.ErrorIs(t, err, dnsError) } func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) { @@ -1562,7 +1680,7 @@ func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) { fConn := &fakeCoordinatee{} tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) - uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, nil) updateC := newFakeWorkspaceUpdateClient(ctx, t) updateCW := uut.New(updateC)