From 70117c5fbe94d9c2e91880de268109d772f65c17 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 22 Sep 2022 17:28:04 +0000 Subject: [PATCH 1/2] fix: Don't use StatusAbnormalClosure This is reserved for WASM use, and might be the cause of some weird leaks. --- coderd/workspaceagents.go | 6 +++--- codersdk/workspaceagents.go | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 942f0b184f0a0..6167790fb8bb7 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -347,7 +347,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request err = updateConnectionTimes() if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) + _ = conn.Close(websocket.StatusGoingAway, err.Error()) return } @@ -380,7 +380,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request } err = updateConnectionTimes() if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) + _ = conn.Close(websocket.StatusGoingAway, err.Error()) return } err := ensureLatestBuild() @@ -571,7 +571,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques }) return } - defer conn.Close(websocket.StatusAbnormalClosure, "") + defer conn.Close(websocket.StatusGoingAway, "") var lastReport codersdk.AgentStatsReportResponse latestStat, err := api.Database.GetLatestAgentStat(ctx, workspaceAgent.ID) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 95832fc625e11..46d8ead8d2d6d 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -308,15 +308,15 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg logger.Debug(ctx, "serving coordinator") err = <-errChan if errors.Is(err, context.Canceled) { - _ = ws.Close(websocket.StatusAbnormalClosure, "") + _ = ws.Close(websocket.StatusGoingAway, "") return } if err != nil { logger.Debug(ctx, "error serving coordinator", slog.Error(err)) - _ = ws.Close(websocket.StatusAbnormalClosure, "") + _ = ws.Close(websocket.StatusGoingAway, "") continue } - _ = ws.Close(websocket.StatusAbnormalClosure, "") + _ = ws.Close(websocket.StatusGoingAway, "") } }() err = <-first @@ -446,7 +446,7 @@ func (c *Client) AgentReportStats( var req AgentStatsReportRequest err := wsjson.Read(ctx, conn, &req) if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, "") + _ = conn.Close(websocket.StatusGoingAway, "") return err } @@ -460,7 +460,7 @@ func (c *Client) AgentReportStats( err = wsjson.Write(ctx, conn, resp) if err != nil { - _ = conn.Close(websocket.StatusAbnormalClosure, "") + _ = conn.Close(websocket.StatusGoingAway, "") return err } } From 55dd39ddafeb6fd500e9f2127b16cf0bbbeb72bc Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 22 Sep 2022 18:00:28 +0000 Subject: [PATCH 2/2] Add close to provisioner logs --- agent/agent_test.go | 2 +- cli/cliui/provisionerjob.go | 7 ++++--- cli/cliui/provisionerjob_test.go | 13 +++++++++++-- cli/create.go | 3 ++- cli/templatecreate.go | 3 ++- cli/update.go | 3 ++- cmd/cliui/main.go | 5 +++-- coderd/provisionerjobs_internal_test.go | 3 ++- coderd/provisionerjobs_test.go | 6 ++++-- coderd/templateversions_test.go | 6 ++++-- coderd/workspaceagents_test.go | 2 +- coderd/workspacebuilds_test.go | 3 ++- codersdk/provisionerdaemons.go | 17 ++++++++++++----- codersdk/templateversions.go | 5 +++-- codersdk/workspacebuilds.go | 2 +- 15 files changed, 54 insertions(+), 26 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 08c7918765319..afed644f78e5e 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -490,7 +490,7 @@ func TestAgent(t *testing.T) { require.Eventually(t, func() bool { _, err := conn.Ping() return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) + }, testutil.WaitLong, testutil.IntervalFast) }) t.Run("Speedtest", func(t *testing.T) { diff --git a/cli/cliui/provisionerjob.go b/cli/cliui/provisionerjob.go index 88375c9134772..c781be74f2cc1 100644 --- a/cli/cliui/provisionerjob.go +++ b/cli/cliui/provisionerjob.go @@ -22,7 +22,7 @@ func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Clie build, err := client.WorkspaceBuild(ctx, build) return build.Job, err }, - Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { + Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { return client.WorkspaceBuildLogsAfter(ctx, build, before) }, }) @@ -31,7 +31,7 @@ func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Clie type ProvisionerJobOptions struct { Fetch func() (codersdk.ProvisionerJob, error) Cancel func() error - Logs func() (<-chan codersdk.ProvisionerJobLog, error) + Logs func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) FetchInterval time.Duration // Verbose determines whether debug and trace logs will be shown. @@ -132,10 +132,11 @@ func ProvisionerJob(ctx context.Context, writer io.Writer, opts ProvisionerJobOp // The initial stage needs to print after the signal handler has been registered. printStage() - logs, err := opts.Logs() + logs, closer, err := opts.Logs() if err != nil { return xerrors.Errorf("logs: %w", err) } + defer closer.Close() var ( // logOutput is where log output is written diff --git a/cli/cliui/provisionerjob_test.go b/cli/cliui/provisionerjob_test.go index 35b203d249e9c..122ff513dd79e 100644 --- a/cli/cliui/provisionerjob_test.go +++ b/cli/cliui/provisionerjob_test.go @@ -2,6 +2,7 @@ package cliui_test import ( "context" + "io" "os" "runtime" "sync" @@ -136,8 +137,10 @@ func newProvisionerJob(t *testing.T) provisionerJobTest { Cancel: func() error { return nil }, - Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { - return logs, nil + Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { + return logs, closeFunc(func() error { + return nil + }), nil }, }) }, @@ -164,3 +167,9 @@ func newProvisionerJob(t *testing.T) provisionerJobTest { PTY: ptty, } } + +type closeFunc func() error + +func (c closeFunc) Close() error { + return c() +} diff --git a/cli/create.go b/cli/create.go index 55281741fc6bc..e43aa898f358e 100644 --- a/cli/create.go +++ b/cli/create.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "time" "github.com/spf13/cobra" @@ -253,7 +254,7 @@ PromptParamLoop: Cancel: func() error { return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) }, - Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { + Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, after) }, // Don't show log output for the dry-run unless there's an error. diff --git a/cli/templatecreate.go b/cli/templatecreate.go index 9c959f88f7329..ef776754ad409 100644 --- a/cli/templatecreate.go +++ b/cli/templatecreate.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "os" "path/filepath" "strings" @@ -182,7 +183,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers Cancel: func() error { return client.CancelTemplateVersion(cmd.Context(), version.ID) }, - Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { + Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { return client.TemplateVersionLogsAfter(cmd.Context(), version.ID, before) }, }) diff --git a/cli/update.go b/cli/update.go index f80aaadfca65c..1ccfaad33a0b3 100644 --- a/cli/update.go +++ b/cli/update.go @@ -66,10 +66,11 @@ func update() *cobra.Command { if err != nil { return err } - logs, err := client.WorkspaceBuildLogsAfter(cmd.Context(), build.ID, before) + logs, closer, err := client.WorkspaceBuildLogsAfter(cmd.Context(), build.ID, before) if err != nil { return err } + defer closer.Close() for { log, ok := <-logs if !ok { diff --git a/cmd/cliui/main.go b/cmd/cliui/main.go index 45cdb6cb7fefe..7fef956fa7693 100644 --- a/cmd/cliui/main.go +++ b/cmd/cliui/main.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "os" "strings" "time" @@ -100,7 +101,7 @@ func main() { Fetch: func() (codersdk.ProvisionerJob, error) { return job, nil }, - Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { + Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { logs := make(chan codersdk.ProvisionerJobLog) go func() { defer close(logs) @@ -143,7 +144,7 @@ func main() { } } }() - return logs, nil + return logs, io.NopCloser(strings.NewReader("")), nil }, Cancel: func() error { job.Status = codersdk.ProvisionerJobCanceling diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 34bfce841e4e9..a5a760757baf3 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -108,8 +108,9 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { require.NoError(t, err) } - logs, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now()) + logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now()) require.NoError(t, err) + defer closer.Close() // when the endpoint calls subscribe, we get the listener here. fPubsub.cond.L.Lock() diff --git a/coderd/provisionerjobs_test.go b/coderd/provisionerjobs_test.go index a34ade1e3874b..f23c3672e3a13 100644 --- a/coderd/provisionerjobs_test.go +++ b/coderd/provisionerjobs_test.go @@ -44,8 +44,9 @@ func TestProvisionerJobLogs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) + logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) require.NoError(t, err) + defer closer.Close() for { log, ok := <-logs t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output) @@ -82,8 +83,9 @@ func TestProvisionerJobLogs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) + logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) require.NoError(t, err) + defer closer.Close() for { _, ok := <-logs if !ok { diff --git a/coderd/templateversions_test.go b/coderd/templateversions_test.go index 461460895fed6..b78dd79ae9ba1 100644 --- a/coderd/templateversions_test.go +++ b/coderd/templateversions_test.go @@ -447,8 +447,9 @@ func TestTemplateVersionLogs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - logs, err := client.TemplateVersionLogsAfter(ctx, version.ID, before) + logs, closer, err := client.TemplateVersionLogsAfter(ctx, version.ID, before) require.NoError(t, err) + defer closer.Close() for { _, ok := <-logs if !ok { @@ -618,8 +619,9 @@ func TestTemplateVersionDryRun(t *testing.T) { require.Equal(t, job.ID, newJob.ID) // Stream logs - logs, err := client.TemplateVersionDryRunLogsAfter(ctx, version.ID, job.ID, after) + logs, closer, err := client.TemplateVersionDryRunLogsAfter(ctx, version.ID, job.ID, after) require.NoError(t, err) + defer closer.Close() logsDone := make(chan struct{}) go func() { diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index c4514c1134427..d81e6890d54a1 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -128,7 +128,7 @@ func TestWorkspaceAgentListen(t *testing.T) { require.Eventually(t, func() bool { _, err := conn.Ping() return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) + }, testutil.WaitLong, testutil.IntervalFast) }) t.Run("FailNonLatestBuild", func(t *testing.T) { diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index 28c707fcbac60..e710962a9ca96 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -442,8 +442,9 @@ func TestWorkspaceBuildLogs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before.Add(-time.Hour)) + logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before.Add(-time.Hour)) require.NoError(t, err) + defer closer.Close() for { log, ok := <-logs if !ok { diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 296df9b5ac70d..f8307b1adee1d 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "io" "net/http" "net/http/cookiejar" "net/url" @@ -104,18 +105,18 @@ func (c *Client) provisionerJobLogsBefore(ctx context.Context, path string, befo } // provisionerJobLogsAfter streams logs that occurred after a specific time. -func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after time.Time) (<-chan ProvisionerJobLog, error) { +func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) { afterQuery := "" if !after.IsZero() { afterQuery = fmt.Sprintf("&after=%d", after.UTC().UnixMilli()) } followURL, err := c.URL.Parse(fmt.Sprintf("%s?follow%s", path, afterQuery)) if err != nil { - return nil, err + return nil, nil, err } jar, err := cookiejar.New(nil) if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) + return nil, nil, xerrors.Errorf("create cookie jar: %w", err) } jar.SetCookies(followURL, []*http.Cookie{{ Name: SessionTokenKey, @@ -129,11 +130,13 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after CompressionMode: websocket.CompressionDisabled, }) if err != nil { - return nil, readBodyAsError(res) + return nil, nil, readBodyAsError(res) } logs := make(chan ProvisionerJobLog) decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText)) + closed := make(chan struct{}) go func() { + defer close(closed) defer close(logs) defer conn.Close(websocket.StatusGoingAway, "") var log ProvisionerJobLog @@ -149,5 +152,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } } }() - return logs, nil + return logs, closeFunc(func() error { + _ = conn.Close(websocket.StatusNormalClosure, "") + <-closed + return nil + }), nil } diff --git a/codersdk/templateversions.go b/codersdk/templateversions.go index 3562906e0c1ee..5baad0a9e9deb 100644 --- a/codersdk/templateversions.go +++ b/codersdk/templateversions.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "time" @@ -99,7 +100,7 @@ func (c *Client) TemplateVersionLogsBefore(ctx context.Context, version uuid.UUI } // TemplateVersionLogsAfter streams logs for a template version that occurred after a specific time. -func (c *Client) TemplateVersionLogsAfter(ctx context.Context, version uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, error) { +func (c *Client) TemplateVersionLogsAfter(ctx context.Context, version uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) { return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/templateversions/%s/logs", version), after) } @@ -166,7 +167,7 @@ func (c *Client) TemplateVersionDryRunLogsBefore(ctx context.Context, version, j // TemplateVersionDryRunLogsAfter streams logs for a template version dry-run // that occurred after a specific time. -func (c *Client) TemplateVersionDryRunLogsAfter(ctx context.Context, version, job uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, error) { +func (c *Client) TemplateVersionDryRunLogsAfter(ctx context.Context, version, job uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) { return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/templateversions/%s/dry-run/%s/logs", version, job), after) } diff --git a/codersdk/workspacebuilds.go b/codersdk/workspacebuilds.go index 1dc75e9239849..cc8cdbf082f74 100644 --- a/codersdk/workspacebuilds.go +++ b/codersdk/workspacebuilds.go @@ -102,7 +102,7 @@ func (c *Client) WorkspaceBuildLogsBefore(ctx context.Context, build uuid.UUID, } // WorkspaceBuildLogsAfter streams logs for a workspace build that occurred after a specific time. -func (c *Client) WorkspaceBuildLogsAfter(ctx context.Context, build uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, error) { +func (c *Client) WorkspaceBuildLogsAfter(ctx context.Context, build uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) { return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/workspacebuilds/%s/logs", build), after) }