diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d11748a51f8b8..be3166cc4d32a 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -20,6 +20,7 @@ import ( "regexp" "runtime" "strings" + "sync" "testing" "time" @@ -1318,9 +1319,6 @@ func TestSSH(t *testing.T) { tmpdir := tempDirUnixSocket(t) localSock := filepath.Join(tmpdir, "local.sock") - l, err := net.Listen("unix", localSock) - require.NoError(t, err) - defer l.Close() remoteSock := filepath.Join(tmpdir, "remote.sock") inv, root := clitest.New(t, @@ -1332,23 +1330,62 @@ func TestSSH(t *testing.T) { clitest.SetupConfig(t, client, root) pty := ptytest.New(t).Attach(inv) inv.Stderr = pty.Output() - cmdDone := tGo(t, func() { - err := inv.WithContext(ctx).Run() - assert.NoError(t, err, "ssh command failed") - }) - // Wait for the prompt or any output really to indicate the command has - // started and accepting input on stdin. + w := clitest.StartWithWaiter(t, inv.WithContext(ctx)) + defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly). + + // Since something was output, it should be safe to write input. + // This could show a prompt or "running startup scripts", so it's + // not indicative of the SSH connection being ready. _ = pty.Peek(ctx, 1) - // This needs to support most shells on Linux or macOS - // We can't include exactly what's expected in the input, as that will always be matched - pty.WriteLine(fmt.Sprintf(`echo "results: $(netstat -an | grep %s | wc -l | tr -d ' ')"`, remoteSock)) - pty.ExpectMatchContext(ctx, "results: 1") + // Ensure the SSH connection is ready by testing the shell + // input/output. + pty.WriteLine("echo ping' 'pong") + pty.ExpectMatchContext(ctx, "ping pong") + + // Start the listener on the "local machine". + l, err := net.Listen("unix", localSock) + require.NoError(t, err) + defer l.Close() + testutil.Go(t, func() { + var wg sync.WaitGroup + defer wg.Wait() + for { + fd, err := l.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + assert.NoError(t, err, "listener accept failed") + } + return + } + + wg.Add(1) + go func() { + defer wg.Done() + defer fd.Close() + agentssh.Bicopy(ctx, fd, fd) + }() + } + }) + + // Dial the forwarded socket on the "remote machine". + d := &net.Dialer{} + fd, err := d.DialContext(ctx, "unix", remoteSock) + require.NoError(t, err) + defer fd.Close() + + // Ping / pong to ensure the socket is working. + _, err = fd.Write([]byte("hello world")) + require.NoError(t, err) + + buf := make([]byte, 11) + _, err = fd.Read(buf) + require.NoError(t, err) + require.Equal(t, "hello world", string(buf)) // And we're done. pty.WriteLine("exit") - <-cmdDone }) // Test that we can forward a local unix socket to a remote unix socket and @@ -1377,6 +1414,8 @@ func TestSSH(t *testing.T) { require.NoError(t, err) defer l.Close() testutil.Go(t, func() { + var wg sync.WaitGroup + defer wg.Wait() for { fd, err := l.Accept() if err != nil { @@ -1386,10 +1425,12 @@ func TestSSH(t *testing.T) { return } - testutil.Go(t, func() { + wg.Add(1) + go func() { + defer wg.Done() defer fd.Close() agentssh.Bicopy(ctx, fd, fd) - }) + }() } }) @@ -1522,6 +1563,8 @@ func TestSSH(t *testing.T) { require.NoError(t, err) defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice. testutil.Go(t, func() { + var wg sync.WaitGroup + defer wg.Wait() for { fd, err := l.Accept() if err != nil { @@ -1531,10 +1574,12 @@ func TestSSH(t *testing.T) { return } - testutil.Go(t, func() { + wg.Add(1) + go func() { + defer wg.Done() defer fd.Close() agentssh.Bicopy(ctx, fd, fd) - }) + }() } }) diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index d7bc29aca3044..938fdf1774008 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -1995,6 +1995,37 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro sidebarAppID = uuid.NullUUID{UUID: id, Valid: true} } + // This is a hacky workaround for the issue with tasks 'disappearing' on stop: + // reuse has_ai_task and sidebar_app_id from the previous build. + // This workaround should be removed as soon as possible. + if workspaceBuild.Transition == database.WorkspaceTransitionStop && workspaceBuild.BuildNumber > 1 { + if prevBuild, err := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + WorkspaceID: workspaceBuild.WorkspaceID, + BuildNumber: workspaceBuild.BuildNumber - 1, + }); err == nil { + hasAITask = prevBuild.HasAITask.Bool + sidebarAppID = prevBuild.AITaskSidebarAppID + warnUnknownSidebarAppID = false + s.Logger.Debug(ctx, "task workaround: reused has_ai_task and sidebar_app_id from previous build to keep track of task", + slog.F("job_id", job.ID.String()), + slog.F("build_number", prevBuild.BuildNumber), + slog.F("workspace_id", workspace.ID), + slog.F("workspace_build_id", workspaceBuild.ID), + slog.F("transition", string(workspaceBuild.Transition)), + slog.F("sidebar_app_id", sidebarAppID.UUID), + slog.F("has_ai_task", hasAITask), + ) + } else { + s.Logger.Error(ctx, "task workaround: tracking via has_ai_task and sidebar_app from previous build failed", + slog.Error(err), + slog.F("job_id", job.ID.String()), + slog.F("workspace_id", workspace.ID), + slog.F("workspace_build_id", workspaceBuild.ID), + slog.F("transition", string(workspaceBuild.Transition)), + ) + } + } + if warnUnknownSidebarAppID { // Ref: https://github.com/coder/coder/issues/18776 // This can happen for a number of reasons: diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 8baa7c99c30b9..98af0bb86a73f 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -2842,9 +2842,12 @@ func TestCompleteJob(t *testing.T) { // has_ai_task has a default value of nil, but once the workspace build completes it will have a value; // it is set to "true" if the related template has any coder_ai_task resources defined, and its sidebar app ID // will be set as well in that case. + // HACK(johnstcn): we also set it to "true" if any _previous_ workspace builds ever had it set to "true". + // This is to avoid tasks "disappearing" when you stop them. t.Run("WorkspaceBuild", func(t *testing.T) { type testcase struct { name string + seedFunc func(context.Context, testing.TB, database.Store) error // If you need to insert other resources transition database.WorkspaceTransition input *proto.CompletedJob_WorkspaceBuild expectHasAiTask bool @@ -2944,6 +2947,17 @@ func TestCompleteJob(t *testing.T) { expectHasAiTask: true, expectUsageEvent: false, }, + { + name: "current build does not have ai task but previous build did", + seedFunc: seedPreviousWorkspaceStartWithAITask, + transition: database.WorkspaceTransitionStop, + input: &proto.CompletedJob_WorkspaceBuild{ + AiTasks: []*sdkproto.AITask{}, + Resources: []*sdkproto.Resource{}, + }, + expectHasAiTask: true, + expectUsageEvent: false, + }, } { t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -2980,6 +2994,9 @@ func TestCompleteJob(t *testing.T) { }) ctx := testutil.Context(t, testutil.WaitShort) + if tc.seedFunc != nil { + require.NoError(t, tc.seedFunc(ctx, t, db)) + } buildJobID := uuid.New() wsBuildID := uuid.New() @@ -2999,8 +3016,13 @@ func TestCompleteJob(t *testing.T) { Tags: pd.Tags, }) require.NoError(t, err) + var buildNum int32 + if latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceTable.ID); err == nil { + buildNum = latestBuild.BuildNumber + } build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ ID: wsBuildID, + BuildNumber: buildNum + 1, JobID: buildJobID, WorkspaceID: workspaceTable.ID, TemplateVersionID: version.ID, @@ -3038,7 +3060,7 @@ func TestCompleteJob(t *testing.T) { require.True(t, build.HasAITask.Valid) // We ALWAYS expect a value to be set, therefore not nil, i.e. valid = true. require.Equal(t, tc.expectHasAiTask, build.HasAITask.Bool) - if tc.expectHasAiTask { + if tc.expectHasAiTask && build.Transition != database.WorkspaceTransitionStop { require.Equal(t, sidebarAppID, build.AITaskSidebarAppID.UUID.String()) } @@ -4244,3 +4266,63 @@ func (f *fakeUsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ databa f.collectedEvents = append(f.collectedEvents, event) return nil } + +func seedPreviousWorkspaceStartWithAITask(ctx context.Context, t testing.TB, db database.Store) error { + t.Helper() + // If the below looks slightly convoluted, that's because it is. + // The workspace doesn't yet have a latest build, so querying all + // workspaces will fail. + tpls, err := db.GetTemplates(ctx) + if err != nil { + return xerrors.Errorf("seedFunc: get template: %w", err) + } + if len(tpls) != 1 { + return xerrors.Errorf("seedFunc: expected exactly one template, got %d", len(tpls)) + } + ws, err := db.GetWorkspacesByTemplateID(ctx, tpls[0].ID) + if err != nil { + return xerrors.Errorf("seedFunc: get workspaces: %w", err) + } + if len(ws) != 1 { + return xerrors.Errorf("seedFunc: expected exactly one workspace, got %d", len(ws)) + } + w := ws[0] + prevJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + OrganizationID: w.OrganizationID, + InitiatorID: w.OwnerID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + tvs, err := db.GetTemplateVersionsByTemplateID(ctx, database.GetTemplateVersionsByTemplateIDParams{ + TemplateID: tpls[0].ID, + }) + if err != nil { + return xerrors.Errorf("seedFunc: get template version: %w", err) + } + if len(tvs) != 1 { + return xerrors.Errorf("seedFunc: expected exactly one template version, got %d", len(tvs)) + } + if tpls[0].ActiveVersionID == uuid.Nil { + return xerrors.Errorf("seedFunc: active version id is nil") + } + res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + JobID: prevJob.ID, + }) + agt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: res.ID, + }) + wa := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{ + AgentID: agt.ID, + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + BuildNumber: 1, + HasAITask: sql.NullBool{Valid: true, Bool: true}, + AITaskSidebarAppID: uuid.NullUUID{Valid: true, UUID: wa.ID}, + ID: w.ID, + InitiatorID: w.OwnerID, + JobID: prevJob.ID, + TemplateVersionID: tvs[0].ID, + Transition: database.WorkspaceTransitionStart, + WorkspaceID: w.ID, + }) + return nil +} diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 6f28b12af5ae0..6a817966f4ff5 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -1610,63 +1610,77 @@ func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { ) for _, tc := range []struct { - name string - devcontainerID string - setupDevcontainers []codersdk.WorkspaceAgentDevcontainer - setupMock func(mccli *acmock.MockContainerCLI, mdccli *acmock.MockDevcontainerCLI) (status int) + name string + devcontainerID string + devcontainers []codersdk.WorkspaceAgentDevcontainer + containers []codersdk.WorkspaceAgentContainer + expectRecreate bool + expectErrorCode int }{ { - name: "Recreate", - devcontainerID: devcontainerID.String(), - setupDevcontainers: []codersdk.WorkspaceAgentDevcontainer{devcontainer}, - setupMock: func(mccli *acmock.MockContainerCLI, mdccli *acmock.MockDevcontainerCLI) int { - mccli.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{ - Containers: []codersdk.WorkspaceAgentContainer{devContainer}, - }, nil).AnyTimes() - // DetectArchitecture always returns "" for this test to disable agent injection. - mccli.EXPECT().DetectArchitecture(gomock.Any(), devContainer.ID).Return("", nil).AnyTimes() - mdccli.EXPECT().ReadConfig(gomock.Any(), workspaceFolder, configFile, gomock.Any()).Return(agentcontainers.DevcontainerConfig{}, nil).AnyTimes() - mdccli.EXPECT().Up(gomock.Any(), workspaceFolder, configFile, gomock.Any()).Return("someid", nil).Times(1) - return 0 - }, + name: "Recreate", + devcontainerID: devcontainerID.String(), + devcontainers: []codersdk.WorkspaceAgentDevcontainer{devcontainer}, + containers: []codersdk.WorkspaceAgentContainer{devContainer}, + expectRecreate: true, }, { - name: "Devcontainer does not exist", - devcontainerID: uuid.NewString(), - setupDevcontainers: nil, - setupMock: func(mccli *acmock.MockContainerCLI, mdccli *acmock.MockDevcontainerCLI) int { - mccli.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{}, nil).AnyTimes() - return http.StatusNotFound - }, + name: "Devcontainer does not exist", + devcontainerID: uuid.NewString(), + devcontainers: nil, + containers: []codersdk.WorkspaceAgentContainer{}, + expectErrorCode: http.StatusNotFound, }, } { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - mccli := acmock.NewMockContainerCLI(ctrl) - mdccli := acmock.NewMockDevcontainerCLI(ctrl) - wantStatus := tc.setupMock(mccli, mdccli) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ - Logger: &logger, - }) - user := coderdtest.CreateFirstUser(t, client) - r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { - return agents - }).Do() + var ( + ctx = testutil.Context(t, testutil.WaitLong) + mCtrl = gomock.NewController(t) + mCCLI = acmock.NewMockContainerCLI(mCtrl) + mDCCLI = acmock.NewMockDevcontainerCLI(mCtrl) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{ + Logger: &logger, + }) + user = coderdtest.CreateFirstUser(t, client) + r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { + return agents + }).Do() + ) + + mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{ + Containers: tc.containers, + }, nil).AnyTimes() + + var upCalled chan struct{} + + if tc.expectRecreate { + upCalled = make(chan struct{}) + + // DetectArchitecture always returns "" for this test to disable agent injection. + mCCLI.EXPECT().DetectArchitecture(gomock.Any(), devContainer.ID).Return("", nil).AnyTimes() + mDCCLI.EXPECT().ReadConfig(gomock.Any(), workspaceFolder, configFile, gomock.Any()).Return(agentcontainers.DevcontainerConfig{}, nil).AnyTimes() + mDCCLI.EXPECT().Up(gomock.Any(), workspaceFolder, configFile, gomock.Any()). + DoAndReturn(func(_ context.Context, _, _ string, _ ...agentcontainers.DevcontainerCLIUpOptions) (string, error) { + close(upCalled) + + return "someid", nil + }).Times(1) + } devcontainerAPIOptions := []agentcontainers.Option{ - agentcontainers.WithContainerCLI(mccli), - agentcontainers.WithDevcontainerCLI(mdccli), + agentcontainers.WithContainerCLI(mCCLI), + agentcontainers.WithDevcontainerCLI(mDCCLI), agentcontainers.WithWatcher(watcher.NewNoop()), } - if tc.setupDevcontainers != nil { + if tc.devcontainers != nil { devcontainerAPIOptions = append(devcontainerAPIOptions, - agentcontainers.WithDevcontainers(tc.setupDevcontainers, nil)) + agentcontainers.WithDevcontainers(tc.devcontainers, nil)) } _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { @@ -1679,15 +1693,14 @@ func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { require.Len(t, resources[0].Agents, 1, "expected one agent") agentID := resources[0].Agents[0].ID - ctx := testutil.Context(t, testutil.WaitLong) - _, err := client.WorkspaceAgentRecreateDevcontainer(ctx, agentID, tc.devcontainerID) - if wantStatus > 0 { + if tc.expectErrorCode > 0 { cerr, ok := codersdk.AsError(err) require.True(t, ok, "expected error to be a coder error") - assert.Equal(t, wantStatus, cerr.StatusCode()) + assert.Equal(t, tc.expectErrorCode, cerr.StatusCode()) } else { require.NoError(t, err, "failed to recreate devcontainer") + testutil.TryReceive(ctx, t, upCalled) } }) }