From 7bff778d86f492bfbd9081de4d508dac64d26f0b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 2 Jun 2022 12:22:05 -0500 Subject: [PATCH 1/2] fix: ensure listen websocket isn't opened for non-latest agents Exponential backoff is only enabled if the websocket fails to open. If the websocket is opened but immediately killed, the agent will try to immediately reconnect. This is desireable in cases where coderd is being replaced or network conditions cause the connection to die, but not for permanent errors. --- coderd/workspaceagents.go | 68 +++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index cfcdea0404683..90396235993cb 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -143,16 +143,49 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { defer api.websocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) - conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) + resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) if err != nil { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("accept websocket: %s", err), + Message: fmt.Sprintf("get workspace resource: %s", err), }) return } - resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) + + build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("get workspace build job: %s", err), + }) + return + } + // Ensure the resource is still valid! + // We only accept agents for resources on the latest build. + ensureLatestBuild := func() error { + latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) + if err != nil { + return err + } + if build.ID != latestBuild.ID { + return xerrors.New("build is outdated") + } + return nil + } + + err = ensureLatestBuild() + if err != nil { + api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built", + slog.F("resource", resource), + slog.F("agent", workspaceAgent), + ) + httpapi.Write(rw, http.StatusForbidden, httpapi.Response{ + Message: fmt.Sprintf("ensure latest build: %s", err), + }) + return + } + + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionDisabled, + }) if err != nil { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ Message: fmt.Sprintf("accept websocket: %s", err), @@ -163,6 +196,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() + config := yamux.DefaultConfig() config.LogOutput = io.Discard session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) @@ -170,6 +204,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } + closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{ ChannelID: workspaceAgent.ID.String(), Pubsub: api.Pubsub, @@ -180,6 +215,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { return } defer closer.Close() + firstConnectedAt := workspaceAgent.FirstConnectedAt if !firstConnectedAt.Valid { firstConnectedAt = sql.NullTime{ @@ -204,23 +240,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { } return nil } - build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) - if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) - return - } - // Ensure the resource is still valid! - // We only accept agents for resources on the latest build. - ensureLatestBuild := func() error { - latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID) - if err != nil { - return err - } - if build.ID != latestBuild.ID { - return xerrors.New("build is outdated") - } - return nil - } defer func() { disconnectedAt = sql.NullTime{ @@ -230,11 +249,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { _ = updateConnectionTimes() }() - err = ensureLatestBuild() - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, "") - return - } err = updateConnectionTimes() if err != nil { _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) From 7355d57a07a777f1fe7af19186eca3986b9b0751 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 2 Jun 2022 14:34:06 -0500 Subject: [PATCH 2/2] test case --- coderd/workspaceagents_test.go | 160 ++++++++++++++++++++++++--------- 1 file changed, 119 insertions(+), 41 deletions(-) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index b14ac43bac4ce..360abb3431156 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -68,52 +68,130 @@ func TestWorkspaceAgent(t *testing.T) { func TestWorkspaceAgentListen(t *testing.T) { t.Parallel() - client, coderAPI := coderdtest.NewWithAPI(t, nil) - user := coderdtest.CreateFirstUser(t, client) - daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI) - authToken := uuid.NewString() - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionDryRun: echo.ProvisionComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "example", - Type: "aws_instance", - Agents: []*proto.Agent{{ - Id: uuid.NewString(), - Auth: &proto.Agent_Token{ - Token: authToken, - }, + + t.Run("Connect", func(t *testing.T) { + t.Parallel() + + client, coderAPI := coderdtest.NewWithAPI(t, nil) + user := coderdtest.CreateFirstUser(t, client) + daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI) + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, }}, - }}, + }, }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) - coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - daemonCloser.Close() + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + daemonCloser.Close() - agentClient := codersdk.New(client.URL) - agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ - Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), - }) - t.Cleanup(func() { - _ = agentCloser.Close() + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ + Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug), + }) + t.Cleanup(func() { + _ = agentCloser.Close() + }) + resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + _, err = conn.Ping() + require.NoError(t, err) }) - resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = conn.Close() + + t.Run("FailNonLatestBuild", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client, coderAPI := coderdtest.NewWithAPI(t, nil) + user := coderdtest.CreateFirstUser(t, client) + daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI) + defer daemonCloser.Close() + + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + version = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: uuid.NewString(), + }, + }}, + }}, + }, + }, + }}, + }, template.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + + stopBuild, err := client.CreateWorkspaceBuild(context.Background(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: version.ID, + Transition: codersdk.WorkspaceTransitionStop, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJob(t, client, stopBuild.ID) + + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + + _, _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil)) + require.Error(t, err) + require.ErrorContains(t, err, "build is outdated") }) - _, err = conn.Ping() - require.NoError(t, err) } func TestWorkspaceAgentTURN(t *testing.T) {