Skip to content
Prev Previous commit
Review suggestions
Signed-off-by: Danny Kopping <dannykopping@gmail.com>
  • Loading branch information
dannykopping committed Apr 14, 2025
commit 6f60cbc1398d884a90836a8acbae368cebd6c4cc
107 changes: 38 additions & 69 deletions coderd/tailnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/tailnet"
Expand Down Expand Up @@ -58,7 +59,8 @@ func TestServerTailnet_AgentConn_NoSTUN(t *testing.T) {
defer cancel()

// Connect through the ServerTailnet
agents, serverTailnet := setupServerTailnetAgent(t, 1, withDERPAndStunOptions(tailnettest.DisableSTUN, tailnettest.DERPIsEmbedded))
agents, serverTailnet := setupServerTailnetAgent(t, 1,
tailnettest.DisableSTUN, tailnettest.DERPIsEmbedded)
a := agents[0]

conn, release, err := serverTailnet.AgentConn(ctx, a.id)
Expand Down Expand Up @@ -341,7 +343,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

agents, serverTailnet := setupServerTailnetAgent(t, 1, withDERPAndStunOptions(tailnettest.DisableSTUN))
agents, serverTailnet := setupServerTailnetAgent(t, 1, tailnettest.DisableSTUN)
a := agents[0]

require.True(t, serverTailnet.Conn().GetBlockEndpoints(), "expected BlockEndpoints to be set")
Expand All @@ -366,47 +368,42 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
})
}

type fakePing struct {
err error
}

func (f *fakePing) Ping(context.Context) (time.Duration, error) {
return time.Duration(0), f.err
}

func TestServerTailnet_Healthcheck(t *testing.T) {
func TestDialFailure(t *testing.T) {
t.Parallel()

// Verifies that a non-nil healthcheck which returns a non-error response behaves as expected.
t.Run("Passing", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)

agents, serverTailnet := setupServerTailnetAgent(t, 1, withHealthChecker(&fakePing{}))
// Setup.
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)

a := agents[0]
conn, release, err := serverTailnet.AgentConn(ctx, a.id)
t.Cleanup(release)
require.NoError(t, err)
assert.True(t, conn.AwaitReachable(ctx))
// Given: a tailnet coordinator.
coord := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coord.Close()
})
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)

// If the healthcheck fails, we have no insight into this at this level.
// The dial against the control plane is retried, so we wait for the context to timeout as an indication that the
// healthcheck is performing as expected.
t.Run("Failing", func(t *testing.T) {
t.Parallel()
// Given: a fake DB healthchecker which will always fail.
fch := &failingHealthcheck{}

// When: dialing the in-memory coordinator.
dialer := &coderd.InmemTailnetDialer{
CoordPtr: &coordPtr,
Logger: logger,
ClientID: uuid.UUID{5},
DatabaseHealthCheck: fch,
}
_, err := dialer.Dial(ctx, nil)

ctx := testutil.Context(t, testutil.WaitMedium)
// Then: the error returned reflects the database has failed its healthcheck.
require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable)
}

agents, serverTailnet := setupServerTailnetAgent(t, 1, withHealthChecker(&fakePing{err: xerrors.New("oops")}))
type failingHealthcheck struct{}

a := agents[0]
_, release, err := serverTailnet.AgentConn(ctx, a.id)
require.Nil(t, release)
require.ErrorContains(t, err, "agent is unreachable")
})
func (failingHealthcheck) Ping(context.Context) (time.Duration, error) {
// Simulate a database connection error.
return 0, xerrors.New("oops")
}

type wrappedListener struct {
Expand All @@ -433,36 +430,9 @@ type agentWithID struct {
agent.Agent
}

type serverOption struct {
HealthCheck coderd.Pinger
DERPAndStunOptions []tailnettest.DERPAndStunOption
}

func withHealthChecker(p coderd.Pinger) serverOption {
return serverOption{
HealthCheck: p,
}
}

func withDERPAndStunOptions(opts ...tailnettest.DERPAndStunOption) serverOption {
return serverOption{
DERPAndStunOptions: opts,
}
}

func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...serverOption) ([]agentWithID, *coderd.ServerTailnet) {
func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...tailnettest.DERPAndStunOption) ([]agentWithID, *coderd.ServerTailnet) {
logger := testutil.Logger(t)

var healthChecker coderd.Pinger
var derpAndStunOptions []tailnettest.DERPAndStunOption
for _, opt := range opts {
derpAndStunOptions = append(derpAndStunOptions, opt.DERPAndStunOptions...)
if opt.HealthCheck != nil {
healthChecker = opt.HealthCheck
}
}

derpMap, derpServer := tailnettest.RunDERPAndSTUN(t, derpAndStunOptions...)
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t, opts...)

coord := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
Expand Down Expand Up @@ -502,11 +472,10 @@ func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...serverOption) (
}

dialer := &coderd.InmemTailnetDialer{
CoordPtr: &coordPtr,
DERPFn: func() *tailcfg.DERPMap { return derpMap },
Logger: logger,
ClientID: uuid.UUID{5},
DatabaseHealthCheck: healthChecker,
CoordPtr: &coordPtr,
DERPFn: func() *tailcfg.DERPMap { return derpMap },
Logger: logger,
ClientID: uuid.UUID{5},
}
serverTailnet, err := coderd.NewServerTailnet(
context.Background(),
Expand Down
60 changes: 0 additions & 60 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
Expand Down Expand Up @@ -538,46 +537,6 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
require.Equal(t, "test", strings.TrimSpace(string(output)))
}

// TestWorkspaceAgentDialFailure validates that the tailnet controller will retry connecting to the control plane until
// its context times out, when the dialer fails its healthcheck.
func TestWorkspaceAgentDialFailure(t *testing.T) {
t.Parallel()

store, ps := dbtestutil.NewDB(t)

// Given: a database which will fail its Ping(ctx) call.
// NOTE: The Ping(ctx) call is made by the Dialer.
pdb := &pingFailingDB{
Store: store,
}
client := coderdtest.New(t, &coderdtest.Options{
Database: pdb,
Pubsub: ps,
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)

// Given: a workspace agent is setup.
r := dbfake.WorkspaceBuild(t, pdb, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
_ = agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
require.Len(t, resources, 1)
require.Len(t, resources[0].Agents, 1)

// When: the db is marked as unhealthy (i.e. will fail its Ping).
// This needs to be done *after* the server "starts" otherwise it'll fail straight away when trying to initialize.
pdb.MarkUnhealthy()

// Then: the tailnet controller will continually try to dial the coordination endpoint, exceeding its context timeout.
ctx := testutil.Context(t, testutil.WaitMedium)
conn, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable)
require.Nil(t, conn)
}

func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil)
Expand Down Expand Up @@ -2632,22 +2591,3 @@ func TestAgentConnectionInfo(t *testing.T) {
require.True(t, info.DisableDirectConnections)
require.True(t, info.DERPForceWebSockets)
}

type pingFailingDB struct {
database.Store

unhealthy bool
}

func (p *pingFailingDB) Ping(context.Context) (time.Duration, error) {
if !p.unhealthy {
return time.Nanosecond, nil
}

// Simulate a database connection error.
return 0, xerrors.New("oops")
}

func (p *pingFailingDB) MarkUnhealthy() {
p.unhealthy = true
}
35 changes: 35 additions & 0 deletions codersdk/workspacesdk/workspacesdk_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
package workspacesdk_test

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"

"github.com/coder/websocket"

"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)

func TestWorkspaceRewriteDERPMap(t *testing.T) {
Expand Down Expand Up @@ -37,3 +45,30 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) {
require.Equal(t, "coconuts.org", node.HostName)
require.Equal(t, 44558, node.DERPPort)
}

func TestWorkspaceDialerFailure(t *testing.T) {
t.Parallel()

// Setup.
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)

// Given: a mock HTTP server which mimicks an unreachable database when calling the coordination endpoint.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
Message: codersdk.DatabaseNotReachable,
Detail: "oops",
})
}))
t.Cleanup(srv.Close)

u, err := url.Parse(srv.URL)
require.NoError(t, err)

// When: calling the coordination endpoint.
dialer := workspacesdk.NewWebsocketDialer(logger, u, &websocket.DialOptions{})
_, err = dialer.Dial(ctx, nil)

// Then: an error indicating a database issue is returned, to conditionalize the behavior of the caller.
require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable)
}
10 changes: 5 additions & 5 deletions provisionerd/provisionerd.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,12 @@ func (p *Server) acquireLoop() {
return
}
err := p.acquireAndRunOne(client)
if err != nil { // Only log if context is not done.
if err != nil && ctx.Err() == nil { // Only log if context is not done.
// Short-circuit: don't wait for the retry delay to exit, if required.
if p.acquireExit() {
return
}
p.opts.Logger.Warn(ctx, "failed to acquire job, retrying...", slog.F("delay", fmt.Sprintf("%vms", retrier.Delay.Milliseconds())), slog.Error(err))
p.opts.Logger.Warn(ctx, "failed to acquire job, retrying", slog.F("delay", fmt.Sprintf("%vms", retrier.Delay.Milliseconds())), slog.Error(err))
} else {
// Reset the retrier after each successful acquisition.
retrier.Reset()
Expand Down Expand Up @@ -346,7 +346,7 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) erro
}
if job.JobId == "" {
p.opts.Logger.Debug(ctx, "acquire job successfully canceled")
return xerrors.New("canceled")
return nil
}

if len(job.TraceMetadata) > 0 {
Expand Down Expand Up @@ -401,9 +401,9 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) erro
Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error),
})
if err != nil {
p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err))
p.opts.Logger.Error(ctx, "failed to report provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err))
}
return xerrors.Errorf("provisioner job failed: %w", err)
return xerrors.Errorf("failed to report provisioner job failed: %w", err)
}

p.mutex.Lock()
Expand Down
Loading