Skip to content

feat: agent uses Tailnet v2 API for DERPMap updates #11698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ type Options struct {
type Client interface {
Manifest(ctx context.Context) (agentsdk.Manifest, error)
Listen(ctx context.Context) (drpc.Conn, error)
DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
Expand Down Expand Up @@ -822,10 +821,22 @@ func (a *agent) run(ctx context.Context) error {
network.SetBlockEndpoints(manifest.DisableDirectConnections)
}

// Listen returns the dRPC connection we use for both Coordinator and DERPMap updates
conn, err := a.client.Listen(ctx)
if err != nil {
return err
}
defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
}
}()

eg, egCtx := errgroup.WithContext(ctx)
eg.Go(func() error {
a.logger.Debug(egCtx, "running tailnet connection coordinator")
err := a.runCoordinator(egCtx, network)
err := a.runCoordinator(egCtx, conn, network)
if err != nil {
return xerrors.Errorf("run coordinator: %w", err)
}
Expand All @@ -834,7 +845,7 @@ func (a *agent) run(ctx context.Context) error {

eg.Go(func() error {
a.logger.Debug(egCtx, "running derp map subscriber")
err := a.runDERPMapSubscriber(egCtx, network)
err := a.runDERPMapSubscriber(egCtx, conn, network)
if err != nil {
return xerrors.Errorf("run derp map subscriber: %w", err)
}
Expand Down Expand Up @@ -1056,21 +1067,8 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t

// runCoordinator runs a coordinator and returns whether a reconnect
// should occur.
func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

conn, err := a.client.Listen(ctx)
if err != nil {
return err
}
defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
}
}()

func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error {
defer a.logger.Debug(ctx, "disconnected from coordination RPC")
tClient := tailnetproto.NewDRPCTailnetClient(conn)
coordinate, err := tClient.Coordinate(ctx)
if err != nil {
Expand All @@ -1082,7 +1080,7 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err))
}
}()
a.logger.Info(ctx, "connected to coordination endpoint")
a.logger.Info(ctx, "connected to coordination RPC")
coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil)
select {
case <-ctx.Done():
Expand All @@ -1093,30 +1091,29 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
}

// runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur.
func (a *agent) runDERPMapSubscriber(ctx context.Context, network *tailnet.Conn) error {
func (a *agent) runDERPMapSubscriber(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error {
defer a.logger.Debug(ctx, "disconnected from derp map RPC")
ctx, cancel := context.WithCancel(ctx)
defer cancel()

updates, closer, err := a.client.DERPMapUpdates(ctx)
tClient := tailnetproto.NewDRPCTailnetClient(conn)
stream, err := tClient.StreamDERPMaps(ctx, &tailnetproto.StreamDERPMapsRequest{})
if err != nil {
return err
return xerrors.Errorf("stream DERP Maps: %w", err)
}
defer closer.Close()

a.logger.Info(ctx, "connected to derp map endpoint")
defer func() {
cErr := stream.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing DERPMap stream", slog.Error(err))
}
}()
a.logger.Info(ctx, "connected to derp map RPC")
for {
select {
case <-ctx.Done():
return ctx.Err()
case update := <-updates:
if update.Err != nil {
return update.Err
}
if update.DERPMap != nil && !tailnet.CompareDERPMaps(network.DERPMap(), update.DERPMap) {
a.logger.Info(ctx, "updating derp map due to detected changes")
network.SetDERPMap(update.DERPMap)
}
dmp, err := stream.Recv()
if err != nil {
return xerrors.Errorf("recv DERPMap error: %w", err)
}
dm := tailnet.DERPMapFromProto(dmp)
network.SetDERPMap(dm)
}
}

Expand Down
16 changes: 12 additions & 4 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,7 @@ func TestAgent_Lifecycle(t *testing.T) {
make(chan *agentsdk.Stats, 50),
tailnet.NewCoordinator(logger),
)
defer client.Close()

fs := afero.NewMemMapFs()
agent := agent.New(agent.Options{
Expand Down Expand Up @@ -1683,13 +1684,18 @@ func TestAgent_UpdatedDERP(t *testing.T) {
statsCh,
coordinator,
)
t.Cleanup(func() {
t.Log("closing client")
client.Close()
})
uut := agent.New(agent.Options{
Client: client,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: time.Minute,
})
t.Cleanup(func() {
t.Log("closing agent")
_ = uut.Close()
})

Expand Down Expand Up @@ -1718,6 +1724,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error())
}
t.Logf("closed coordination %s", name)
})
// Force DERP.
conn.SetBlockEndpoints(true)
Expand Down Expand Up @@ -1753,11 +1760,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
}

// Push a new DERP map to the agent.
err := client.PushDERPMapUpdate(agentsdk.DERPMapUpdate{
DERPMap: newDerpMap,
})
err := client.PushDERPMapUpdate(newDerpMap)
require.NoError(t, err)
t.Logf("client Pushed DERPMap update")
t.Logf("pushed DERPMap update to agent")

require.Eventually(t, func() bool {
conn := uut.TailnetConn()
Expand Down Expand Up @@ -1826,6 +1831,7 @@ func TestAgent_Reconnect(t *testing.T) {
statsCh,
coordinator,
)
defer client.Close()
initialized := atomic.Int32{}
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -1862,6 +1868,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
make(chan *agentsdk.Stats, 50),
coordinator,
)
defer client.Close()
filesystem := afero.NewMemMapFs()
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -2039,6 +2046,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)
t.Cleanup(c.Close)

options := agent.Options{
Client: c,
Expand Down
36 changes: 15 additions & 21 deletions agent/agenttest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func NewClient(t testing.TB,
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coordinator)
mux := drpcmux.New()
derpMapUpdates := make(chan *tailcfg.DERPMap)
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
// TODO: handle DERPMap too!
DerpMapUpdateFrequency: time.Hour,
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
CoordPtr: &coordPtr,
Logger: logger,
DerpMapUpdateFrequency: time.Microsecond,
DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
require.NoError(t, err)
Expand All @@ -64,7 +64,7 @@ func NewClient(t testing.TB,
statsChan: statsChan,
coordinator: coordinator,
server: server,
derpMapUpdates: make(chan agentsdk.DERPMapUpdate),
derpMapUpdates: derpMapUpdates,
}
}

Expand All @@ -85,23 +85,26 @@ type Client struct {
lifecycleStates []codersdk.WorkspaceAgentLifecycle
startup agentsdk.PostStartupRequest
logs []agentsdk.Log
derpMapUpdates chan agentsdk.DERPMapUpdate
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
}

func (c *Client) Close() {
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
}

func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}

func (c *Client) Listen(_ context.Context) (drpc.Conn, error) {
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
conn, lis := drpcsdk.MemTransportPipe()
closed := make(chan struct{})
c.LastWorkspaceAgent = func() {
_ = conn.Close()
_ = lis.Close()
<-closed
}
c.t.Cleanup(c.LastWorkspaceAgent)
serveCtx, cancel := context.WithCancel(context.Background())
serveCtx, cancel := context.WithCancel(ctx)
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Expand All @@ -112,7 +115,6 @@ func (c *Client) Listen(_ context.Context) (drpc.Conn, error) {
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {
_ = c.server.Serve(serveCtx, lis)
close(closed)
}()
return conn, nil
}
Expand Down Expand Up @@ -235,7 +237,7 @@ func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerCo
return codersdk.ServiceBannerConfig{}, nil
}

func (c *Client) PushDERPMapUpdate(update agentsdk.DERPMapUpdate) error {
func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
timer := time.NewTimer(testutil.WaitShort)
defer timer.Stop()
select {
Expand All @@ -247,14 +249,6 @@ func (c *Client) PushDERPMapUpdate(update agentsdk.DERPMapUpdate) error {
return nil
}

func (c *Client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) {
closed := make(chan struct{})
return c.derpMapUpdates, closeFunc(func() error {
close(closed)
return nil
}), nil
}

type closeFunc func() error

func (c closeFunc) Close() error {
Expand Down
1 change: 1 addition & 0 deletions coderd/tailnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
})

c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
t.Cleanup(c.Close)

options := agent.Options{
Client: c,
Expand Down
59 changes: 23 additions & 36 deletions coderd/wsconncache/wsconncache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
_ = coordinator.Close()
})
manifest.AgentID = uuid.New()
aC := &client{
t: t,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
derpMapUpdates: make(chan *tailcfg.DERPMap),
}
t.Cleanup(aC.close)
closer := agent.New(agent.Options{
Client: &client{
t: t,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
},
Client: aC,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
Expand Down Expand Up @@ -230,52 +233,37 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
}

type client struct {
t *testing.T
agentID uuid.UUID
manifest agentsdk.Manifest
coordinator tailnet.Coordinator
}

func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
t *testing.T
agentID uuid.UUID
manifest agentsdk.Manifest
coordinator tailnet.Coordinator
closeOnce sync.Once
derpMapUpdates chan *tailcfg.DERPMap
}

type closer struct {
closeFunc func() error
func (c *client) close() {
c.closeOnce.Do(func() { close(c.derpMapUpdates) })
}

func (c *closer) Close() error {
return c.closeFunc()
}

func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) {
closed := make(chan struct{})
return make(<-chan agentsdk.DERPMapUpdate), &closer{
closeFunc: func() error {
close(closed)
return nil
},
}, nil
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}

func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
conn, lis := drpcsdk.MemTransportPipe()
closed := make(chan struct{})
c.t.Cleanup(func() {
_ = conn.Close()
_ = lis.Close()
<-closed
})
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&c.coordinator)
mux := drpcmux.New()
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
// TODO: handle DERPMap too!
DerpMapUpdateFrequency: time.Hour,
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
CoordPtr: &coordPtr,
Logger: logger,
DerpMapUpdateFrequency: time.Microsecond,
DerpMapFn: func() *tailcfg.DERPMap { return <-c.derpMapUpdates },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
if err != nil {
Expand All @@ -302,7 +290,6 @@ func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {
server.Serve(serveCtx, lis)
close(closed)
}()
return conn, nil
}
Expand Down
Loading