Skip to content

feat(tailnet): add alias with username and short alias to DNS #15585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions tailnet/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ func (r *basicResumeTokenRefresher) refresh() {
type tunnelAllWorkspaceUpdatesController struct {
coordCtrl *TunnelSrcCoordController
dnsHostSetter DNSHostsSetter
ownerUsername string
logger slog.Logger
}

Expand All @@ -868,18 +869,30 @@ type workspace struct {
}

// addAllDNSNames adds names for all of its agents to the given map of names
func (w workspace) addAllDNSNames(names map[dnsname.FQDN][]netip.Addr) error {
func (w workspace) addAllDNSNames(names map[dnsname.FQDN][]netip.Addr, owner string) error {
for _, a := range w.agents {
// TODO: technically, DNS labels cannot start with numbers, but the rules are often not
// strictly enforced.
// TODO: support <agent>.<workspace>.<username>.coder
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%s.%s.me.coder.", a.name, w.name))
if err != nil {
return err
}
names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.id)}
fqdn, err = dnsname.ToFQDN(fmt.Sprintf("%s.%s.%s.coder.", a.name, w.name, owner))
if err != nil {
return err
}
names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.id)}
}
if len(w.agents) == 1 {
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%s.coder.", w.name))
if err != nil {
return err
}
for _, a := range w.agents {
names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.id)}
}
}
// TODO: Possibly support <workspace>.coder. alias if there is only one agent
return nil
}

Expand All @@ -895,6 +908,7 @@ func (t *tunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient)
logger: t.logger,
coordCtrl: t.coordCtrl,
dnsHostsSetter: t.dnsHostSetter,
ownerUsername: t.ownerUsername,
recvLoopDone: make(chan struct{}),
workspaces: make(map[uuid.UUID]*workspace),
}
Expand All @@ -908,6 +922,7 @@ type tunnelUpdater struct {
client WorkspaceUpdatesClient
coordCtrl *TunnelSrcCoordController
dnsHostsSetter DNSHostsSetter
ownerUsername string
recvLoopDone chan struct{}

// don't need the mutex since only manipulated by the recvLoop
Expand Down Expand Up @@ -1088,7 +1103,7 @@ func (t *tunnelUpdater) allAgentIDs() []uuid.UUID {
func (t *tunnelUpdater) allDNSNames() map[dnsname.FQDN][]netip.Addr {
names := make(map[dnsname.FQDN][]netip.Addr)
for _, w := range t.workspaces {
err := w.addAllDNSNames(names)
err := w.addAllDNSNames(names, t.ownerUsername)
if err != nil {
// This should never happen in production, because converting the FQDN only fails
// if names are too long, and we put strict length limits on agent, workspace, and user
Expand All @@ -1102,13 +1117,28 @@ func (t *tunnelUpdater) allDNSNames() map[dnsname.FQDN][]netip.Addr {
return names
}

type TunnelAllOption func(t *tunnelAllWorkspaceUpdatesController)

// WithDNS configures the tunnelAllWorkspaceUpdatesController to set DNS names for all workspaces
// and agents it learns about.
func WithDNS(d DNSHostsSetter, ownerUsername string) TunnelAllOption {
return func(t *tunnelAllWorkspaceUpdatesController) {
t.dnsHostSetter = d
t.ownerUsername = ownerUsername
}
}

// NewTunnelAllWorkspaceUpdatesController creates a WorkspaceUpdatesController that creates tunnels
// (via the TunnelSrcCoordController) to all agents received over the WorkspaceUpdates RPC. If a
// DNSHostSetter is provided, it also programs DNS hosts based on the agent and workspace names.
func NewTunnelAllWorkspaceUpdatesController(
logger slog.Logger, c *TunnelSrcCoordController, d DNSHostsSetter,
logger slog.Logger, c *TunnelSrcCoordController, opts ...TunnelAllOption,
) WorkspaceUpdatesController {
return &tunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c, dnsHostSetter: d}
t := &tunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c}
for _, opt := range opts {
opt(t)
}
return t
}

// NewController creates a new Controller without running it
Expand Down
89 changes: 47 additions & 42 deletions tailnet/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,13 +974,13 @@ func (f *fakeResumeTokenClient) RefreshResumeToken(_ context.Context, _ *proto.R
}
select {
case <-f.ctx.Done():
return nil, f.ctx.Err()
return nil, timeoutOnFakeErr
case f.calls <- call:
// OK
}
select {
case <-f.ctx.Done():
return nil, f.ctx.Err()
return nil, timeoutOnFakeErr
case err := <-call.errCh:
return nil, err
case resp := <-call.resp:
Expand Down Expand Up @@ -1240,6 +1240,11 @@ func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (t
}, nil
}

// timeoutOnFakeErr is the error we send when fakes fail to send calls or receive responses before
// their context times out. We don't want to send the context error since that often doesn't trigger
// test failures or logging.
var timeoutOnFakeErr = xerrors.New("test timeout")

type fakeCoordinatorClient struct {
ctx context.Context
t testing.TB
Expand All @@ -1253,15 +1258,13 @@ func (f fakeCoordinatorClient) Close() error {
errs := make(chan error)
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send close call")
return f.ctx.Err()
return timeoutOnFakeErr
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: changes like this are because calling t.Error() after the test case finishes (and context expires in t.Cleanup) will panic and create a bunch of extra tests to fail. So, we don't fail the test in these handlers, and instead pass a distinctive error.

Copy link
Member

@ethanndickson ethanndickson Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I assume it's a similar fix for coder/internal#217 too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, plus there was a real product bug in that one.

case f.close <- errs:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for close call response")
return f.ctx.Err()
return timeoutOnFakeErr
case err := <-errs:
return err
}
Expand All @@ -1276,15 +1279,13 @@ func (f fakeCoordinatorClient) Send(request *proto.CoordinateRequest) error {
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send call")
return f.ctx.Err()
return timeoutOnFakeErr
case f.reqs <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for send call response")
return f.ctx.Err()
return timeoutOnFakeErr
case err := <-errs:
return err
}
Expand All @@ -1300,15 +1301,13 @@ func (f fakeCoordinatorClient) Recv() (*proto.CoordinateResponse, error) {
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send Recv() call")
return nil, f.ctx.Err()
return nil, timeoutOnFakeErr
case f.resps <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for Recv() call response")
return nil, f.ctx.Err()
return nil, timeoutOnFakeErr
case err := <-errs:
return nil, err
case resp := <-resps:
Expand Down Expand Up @@ -1348,15 +1347,13 @@ func (f *fakeWorkspaceUpdateClient) Close() error {
errs := make(chan error)
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send close call")
return f.ctx.Err()
return timeoutOnFakeErr
case f.close <- errs:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for close call response")
return f.ctx.Err()
return timeoutOnFakeErr
case err := <-errs:
return err
}
Expand All @@ -1372,15 +1369,13 @@ func (f *fakeWorkspaceUpdateClient) Recv() (*proto.WorkspaceUpdate, error) {
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send Recv() call")
return nil, f.ctx.Err()
return nil, timeoutOnFakeErr
case f.recv <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for Recv() call response")
return nil, f.ctx.Err()
return nil, timeoutOnFakeErr
case err := <-errs:
return nil, err
case resp := <-resps:
Expand Down Expand Up @@ -1440,28 +1435,26 @@ func (f *fakeDNSSetter) SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error {
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send SetDNSHosts() call")
return f.ctx.Err()
return timeoutOnFakeErr
case f.calls <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for SetDNSHosts() call response")
return f.ctx.Err()
return timeoutOnFakeErr
case err := <-errs:
return err
}
}

func setupConnectedAllWorkspaceUpdatesController(
ctx context.Context, t testing.TB, logger slog.Logger, dnsSetter tailnet.DNSHostsSetter,
ctx context.Context, t testing.TB, logger slog.Logger, opts ...tailnet.TunnelAllOption,
) (
*fakeCoordinatorClient, *fakeWorkspaceUpdateClient,
) {
fConn := &fakeCoordinatee{}
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, dnsSetter)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, opts...)

// connect up a coordinator client, to track adding and removing tunnels
coordC := newFakeCoordinatorClient(ctx, t)
Expand Down Expand Up @@ -1496,7 +1489,8 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
logger := testutil.Logger(t)

fDNS := newFakeDNSSetter(ctx, t)
coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger, fDNS)
coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
tailnet.WithDNS(fDNS, "testy"))

// Initial update contains 2 workspaces with 1 & 2 agents, respectively
w1ID := testUUID(1)
Expand Down Expand Up @@ -1532,9 +1526,13 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {

// Also triggers setting DNS hosts
expectedDNS := map[dnsname.FQDN][]netip.Addr{
"w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w2a1.w2.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0201::")},
"w2a2.w2.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0202::")},
"w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w2a1.w2.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0201::")},
"w2a2.w2.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0202::")},
"w1a1.w1.testy.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w2a1.w2.testy.coder.": {netip.MustParseAddr("fd60:627a:a42b:0201::")},
"w2a2.w2.testy.coder.": {netip.MustParseAddr("fd60:627a:a42b:0202::")},
"w1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
}
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
Expand All @@ -1547,7 +1545,8 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
logger := testutil.Logger(t)

fDNS := newFakeDNSSetter(ctx, t)
coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger, fDNS)
coordC, updateC := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
tailnet.WithDNS(fDNS, "testy"))

w1ID := testUUID(1)
w1a1ID := testUUID(1, 1)
Expand All @@ -1571,7 +1570,9 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {

// DNS for w1a1
expectedDNS := map[dnsname.FQDN][]netip.Addr{
"w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w1a1.w1.testy.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
}
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
Expand Down Expand Up @@ -1601,7 +1602,9 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {

// DNS contains only w1a2
expectedDNS = map[dnsname.FQDN][]netip.Addr{
"w1a2.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0102::")},
"w1a2.w1.testy.coder.": {netip.MustParseAddr("fd60:627a:a42b:0102::")},
"w1a2.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0102::")},
"w1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0102::")},
}
dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
Expand All @@ -1619,7 +1622,9 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
fDNS := newFakeDNSSetter(ctx, t)
fConn := &fakeCoordinatee{}
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, fDNS)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc,
tailnet.WithDNS(fDNS, "testy"),
)

updateC := newFakeWorkspaceUpdateClient(ctx, t)
updateCW := uut.New(updateC)
Expand All @@ -1639,7 +1644,9 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {

// DNS for w1a1
expectedDNS := map[dnsname.FQDN][]netip.Addr{
"w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w1a1.w1.me.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w1a1.w1.testy.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
"w1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
}
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
Expand Down Expand Up @@ -1746,7 +1753,7 @@ func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) {

fConn := &fakeCoordinatee{}
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, nil)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc)
updateC := newFakeWorkspaceUpdateClient(ctx, t)
updateCW := uut.New(updateC)

Expand Down Expand Up @@ -1780,18 +1787,16 @@ func (f fakeWorkspaceUpdatesController) New(client tailnet.WorkspaceUpdatesClien
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send New call")
cw := newFakeCloserWaiter()
cw.errCh <- f.ctx.Err()
cw.errCh <- timeoutOnFakeErr
return cw
case f.calls <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to get New call response")
cw := newFakeCloserWaiter()
cw.errCh <- f.ctx.Err()
cw.errCh <- timeoutOnFakeErr
return cw
case resp := <-resps:
return resp
Expand Down
Loading