diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 17aae8fe47f0f..0a30bf97cce22 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1582,6 +1582,13 @@ jobs: "type": "mrkdwn", "text": "*View failure:* <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|Click here>" } + }, + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "<@U08TJ4YNCA3> investigate this CI failure. Check logs, search for existing issues, use git blame to find who last modified failing tests, create issue in coder/internal (not public repo), use title format \"flake: TestName\" for flaky tests, and assign to the person from git blame." + } } ] }' ${{ secrets.CI_FAILURE_SLACK_WEBHOOK }} diff --git a/.github/workflows/nightly-gauntlet.yaml b/.github/workflows/nightly-gauntlet.yaml index 7b20ee92554b2..7bbf690f5e2db 100644 --- a/.github/workflows/nightly-gauntlet.yaml +++ b/.github/workflows/nightly-gauntlet.yaml @@ -203,6 +203,13 @@ jobs: "type": "mrkdwn", "text": "*View failure:* <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|Click here>" } + }, + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "<@U08TJ4YNCA3> investigate this CI failure. Check logs, search for existing issues, use git blame to find who last modified failing tests, create issue in coder/internal (not public repo), use title format \"flake: TestName\" for flaky tests, and assign to the person from git blame." + } } ] }' ${{ secrets.CI_FAILURE_SLACK_WEBHOOK }} diff --git a/Makefile b/Makefile index 9040a891700e1..a5341ee79f753 100644 --- a/Makefile +++ b/Makefile @@ -636,7 +636,8 @@ GEN_FILES := \ coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ - coderd/httpmw/loggermw/loggermock/loggermock.go + coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go # all gen targets should be added here and to gen/mark-fresh gen: gen/db gen/golden-files $(GEN_FILES) @@ -686,6 +687,7 @@ gen/mark-fresh: agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go \ " for file in $$files; do @@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g go generate ./coderd/httpmw/loggermw/loggermock/ touch "$@" +codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go + go generate ./codersdk/workspacesdk/agentconnmock/ + touch "$@" + agent/agentcontainers/dcspec/dcspec_gen.go: \ node_modules/.installed \ agent/agentcontainers/dcspec/devContainer.base.schema.json \ diff --git a/agent/agent_test.go b/agent/agent_test.go index 52d8cfc09d573..2425fd81a0ead 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2750,9 +2750,9 @@ func TestAgent_Dial(t *testing.T) { switch l.Addr().Network() { case "tcp": - conn, err = agentConn.Conn.DialContextTCP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp) case "udp": - conn, err = agentConn.Conn.DialContextUDP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp) default: t.Fatalf("unknown network: %s", l.Addr().Network()) } @@ -2811,7 +2811,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { }) // Setup a client connection. - newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn { + newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()}, DERPMap: derpMap, @@ -2891,13 +2891,13 @@ func TestAgent_UpdatedDERP(t *testing.T) { // Connect from a second client and make sure it uses the new DERP map. conn2 := newClientConn(newDerpMap, "client2") - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) t.Log("conn2 got the new DERPMap") // If the first client gets a DERP map update, it should be able to // reconnect just fine. - conn1.SetDERPMap(newDerpMap) - require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) + conn1.TailnetConn().SetDERPMap(newDerpMap) + require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs()) t.Log("set the new DERPMap on conn1") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -3264,7 +3264,7 @@ func setupSSHSessionOnPort( } func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( - *workspacesdk.AgentConn, + workspacesdk.AgentConn, *agenttest.Client, <-chan *proto.Stats, afero.Fs, diff --git a/agent/agentcontainers/api_test.go b/agent/agentcontainers/api_test.go index 8c8e3b5411ed0..263f1698a7117 100644 --- a/agent/agentcontainers/api_test.go +++ b/agent/agentcontainers/api_test.go @@ -1675,6 +1675,8 @@ func TestAPI(t *testing.T) { coderBin, err := os.Executable() require.NoError(t, err) + coderBin, err = filepath.EvalSymlinks(coderBin) + require.NoError(t, err) mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{ Containers: []codersdk.WorkspaceAgentContainer{testContainer}, @@ -2455,6 +2457,8 @@ func TestAPI(t *testing.T) { coderBin, err := os.Executable() require.NoError(t, err) + coderBin, err = filepath.EvalSymlinks(coderBin) + require.NoError(t, err) // Mock the `List` function to always return out test container. mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{ @@ -2549,6 +2553,8 @@ func TestAPI(t *testing.T) { coderBin, err := os.Executable() require.NoError(t, err) + coderBin, err = filepath.EvalSymlinks(coderBin) + require.NoError(t, err) // Mock the `List` function to always return out test container. mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{ @@ -2654,6 +2660,8 @@ func TestAPI(t *testing.T) { coderBin, err := os.Executable() require.NoError(t, err) + coderBin, err = filepath.EvalSymlinks(coderBin) + require.NoError(t, err) // Mock the `List` function to always return our test container. mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{ diff --git a/cli/ping.go b/cli/ping.go index 0b9fde5c62eb8..29af06affeaee 100644 --- a/cli/ping.go +++ b/cli/ping.go @@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command { } defer conn.Close() - derpMap := conn.DERPMap() + derpMap := conn.TailnetConn().DERPMap() diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second) defer diagCancel() @@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command { // Silent ping to determine whether we should show diags _, didP2p, _, _ := conn.Ping(ctx) - ni := conn.GetNetInfo() + ni := conn.TailnetConn().GetNetInfo() connDiags := cliui.ConnDiags{ DisableDirect: r.disableDirect, LocalNetInfo: ni, diff --git a/cli/portforward.go b/cli/portforward.go index 59c1f5827b06f..1b055d9e4362e 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command { func listenAndPortForward( ctx context.Context, inv *serpent.Invocation, - conn *workspacesdk.AgentConn, + conn workspacesdk.AgentConn, wg *sync.WaitGroup, spec portForwardSpec, logger slog.Logger, diff --git a/cli/speedtest.go b/cli/speedtest.go index 3827b45125842..86d0e6a9ee63c 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command { if err != nil { continue } - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { continue } @@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command { outputResult.Intervals[i] = interval } } - conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) + conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) out, err := formatter.Format(inv.Context(), outputResult) if err != nil { return err diff --git a/cli/ssh.go b/cli/ssh.go index bc2bb24235ad2..a2f0db7327bef 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command { } err = sshSession.Wait() - conn.SendDisconnectedTelemetry() + conn.TailnetConn().SendDisconnectedTelemetry() if err != nil { if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { // Clear the error since it's not useful beyond @@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName { func setStatsCallback( ctx context.Context, - agentConn *workspacesdk.AgentConn, + agentConn workspacesdk.AgentConn, logger slog.Logger, networkInfoDir string, networkInfoInterval time.Duration, @@ -1437,7 +1437,7 @@ func setStatsCallback( now := time.Now() cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{}) - agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb) + agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb) return errCh, nil } @@ -1451,13 +1451,13 @@ type sshNetworkStats struct { UsingCoderConnect bool `json:"using_coder_connect"` } -func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { +func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { latency, p2p, pingResult, err := agentConn.Ping(ctx) if err != nil { return nil, err } - node := agentConn.Node() - derpMap := agentConn.DERPMap() + node := agentConn.TailnetConn().Node() + derpMap := agentConn.TailnetConn().DERPMap() totalRx := uint64(0) totalTx := uint64(0) diff --git a/coderd/coderd.go b/coderd/coderd.go index a934536c0aef0..8ab204f8a31ef 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -325,6 +325,9 @@ func New(options *Options) *API { }) } + if options.PrometheusRegistry == nil { + options.PrometheusRegistry = prometheus.NewRegistry() + } if options.Authorizer == nil { options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry) if buildinfo.IsDev() { @@ -381,9 +384,6 @@ func New(options *Options) *API { if options.FilesRateLimit == 0 { options.FilesRateLimit = 12 } - if options.PrometheusRegistry == nil { - options.PrometheusRegistry = prometheus.NewRegistry() - } if options.Clock == nil { options.Clock = quartz.NewReal() } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 7ae1e4bbf9b73..a716c04adc030 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1837,6 +1837,14 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt) } +func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) { + _, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID) + if err != nil { + return uuid.Nil, err + } + return q.db.FindMatchingPresetID(ctx, arg) +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index e4639c3ae0adf..ce70a9b1f112a 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -4965,6 +4965,22 @@ func (s *MethodTestSuite) TestPrebuilds() { template, policy.ActionUse, ).Errors(sql.ErrNoRows) })) + s.Run("FindMatchingPresetID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + t1 := testutil.Fake(s.T(), faker, database.Template{}) + tv := testutil.Fake(s.T(), faker, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}}) + dbm.EXPECT().FindMatchingPresetID(gomock.Any(), database.FindMatchingPresetIDParams{ + TemplateVersionID: tv.ID, + ParameterNames: []string{"test"}, + ParameterValues: []string{"test"}, + }).Return(uuid.Nil, nil).AnyTimes() + dbm.EXPECT().GetTemplateVersionByID(gomock.Any(), tv.ID).Return(tv, nil).AnyTimes() + dbm.EXPECT().GetTemplateByID(gomock.Any(), t1.ID).Return(t1, nil).AnyTimes() + check.Args(database.FindMatchingPresetIDParams{ + TemplateVersionID: tv.ID, + ParameterNames: []string{"test"}, + ParameterValues: []string{"test"}, + }).Asserts(tv.RBACObject(t1), policy.ActionRead).Returns(uuid.Nil) + })) s.Run("GetPrebuildMetrics", s.Subtest(func(_ database.Store, check *expects) { check.Args(). Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 12133997bf2c9..11d21eab3b593 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -565,6 +565,13 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context. return r0, r1 } +func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.FindMatchingPresetID(ctx, arg) + m.queryLatencies.WithLabelValues("FindMatchingPresetID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { start := time.Now() apiKey, err := m.s.GetAPIKeyByID(ctx, id) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 96e277cd7af58..67244cf2b01e9 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1051,6 +1051,21 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt) } +// FindMatchingPresetID mocks base method. +func (m *MockStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindMatchingPresetID", ctx, arg) + ret0, _ := ret[0].(uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindMatchingPresetID indicates an expected call of FindMatchingPresetID. +func (mr *MockStoreMockRecorder) FindMatchingPresetID(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindMatchingPresetID", reflect.TypeOf((*MockStore)(nil).FindMatchingPresetID), ctx, arg) +} + // GetAPIKeyByID mocks base method. func (m *MockStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 8ac974ff20ee8..c490a04d2b653 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -137,6 +137,11 @@ type sqlcQuerier interface { FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error) FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error) + // FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. + // It returns the preset ID if a match is found, or NULL if no match is found. + // The query finds presets where all preset parameters are present in the provided parameters, + // and returns the preset with the most parameters (largest subset). + FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error) GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1b63e7c1e960f..d16bd34f25f82 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -32,7 +32,7 @@ WITH latest AS ( -- be as if the workspace auto started at the given time and the -- original TTL was applied. -- - -- Sadly we can't define ` + "`" + `activity_bump_interval` + "`" + ` above since + -- Sadly we can't define 'activity_bump_interval' above since -- it won't be available for this CASE statement, so we have to -- copy the cast twice. WHEN NOW() + (templates.activity_bump / 1000 / 1000 / 1000 || ' seconds')::interval > $1 :: timestamptz @@ -62,7 +62,11 @@ WITH latest AS ( ON workspaces.id = workspace_builds.workspace_id JOIN templates ON templates.id = workspaces.template_id - WHERE workspace_builds.workspace_id = $2::uuid + WHERE + workspace_builds.workspace_id = $2::uuid + -- Prebuilt workspaces (identified by having the prebuilds system user as owner_id) + -- are managed by the reconciliation loop and not subject to activity bumping + AND workspaces.owner_id != 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID ORDER BY workspace_builds.build_number DESC LIMIT 1 ) @@ -7252,6 +7256,47 @@ func (q *sqlQuerier) CountInProgressPrebuilds(ctx context.Context) ([]CountInPro return items, nil } +const findMatchingPresetID = `-- name: FindMatchingPresetID :one +WITH provided_params AS ( + SELECT + unnest($1::text[]) AS name, + unnest($2::text[]) AS value +), +preset_matches AS ( + SELECT + tvp.id AS template_version_preset_id, + COALESCE(COUNT(tvpp.name), 0) AS total_preset_params, + COALESCE(COUNT(pp.name), 0) AS matching_params + FROM template_version_presets tvp + LEFT JOIN template_version_preset_parameters tvpp ON tvpp.template_version_preset_id = tvp.id + LEFT JOIN provided_params pp ON pp.name = tvpp.name AND pp.value = tvpp.value + WHERE tvp.template_version_id = $3 + GROUP BY tvp.id +) +SELECT pm.template_version_preset_id +FROM preset_matches pm +WHERE pm.total_preset_params = pm.matching_params -- All preset parameters must match +ORDER BY pm.total_preset_params DESC -- Return the preset with the most parameters +LIMIT 1 +` + +type FindMatchingPresetIDParams struct { + ParameterNames []string `db:"parameter_names" json:"parameter_names"` + ParameterValues []string `db:"parameter_values" json:"parameter_values"` + TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"` +} + +// FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. +// It returns the preset ID if a match is found, or NULL if no match is found. +// The query finds presets where all preset parameters are present in the provided parameters, +// and returns the preset with the most parameters (largest subset). +func (q *sqlQuerier) FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, findMatchingPresetID, pq.Array(arg.ParameterNames), pq.Array(arg.ParameterValues), arg.TemplateVersionID) + var template_version_preset_id uuid.UUID + err := row.Scan(&template_version_preset_id) + return template_version_preset_id, err +} + const getPrebuildMetrics = `-- name: GetPrebuildMetrics :many SELECT t.name as template_name, diff --git a/coderd/database/queries/activitybump.sql b/coderd/database/queries/activitybump.sql index 09349d29e5d06..e367a93abf778 100644 --- a/coderd/database/queries/activitybump.sql +++ b/coderd/database/queries/activitybump.sql @@ -22,7 +22,7 @@ WITH latest AS ( -- be as if the workspace auto started at the given time and the -- original TTL was applied. -- - -- Sadly we can't define `activity_bump_interval` above since + -- Sadly we can't define 'activity_bump_interval' above since -- it won't be available for this CASE statement, so we have to -- copy the cast twice. WHEN NOW() + (templates.activity_bump / 1000 / 1000 / 1000 || ' seconds')::interval > @next_autostart :: timestamptz @@ -52,7 +52,11 @@ WITH latest AS ( ON workspaces.id = workspace_builds.workspace_id JOIN templates ON templates.id = workspaces.template_id - WHERE workspace_builds.workspace_id = @workspace_id::uuid + WHERE + workspace_builds.workspace_id = @workspace_id::uuid + -- Prebuilt workspaces (identified by having the prebuilds system user as owner_id) + -- are managed by the reconciliation loop and not subject to activity bumping + AND workspaces.owner_id != 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID ORDER BY workspace_builds.build_number DESC LIMIT 1 ) diff --git a/coderd/database/queries/prebuilds.sql b/coderd/database/queries/prebuilds.sql index 87a713974c563..8654453554e8c 100644 --- a/coderd/database/queries/prebuilds.sql +++ b/coderd/database/queries/prebuilds.sql @@ -245,3 +245,30 @@ INNER JOIN organizations o ON o.id = w.organization_id WHERE NOT t.deleted AND wpb.build_number = 1 GROUP BY t.name, tvp.name, o.name ORDER BY t.name, tvp.name, o.name; + +-- name: FindMatchingPresetID :one +-- FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. +-- It returns the preset ID if a match is found, or NULL if no match is found. +-- The query finds presets where all preset parameters are present in the provided parameters, +-- and returns the preset with the most parameters (largest subset). +WITH provided_params AS ( + SELECT + unnest(@parameter_names::text[]) AS name, + unnest(@parameter_values::text[]) AS value +), +preset_matches AS ( + SELECT + tvp.id AS template_version_preset_id, + COALESCE(COUNT(tvpp.name), 0) AS total_preset_params, + COALESCE(COUNT(pp.name), 0) AS matching_params + FROM template_version_presets tvp + LEFT JOIN template_version_preset_parameters tvpp ON tvpp.template_version_preset_id = tvp.id + LEFT JOIN provided_params pp ON pp.name = tvpp.name AND pp.value = tvpp.value + WHERE tvp.template_version_id = @template_version_id + GROUP BY tvp.id +) +SELECT pm.template_version_preset_id +FROM preset_matches pm +WHERE pm.total_preset_params = pm.matching_params -- All preset parameters must match +ORDER BY pm.total_preset_params DESC -- Return the preset with the most parameters +LIMIT 1; diff --git a/coderd/prebuilds/parameters.go b/coderd/prebuilds/parameters.go new file mode 100644 index 0000000000000..63a1a7b78bfa7 --- /dev/null +++ b/coderd/prebuilds/parameters.go @@ -0,0 +1,42 @@ +package prebuilds + +import ( + "context" + "database/sql" + "errors" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +// FindMatchingPresetID finds a preset ID that matches the provided parameters. +// It returns the preset ID if a match is found, or uuid.Nil if no match is found. +// The function performs a bidirectional comparison to ensure all parameters match exactly. +func FindMatchingPresetID( + ctx context.Context, + store database.Store, + templateVersionID uuid.UUID, + parameterNames []string, + parameterValues []string, +) (uuid.UUID, error) { + if len(parameterNames) != len(parameterValues) { + return uuid.Nil, xerrors.New("parameter names and values must have the same length") + } + + result, err := store.FindMatchingPresetID(ctx, database.FindMatchingPresetIDParams{ + TemplateVersionID: templateVersionID, + ParameterNames: parameterNames, + ParameterValues: parameterValues, + }) + if err != nil { + // Handle the case where no matching preset is found (no rows returned) + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, nil + } + return uuid.Nil, xerrors.Errorf("find matching preset ID: %w", err) + } + + return result, nil +} diff --git a/coderd/prebuilds/parameters_test.go b/coderd/prebuilds/parameters_test.go new file mode 100644 index 0000000000000..e9366bb1da02b --- /dev/null +++ b/coderd/prebuilds/parameters_test.go @@ -0,0 +1,198 @@ +package prebuilds_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/testutil" +) + +func TestFindMatchingPresetID(t *testing.T) { + t.Parallel() + + presetIDs := []uuid.UUID{ + uuid.New(), + uuid.New(), + } + // Give each preset a meaningful name in alphabetical order + presetNames := map[uuid.UUID]string{ + presetIDs[0]: "development", + presetIDs[1]: "production", + } + tests := []struct { + name string + parameterNames []string + parameterValues []string + presetParameters []database.TemplateVersionPresetParameter + expectedPresetID uuid.UUID + expectError bool + errorContains string + }{ + { + name: "exact match", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-west-2", "t3.medium"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: presetIDs[0], + expectError: false, + }, + { + name: "no match - different values", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-east-1", "t3.medium"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: uuid.Nil, + expectError: false, + }, + { + name: "no match - fewer provided parameters", + parameterNames: []string{"region"}, + parameterValues: []string{"us-west-2"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: uuid.Nil, + expectError: false, + }, + { + name: "subset match - extra provided parameter", + parameterNames: []string{"region", "instance_type", "extra_param"}, + parameterValues: []string{"us-west-2", "t3.medium", "extra_value"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: presetIDs[0], // Should match because all preset parameters are present + expectError: false, + }, + { + name: "mismatched parameter names vs values", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-west-2"}, + presetParameters: []database.TemplateVersionPresetParameter{}, + expectedPresetID: uuid.Nil, + expectError: true, + errorContains: "parameter names and values must have the same length", + }, + { + name: "multiple presets - match first", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-west-2", "t3.medium"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-east-1"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: presetIDs[0], + expectError: false, + }, + { + name: "largest subset match", + parameterNames: []string{"region", "instance_type", "storage_size"}, + parameterValues: []string{"us-west-2", "t3.medium", "100gb"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + }, + expectedPresetID: presetIDs[0], // Should match the larger subset (2 params vs 1 param) + expectError: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + db, _ := dbtestutil.NewDB(t) + org := dbgen.Organization(t, db, database.Organization{}) + user := dbgen.User(t, db, database.User{}) + templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: user.ID, + JobID: uuid.New(), + }) + + // Group parameters by preset ID and create presets + presetMap := make(map[uuid.UUID][]database.TemplateVersionPresetParameter) + for _, param := range tt.presetParameters { + presetMap[param.TemplateVersionPresetID] = append(presetMap[param.TemplateVersionPresetID], param) + } + + // Create presets and insert their parameters + for presetID, params := range presetMap { + // Create the preset + _, err := db.InsertPreset(ctx, database.InsertPresetParams{ + ID: presetID, + TemplateVersionID: templateVersion.ID, + Name: presetNames[presetID], + CreatedAt: dbtestutil.NowInDefaultTimezone(), + }) + require.NoError(t, err) + + // Insert parameters for this preset + names := make([]string, len(params)) + values := make([]string, len(params)) + for i, param := range params { + names[i] = param.Name + values[i] = param.Value + } + + _, err = db.InsertPresetParameters(ctx, database.InsertPresetParametersParams{ + TemplateVersionPresetID: presetID, + Names: names, + Values: values, + }) + require.NoError(t, err) + } + + result, err := prebuilds.FindMatchingPresetID( + ctx, + db, + templateVersion.ID, + tt.parameterNames, + tt.parameterValues, + ) + + // Assert results + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedPresetID, result) + } + }) + } +} diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 172edea95a586..cdcf657fe732d 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) ( }, nil } -func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) { +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { var ( - conn *workspacesdk.AgentConn + conn workspacesdk.AgentConn ret func() ) diff --git a/coderd/workspaceagents_internal_test.go b/coderd/workspaceagents_internal_test.go new file mode 100644 index 0000000000000..c7520f05ab503 --- /dev/null +++ b/coderd/workspaceagents_internal_test.go @@ -0,0 +1,186 @@ +package coderd + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +type fakeAgentProvider struct { + agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) +} + +func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy { + panic("unimplemented") +} + +func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + if f.agentConn != nil { + return f.agentConn(ctx, agentID) + } + + panic("unimplemented") +} + +func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + panic("unimplemented") +} + +func (fakeAgentProvider) Close() error { + return nil +} + +func TestWatchAgentContainers(t *testing.T) { + t.Parallel() + + t.Run("WebSocketClosesProperly", func(t *testing.T) { + t.Parallel() + + // This test ensures that the agent containers `/watch` websocket can gracefully + // handle the underlying websocket unexpectedly closing. This test was created in + // response to this issue: https://github.com/coder/coder/issues/19372 + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd") + + mCtrl = gomock.NewController(t) + mDB = dbmock.NewMockStore(mCtrl) + mCoordinator = tailnettest.NewMockCoordinator(mCtrl) + mAgentConn = agentconnmock.NewMockAgentConn(mCtrl) + + fAgentProvider = fakeAgentProvider{ + agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + return mAgentConn, func() {}, nil + }, + } + + workspaceID = uuid.New() + agentID = uuid.New() + resourceID = uuid.New() + jobID = uuid.New() + buildID = uuid.New() + + containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse) + + r = chi.NewMux() + + api = API{ + ctx: ctx, + Options: &Options{ + AgentInactiveDisconnectTimeout: testutil.WaitShort, + Database: mDB, + Logger: logger, + DeploymentValues: &codersdk.DeploymentValues{}, + TailnetCoordinator: tailnettest.NewFakeCoordinator(), + }, + } + ) + + var tailnetCoordinator tailnet.Coordinator = mCoordinator + api.TailnetCoordinator.Store(&tailnetCoordinator) + api.agentProvider = fAgentProvider + + // Setup: Allow `ExtractWorkspaceAgentParams` to complete. + mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{ + ID: agentID, + ResourceID: resourceID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + }, nil) + mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{ + ID: resourceID, + JobID: jobID, + }, nil) + mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{ + ID: jobID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }, nil) + mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{ + WorkspaceID: workspaceID, + ID: buildID, + }, nil) + + // And: Allow `db2dsk.WorkspaceAgent` to complete. + mCoordinator.EXPECT().Node(gomock.Any()).Return(nil) + + // And: Allow `WatchContainers` to be called, returing our `containersCh` channel. + mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()). + Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil) + + // And: We mount the HTTP Handler + r.With(httpmw.ExtractWorkspaceAgentParam(mDB)). + Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers) + + // Given: We create the HTTP server + srv := httptest.NewServer(r) + defer srv.Close() + + // And: Dial the WebSocket + wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil) + require.NoError(t, err) + if resp.Body != nil { + defer resp.Body.Close() + } + + // And: Create a streaming decoder + decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger) + defer decoder.Close() + decodeCh := decoder.Chan() + + // And: We can successfully send through the channel. + testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{{ + ID: "test-container-id", + }}, + }) + + // And: Receive the data. + containerResp := testutil.RequireReceive(ctx, t, decodeCh) + require.Len(t, containerResp.Containers, 1) + require.Equal(t, "test-container-id", containerResp.Containers[0].ID) + + // When: We close the `containersCh` + close(containersCh) + + // Then: We expect `decodeCh` to be closed. + select { + case <-ctx.Done(): + t.Fail() + + case _, ok := <-decodeCh: + require.False(t, ok, "channel is expected to be closed") + } + }) +} diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 1855ed8a7e8fc..ac58df1b772ad 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -593,7 +593,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) { _ = agenttest.New(t, client.URL, r.AgentToken) resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) - conn, err := func() (*workspacesdk.AgentConn, error) { + conn, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -1574,82 +1574,6 @@ func TestWatchWorkspaceAgentDevcontainers(t *testing.T) { } } }) - - t.Run("PayloadTooLarge", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitSuperLong) - logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - mClock = quartz.NewMock(t) - updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop") - mCtrl = gomock.NewController(t) - mCCLI = acmock.NewMockContainerCLI(mCtrl) - - client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &logger}) - user = coderdtest.CreateFirstUser(t, client) - r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { - return agents - }).Do() - ) - - // WebSocket limit is 4MiB, so we want to ensure we create _more_ than 4MiB worth of payload. - // Creating 20,000 fake containers creates a payload of roughly 7MiB. - var fakeContainers []codersdk.WorkspaceAgentContainer - for range 20_000 { - fakeContainers = append(fakeContainers, codersdk.WorkspaceAgentContainer{ - CreatedAt: time.Now(), - ID: uuid.NewString(), - FriendlyName: uuid.NewString(), - Image: "busybox:latest", - Labels: map[string]string{ - agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project", - agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json", - }, - Running: false, - Ports: []codersdk.WorkspaceAgentContainerPort{}, - Status: string(codersdk.WorkspaceAgentDevcontainerStatusRunning), - Volumes: map[string]string{}, - }) - } - - mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: fakeContainers}, nil) - mCCLI.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() - - _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { - o.Logger = logger.Named("agent") - o.Devcontainers = true - o.DevcontainerAPIOptions = []agentcontainers.Option{ - agentcontainers.WithClock(mClock), - agentcontainers.WithContainerCLI(mCCLI), - agentcontainers.WithWatcher(watcher.NewNoop()), - } - }) - - resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() - require.Len(t, resources, 1, "expected one resource") - require.Len(t, resources[0].Agents, 1, "expected one agent") - agentID := resources[0].Agents[0].ID - - updaterTickerTrap.MustWait(ctx).MustRelease(ctx) - defer updaterTickerTrap.Close() - - containers, closer, err := client.WatchWorkspaceAgentContainers(ctx, agentID) - require.NoError(t, err) - defer func() { - closer.Close() - }() - - select { - case <-ctx.Done(): - t.Fail() - case _, ok := <-containers: - require.False(t, ok) - } - }) } func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { @@ -2497,7 +2421,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { agentID := resources[0].Agents[0].ID // Connect from a client. - conn1, err := func() (*workspacesdk.AgentConn, error) { + conn1, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -2538,7 +2462,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { // Wait for the DERP map to be updated on the existing client. require.Eventually(t, func() bool { - regionIDs := conn1.Conn.DERPMap().RegionIDs() + regionIDs := conn1.TailnetConn().DERPMap().RegionIDs() return len(regionIDs) == 1 && regionIDs[0] == 2 }, testutil.WaitLong, testutil.IntervalFast) @@ -2555,7 +2479,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { defer conn2.Close() ok = conn2.AwaitReachable(ctx) require.True(t, ok) - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) } func TestWorkspaceAgentExternalAuthListen(t *testing.T) { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 2f1294558f67a..002bb1ea05aae 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -74,7 +74,7 @@ type AgentProvider interface { ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy // AgentConn returns a new connection to the specified agent. - AgentConn(ctx context.Context, agentID uuid.UUID) (_ *workspacesdk.AgentConn, release func(), _ error) + AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index 29c9cac0ffa13..633acae328673 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -1638,6 +1638,8 @@ func TestPostWorkspaceBuild(t *testing.T) { t.Run("SetsPresetID", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) user := coderdtest.CreateFirstUser(t, client) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ @@ -1645,9 +1647,20 @@ func TestPostWorkspaceBuild(t *testing.T) { ProvisionPlan: []*proto.Response{{ Type: &proto.Response_Plan{ Plan: &proto.PlanComplete{ - Presets: []*proto.Preset{{ - Name: "test", - }}, + Presets: []*proto.Preset{ + { + Name: "autodetected", + }, + { + Name: "manual", + Parameters: []*proto.PresetParameter{ + { + Name: "param1", + Value: "value1", + }, + }, + }, + }, }, }, }}, @@ -1655,28 +1668,29 @@ func TestPostWorkspaceBuild(t *testing.T) { }) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, template.ID) - coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - require.Nil(t, workspace.LatestBuild.TemplateVersionPresetID) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() presets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) - require.Equal(t, 1, len(presets)) - require.Equal(t, "test", presets[0].Name) + require.Equal(t, 2, len(presets)) + require.Equal(t, "autodetected", presets[0].Name) + require.Equal(t, "manual", presets[1].Name) + + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + // Preset ID was detected based on the workspace parameters: + require.Equal(t, presets[0].ID, *workspace.LatestBuild.TemplateVersionPresetID) build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ TemplateVersionID: version.ID, Transition: codersdk.WorkspaceTransitionStart, - TemplateVersionPresetID: presets[0].ID, + TemplateVersionPresetID: presets[1].ID, }) require.NoError(t, err) require.NotNil(t, build.TemplateVersionPresetID) workspace, err = client.Workspace(ctx, workspace.ID) require.NoError(t, err) + require.Equal(t, presets[1].ID, *workspace.LatestBuild.TemplateVersionPresetID) require.Equal(t, build.TemplateVersionPresetID, workspace.LatestBuild.TemplateVersionPresetID) }) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 23bd8c5f6ed9e..b2b2610ff1349 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -638,14 +638,35 @@ func createWorkspace( // Use injected Clock to allow time mocking in tests now := api.Clock.Now() - // If a template preset was chosen, try claim a prebuilt workspace. - if req.TemplateVersionPresetID != uuid.Nil { + templateVersionPresetID := req.TemplateVersionPresetID + + // If no preset was chosen, look for a matching preset by parameter values. + if templateVersionPresetID == uuid.Nil { + parameterNames := make([]string, len(req.RichParameterValues)) + parameterValues := make([]string, len(req.RichParameterValues)) + for i, parameter := range req.RichParameterValues { + parameterNames[i] = parameter.Name + parameterValues[i] = parameter.Value + } + var err error + templateVersionID := req.TemplateVersionID + if templateVersionID == uuid.Nil { + templateVersionID = template.ActiveVersionID + } + templateVersionPresetID, err = prebuilds.FindMatchingPresetID(ctx, db, templateVersionID, parameterNames, parameterValues) + if err != nil { + return xerrors.Errorf("find matching preset: %w", err) + } + } + + // Try to claim a prebuilt workspace. + if templateVersionPresetID != uuid.Nil { // Try and claim an eligible prebuild, if available. // On successful claim, initialize all lifecycle fields from template and workspace-level config // so the newly claimed workspace is properly managed by the lifecycle executor. claimedWorkspace, err = claimPrebuild( - ctx, prebuildsClaimer, db, api.Logger, now, req, owner, - dbAutostartSchedule, nextStartAt, dbTTL) + ctx, prebuildsClaimer, db, api.Logger, now, req.Name, owner, + templateVersionPresetID, dbAutostartSchedule, nextStartAt, dbTTL) // If claiming fails with an expected error (no claimable prebuilds or AGPL does not support prebuilds), // we fall back to creating a new workspace. Otherwise, propagate the unexpected error. if err != nil { @@ -654,7 +675,7 @@ func createWorkspace( fields := []any{ slog.Error(err), slog.F("workspace_name", req.Name), - slog.F("template_version_preset_id", req.TemplateVersionPresetID), + slog.F("template_version_preset_id", templateVersionPresetID), } if !isExpectedError { @@ -718,8 +739,8 @@ func createWorkspace( if req.TemplateVersionID != uuid.Nil { builder = builder.VersionID(req.TemplateVersionID) } - if req.TemplateVersionPresetID != uuid.Nil { - builder = builder.TemplateVersionPresetID(req.TemplateVersionPresetID) + if templateVersionPresetID != uuid.Nil { + builder = builder.TemplateVersionPresetID(templateVersionPresetID) } if claimedWorkspace != nil { builder = builder.MarkPrebuiltWorkspaceClaim() @@ -884,13 +905,14 @@ func claimPrebuild( db database.Store, logger slog.Logger, now time.Time, - req codersdk.CreateWorkspaceRequest, + name string, owner workspaceOwner, + templateVersionPresetID uuid.UUID, autostartSchedule sql.NullString, nextStartAt sql.NullTime, ttl sql.NullInt64, ) (*database.Workspace, error) { - claimedID, err := claimer.Claim(ctx, now, owner.ID, req.Name, req.TemplateVersionPresetID, autostartSchedule, nextStartAt, ttl) + claimedID, err := claimer.Claim(ctx, now, owner.ID, name, templateVersionPresetID, autostartSchedule, nextStartAt, ttl) if err != nil { // TODO: enhance this by clarifying whether this *specific* prebuild failed or whether there are none to claim. return nil, xerrors.Errorf("claim prebuild: %w", err) diff --git a/coderd/workspaces_test.go b/coderd/workspaces_test.go index 443098036af62..8fc11ef6c8ccb 100644 --- a/coderd/workspaces_test.go +++ b/coderd/workspaces_test.go @@ -4915,3 +4915,285 @@ func TestUpdateWorkspaceACL(t *testing.T) { require.Equal(t, cerr.Validations[0].Field, "user_roles") }) } + +func TestWorkspaceCreateWithImplicitPreset(t *testing.T) { + t.Parallel() + + // Helper function to create template with presets + createTemplateWithPresets := func(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, presets []*proto.Preset) (codersdk.Template, codersdk.TemplateVersion) { + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Presets: presets, + }, + }, + }, + }, + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + return template, version + } + + // Helper function to create workspace and verify preset usage + createWorkspaceAndVerifyPreset := func(t *testing.T, client *codersdk.Client, template codersdk.Template, expectedPresetID *uuid.UUID, params []codersdk.WorkspaceBuildParameter) codersdk.Workspace { + wsName := testutil.GetRandomNameHyphenated(t) + var ws codersdk.Workspace + if len(params) > 0 { + ws = coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.Name = wsName + cwr.RichParameterValues = params + }) + } else { + ws = coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.Name = wsName + }) + } + require.Equal(t, wsName, ws.Name) + + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID) + + // Verify the preset was used if expected + if expectedPresetID != nil { + require.NotNil(t, ws.LatestBuild.TemplateVersionPresetID) + require.Equal(t, *expectedPresetID, *ws.LatestBuild.TemplateVersionPresetID) + } else { + require.Nil(t, ws.LatestBuild.TemplateVersionPresetID) + } + + return ws + } + + t.Run("NoPresets", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with no presets + template, _ := createTemplateWithPresets(t, client, user, []*proto.Preset{}) + + // Test workspace creation with no parameters + createWorkspaceAndVerifyPreset(t, client, template, nil, nil) + + // Test workspace creation with parameters (should still work, no preset matching) + createWorkspaceAndVerifyPreset(t, client, template, nil, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + }) + + t.Run("SinglePresetNoParameters", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with single preset that has no parameters + preset := &proto.Preset{ + Name: "empty-preset", + Description: "A preset with no parameters", + Parameters: []*proto.PresetParameter{}, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset}) + + // Get the preset ID from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + presetID := presets[0].ID + + // Test workspace creation with no parameters - should match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, nil) + + // Test workspace creation with parameters - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + }) + + t.Run("SinglePresetWithParameters", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with single preset that has parameters + preset := &proto.Preset{ + Name: "param-preset", + Description: "A preset with parameters", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset}) + + // Get the preset ID from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + presetID := presets[0].ID + + // Test workspace creation with no parameters - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, nil, nil) + + // Test workspace creation with exact matching parameters - should match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }) + + // Test workspace creation with partial matching parameters - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, nil, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + + // Test workspace creation with different parameter values - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, nil, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "different"}, + }) + + // Test workspace creation with extra parameters - should match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + {Name: "param3", Value: "value3"}, + }) + }) + + t.Run("MultiplePresets", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with multiple presets + preset1 := &proto.Preset{ + Name: "empty-preset", + Description: "A preset with no parameters", + Parameters: []*proto.PresetParameter{}, + } + preset2 := &proto.Preset{ + Name: "single-param-preset", + Description: "A preset with one parameter", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + }, + } + preset3 := &proto.Preset{ + Name: "multi-param-preset", + Description: "A preset with multiple parameters", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset1, preset2, preset3}) + + // Get the preset IDs from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 3) + + // Sort presets by name to get consistent ordering + var emptyPresetID, singleParamPresetID, multiParamPresetID uuid.UUID + for _, p := range presets { + switch p.Name { + case "empty-preset": + emptyPresetID = p.ID + case "single-param-preset": + singleParamPresetID = p.ID + case "multi-param-preset": + multiParamPresetID = p.ID + } + } + + // Test workspace creation with no parameters - should match empty preset + createWorkspaceAndVerifyPreset(t, client, template, &emptyPresetID, nil) + + // Test workspace creation with single parameter - should match single param preset + createWorkspaceAndVerifyPreset(t, client, template, &singleParamPresetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + + // Test workspace creation with multiple parameters - should match multi param preset + createWorkspaceAndVerifyPreset(t, client, template, &multiParamPresetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }) + + // Test workspace creation with non-matching parameters - should not match any preset + createWorkspaceAndVerifyPreset(t, client, template, &emptyPresetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "different"}, + }) + }) + + t.Run("PresetSpecifiedExplicitly", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with multiple presets + preset1 := &proto.Preset{ + Name: "preset1", + Description: "First preset", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + }, + } + preset2 := &proto.Preset{ + Name: "preset2", + Description: "Second preset", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value2"}, + }, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset1, preset2}) + + // Get the preset IDs from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 2) + + var preset1ID, preset2ID uuid.UUID + for _, p := range presets { + switch p.Name { + case "preset1": + preset1ID = p.ID + case "preset2": + preset2ID = p.ID + } + } + + // Test workspace creation with preset1 specified explicitly - should use preset1 regardless of parameters + ws := coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.TemplateVersionPresetID = preset1ID + cwr.RichParameterValues = []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value2"}, // This would normally match preset2 + } + }) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID) + require.NotNil(t, ws.LatestBuild.TemplateVersionPresetID) + require.Equal(t, preset1ID, *ws.LatestBuild.TemplateVersionPresetID) + + // Test workspace creation with preset2 specified explicitly - should use preset2 regardless of parameters + ws2 := coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.TemplateVersionPresetID = preset2ID + cwr.RichParameterValues = []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, // This would normally match preset1 + } + }) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws2.LatestBuild.ID) + require.NotNil(t, ws2.LatestBuild.TemplateVersionPresetID) + require.Equal(t, preset2ID, *ws2.LatestBuild.TemplateVersionPresetID) + }) +} diff --git a/coderd/workspacestats/reporter.go b/coderd/workspacestats/reporter.go index 58d177f1c2071..f6b8a8dd0953b 100644 --- a/coderd/workspacestats/reporter.go +++ b/coderd/workspacestats/reporter.go @@ -149,33 +149,36 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac return nil } - // check next autostart - var nextAutostart time.Time - if workspace.AutostartSchedule.String != "" { - templateSchedule, err := (*(r.opts.TemplateScheduleStore.Load())).Get(ctx, r.opts.Database, workspace.TemplateID) - // If the template schedule fails to load, just default to bumping - // without the next transition and log it. - switch { - case err == nil: - next, allowed := schedule.NextAutostart(now, workspace.AutostartSchedule.String, templateSchedule) - if allowed { - nextAutostart = next + // Prebuilds are not subject to activity-based deadline bumps + if !workspace.IsPrebuild() { + // check next autostart + var nextAutostart time.Time + if workspace.AutostartSchedule.String != "" { + templateSchedule, err := (*(r.opts.TemplateScheduleStore.Load())).Get(ctx, r.opts.Database, workspace.TemplateID) + // If the template schedule fails to load, just default to bumping + // without the next transition and log it. + switch { + case err == nil: + next, allowed := schedule.NextAutostart(now, workspace.AutostartSchedule.String, templateSchedule) + if allowed { + nextAutostart = next + } + case database.IsQueryCanceledError(err): + r.opts.Logger.Debug(ctx, "query canceled while loading template schedule", + slog.F("workspace_id", workspace.ID), + slog.F("template_id", workspace.TemplateID)) + default: + r.opts.Logger.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min", + slog.F("workspace_id", workspace.ID), + slog.F("template_id", workspace.TemplateID), + slog.Error(err), + ) } - case database.IsQueryCanceledError(err): - r.opts.Logger.Debug(ctx, "query canceled while loading template schedule", - slog.F("workspace_id", workspace.ID), - slog.F("template_id", workspace.TemplateID)) - default: - r.opts.Logger.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min", - slog.F("workspace_id", workspace.ID), - slog.F("template_id", workspace.TemplateID), - slog.Error(err), - ) } - } - // bump workspace activity - ActivityBumpWorkspace(ctx, r.opts.Logger.Named("activity_bump"), r.opts.Database, workspace.ID, nextAutostart) + // bump workspace activity + ActivityBumpWorkspace(ctx, r.opts.Logger.Named("activity_bump"), r.opts.Database, workspace.ID, nextAutostart) + } // bump workspace last_used_at r.opts.UsageTracker.Add(workspace.ID) diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 73e449ee5bb93..223b8bec084ad 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -15,6 +15,7 @@ import ( "github.com/coder/coder/v2/coderd/dynamicparameters" "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/provisioner/terraform/tfparse" @@ -442,6 +443,20 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object var workspaceBuild database.WorkspaceBuild err = b.store.InTx(func(store database.Store) error { + names, values, err := b.getParameters() + if err != nil { + // getParameters already wraps errors in BuildError + return err + } + + if b.templateVersionPresetID == uuid.Nil { + presetID, err := prebuilds.FindMatchingPresetID(b.ctx, b.store, templateVersionID, names, values) + if err != nil { + return BuildError{http.StatusInternalServerError, "find matching preset", err} + } + b.templateVersionPresetID = presetID + } + err = store.InsertWorkspaceBuild(b.ctx, database.InsertWorkspaceBuildParams{ ID: workspaceBuildID, CreatedAt: now, @@ -473,12 +488,6 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object return BuildError{code, "insert workspace build", err} } - names, values, err := b.getParameters() - if err != nil { - // getParameters already wraps errors in BuildError - return err - } - err = store.InsertWorkspaceBuildParameters(b.ctx, database.InsertWorkspaceBuildParametersParams{ WorkspaceBuildID: workspaceBuildID, Name: names, diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index ee421a8adb649..b862e6459c285 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -82,6 +82,7 @@ func TestBuilder_NoOptions(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(inactiveVersionID, bld.TemplateVersionID) asrt.Equal(workspaceID, bld.WorkspaceID) @@ -132,6 +133,7 @@ func TestBuilder_Initiator(t *testing.T) { asrt.Equal(otherUserID, job.InitiatorID) }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(otherUserID, bld.InitiatorID) }), @@ -180,6 +182,7 @@ func TestBuilder_Baggage(t *testing.T) { asrt.Contains(string(job.TraceMetadata.RawMessage), "ip=127.0.0.1") }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { }), expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) { @@ -219,6 +222,7 @@ func TestBuilder_Reason(t *testing.T) { expectProvisionerJob(func(_ database.InsertProvisionerJobParams) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(database.BuildReasonAutostart, bld.Reason) }), @@ -261,6 +265,7 @@ func TestBuilder_ActiveVersion(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(activeVersionID, bld.TemplateVersionID) // no previous build... @@ -386,6 +391,7 @@ func TestWorkspaceBuildWithTags(t *testing.T) { expectBuildParameters(func(_ database.InsertWorkspaceBuildParametersParams) { }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -470,6 +476,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { } }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -519,6 +526,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { } }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -661,6 +669,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { } }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -713,6 +722,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), // Outputs + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), withInTx, expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), @@ -775,6 +785,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), // Outputs + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), withInTx, expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), @@ -906,6 +917,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(inactiveVersionID, bld.TemplateVersionID) asrt.Equal(workspaceID, bld.WorkspaceID) @@ -968,6 +980,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(inactiveVersionID, bld.TemplateVersionID) asrt.Equal(workspaceID, bld.WorkspaceID) @@ -1041,6 +1054,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { // Outputs expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), withBuild, expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) {}), @@ -1485,6 +1499,14 @@ func withProvisionerDaemons(provisionerDaemons []database.GetEligibleProvisioner } } +func expectFindMatchingPresetID(id uuid.UUID, err error) func(mTx *dbmock.MockStore) { + return func(mTx *dbmock.MockStore) { + mTx.EXPECT().FindMatchingPresetID(gomock.Any(), gomock.Any()). + Times(1). + Return(id, err) + } +} + type fakeUsageChecker struct { checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) } diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 9c65b7ee9a1e1..bb929c9ba2a04 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -34,8 +34,8 @@ import ( // to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the // conn is shared and closing it is undesirable, you may return ErrNoClose from // opts.CloseFunc. This will ensure the underlying conn is not closed. -func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { - return &AgentConn{ +func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) AgentConn { + return &agentConn{ Conn: conn, opts: opts, } @@ -43,23 +43,54 @@ func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { // AgentConn represents a connection to a workspace agent. // @typescript-ignore AgentConn -type AgentConn struct { +type AgentConn interface { + TailnetConn() *tailnet.Conn + + AwaitReachable(ctx context.Context) bool + Close() error + DebugLogs(ctx context.Context) ([]byte, error) + DebugMagicsock(ctx context.Context) ([]byte, error) + DebugManifest(ctx context.Context) ([]byte, error) + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) + GetPeerDiagnostics() tailnet.PeerDiagnostics + ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) + ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) + Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) + Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) + PrometheusMetrics(ctx context.Context) ([]byte, error) + ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) + RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) + SSH(ctx context.Context) (*gonet.TCPConn, error) + SSHClient(ctx context.Context) (*ssh.Client, error) + SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) + SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) + Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) + WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) +} + +// AgentConn represents a connection to a workspace agent. +// @typescript-ignore AgentConn +type agentConn struct { *tailnet.Conn opts AgentConnOptions } +func (c *agentConn) TailnetConn() *tailnet.Conn { + return c.Conn +} + // @typescript-ignore AgentConnOptions type AgentConnOptions struct { AgentID uuid.UUID CloseFunc func() error } -func (c *AgentConn) agentAddress() netip.Addr { +func (c *agentConn) agentAddress() netip.Addr { return tailnet.TailscaleServicePrefix.AddrFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. -func (c *AgentConn) AwaitReachable(ctx context.Context) bool { +func (c *agentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -68,7 +99,7 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool { // Ping pings the agent and returns the round-trip time. // The bool returns true if the ping was made P2P. -func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { +func (c *agentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -76,7 +107,7 @@ func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.Pi } // Close ends the connection to the workspace agent. -func (c *AgentConn) Close() error { +func (c *agentConn) Close() error { var cerr error if c.opts.CloseFunc != nil { cerr = c.opts.CloseFunc() @@ -131,7 +162,7 @@ type ReconnectingPTYRequest struct { // ReconnectingPTY spawns a new reconnecting terminal session. // `ReconnectingPTYRequest` should be JSON marshaled and written to the returned net.Conn. // Raw terminal output will be read from the returned net.Conn. -func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { +func (c *agentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -171,13 +202,13 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w // SSH pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent. -func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { +func (c *agentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { return c.SSHOnPort(ctx, AgentSSHPort) } // SSHOnPort pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent on the specified port. -func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { +func (c *agentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -190,12 +221,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, } // SSHClient calls SSH to create a client -func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { +func (c *agentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { +func (c *agentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -218,7 +249,7 @@ func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Clie } // Speedtest runs a speedtest against the workspace agent. -func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { +func (c *agentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -242,7 +273,7 @@ func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction // DialContext dials the address provided in the workspace agent. // The network must be "tcp" or "udp". -func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *agentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -265,7 +296,7 @@ func (c *AgentConn) DialContext(ctx context.Context, network string, addr string } // ListeningPorts lists the ports that are currently in use by the workspace. -func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { +func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil) @@ -282,7 +313,7 @@ func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent } // Netcheck returns a network check report from the workspace agent. -func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { +func (c *agentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/netcheck", nil) @@ -299,7 +330,7 @@ func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport } // DebugMagicsock makes a request to the workspace agent's magicsock debug endpoint. -func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/magicsock", nil) @@ -319,7 +350,7 @@ func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { // DebugManifest returns the agent's in-memory manifest. Unfortunately this must // be returns as a []byte to avoid an import cycle. -func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugManifest(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/manifest", nil) @@ -338,7 +369,7 @@ func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { } // DebugLogs returns up to the last 10MB of `/tmp/coder-agent.log` -func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugLogs(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/logs", nil) @@ -357,7 +388,7 @@ func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { } // PrometheusMetrics returns a response from the agent's prometheus metrics endpoint -func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { +func (c *agentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/prometheus", nil) @@ -376,7 +407,7 @@ func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { } // ListContainers returns a response from the agent's containers endpoint -func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { +func (c *agentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/containers", nil) @@ -391,7 +422,7 @@ func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } -func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { +func (c *agentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -427,7 +458,7 @@ func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<- // RecreateDevcontainer recreates a devcontainer with the given container. // This is a blocking call and will wait for the container to be recreated. -func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { +func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/containers/devcontainers/"+devcontainerID+"/recreate", nil) @@ -446,7 +477,7 @@ func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str } // apiRequest makes a request to the workspace agent's HTTP API server. -func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -463,7 +494,7 @@ func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io // apiClient returns an HTTP client that can be used to make // requests to the workspace agent's HTTP API server. -func (c *AgentConn) apiClient() *http.Client { +func (c *agentConn) apiClient() *http.Client { return &http.Client{ Transport: &http.Transport{ // Disable keep alives as we're usually only making a single @@ -504,6 +535,6 @@ func (c *AgentConn) apiClient() *http.Client { } } -func (c *AgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { +func (c *agentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { return c.Conn.GetPeerDiagnostics(c.opts.AgentID) } diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go new file mode 100644 index 0000000000000..eb55bb27938c0 --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -0,0 +1,373 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: .. (interfaces: AgentConn) +// +// Generated by this command: +// +// mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn +// + +// Package agentconnmock is a generated GoMock package. +package agentconnmock + +import ( + context "context" + io "io" + net "net" + reflect "reflect" + time "time" + + slog "cdr.dev/slog" + codersdk "github.com/coder/coder/v2/codersdk" + healthsdk "github.com/coder/coder/v2/codersdk/healthsdk" + workspacesdk "github.com/coder/coder/v2/codersdk/workspacesdk" + tailnet "github.com/coder/coder/v2/tailnet" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" + ssh "golang.org/x/crypto/ssh" + gonet "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + ipnstate "tailscale.com/ipn/ipnstate" + speedtest "tailscale.com/net/speedtest" +) + +// MockAgentConn is a mock of AgentConn interface. +type MockAgentConn struct { + ctrl *gomock.Controller + recorder *MockAgentConnMockRecorder + isgomock struct{} +} + +// MockAgentConnMockRecorder is the mock recorder for MockAgentConn. +type MockAgentConnMockRecorder struct { + mock *MockAgentConn +} + +// NewMockAgentConn creates a new mock instance. +func NewMockAgentConn(ctrl *gomock.Controller) *MockAgentConn { + mock := &MockAgentConn{ctrl: ctrl} + mock.recorder = &MockAgentConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAgentConn) EXPECT() *MockAgentConnMockRecorder { + return m.recorder +} + +// AwaitReachable mocks base method. +func (m *MockAgentConn) AwaitReachable(ctx context.Context) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AwaitReachable", ctx) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AwaitReachable indicates an expected call of AwaitReachable. +func (mr *MockAgentConnMockRecorder) AwaitReachable(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitReachable", reflect.TypeOf((*MockAgentConn)(nil).AwaitReachable), ctx) +} + +// Close mocks base method. +func (m *MockAgentConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAgentConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgentConn)(nil).Close)) +} + +// DebugLogs mocks base method. +func (m *MockAgentConn) DebugLogs(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugLogs", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugLogs indicates an expected call of DebugLogs. +func (mr *MockAgentConnMockRecorder) DebugLogs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugLogs", reflect.TypeOf((*MockAgentConn)(nil).DebugLogs), ctx) +} + +// DebugMagicsock mocks base method. +func (m *MockAgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugMagicsock", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugMagicsock indicates an expected call of DebugMagicsock. +func (mr *MockAgentConnMockRecorder) DebugMagicsock(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugMagicsock", reflect.TypeOf((*MockAgentConn)(nil).DebugMagicsock), ctx) +} + +// DebugManifest mocks base method. +func (m *MockAgentConn) DebugManifest(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugManifest", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugManifest indicates an expected call of DebugManifest. +func (mr *MockAgentConnMockRecorder) DebugManifest(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugManifest", reflect.TypeOf((*MockAgentConn)(nil).DebugManifest), ctx) +} + +// DialContext mocks base method. +func (m *MockAgentConn) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialContext", ctx, network, addr) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialContext indicates an expected call of DialContext. +func (mr *MockAgentConnMockRecorder) DialContext(ctx, network, addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockAgentConn)(nil).DialContext), ctx, network, addr) +} + +// GetPeerDiagnostics mocks base method. +func (m *MockAgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerDiagnostics") + ret0, _ := ret[0].(tailnet.PeerDiagnostics) + return ret0 +} + +// GetPeerDiagnostics indicates an expected call of GetPeerDiagnostics. +func (mr *MockAgentConnMockRecorder) GetPeerDiagnostics() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerDiagnostics", reflect.TypeOf((*MockAgentConn)(nil).GetPeerDiagnostics)) +} + +// ListContainers mocks base method. +func (m *MockAgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListContainers", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListContainers indicates an expected call of ListContainers. +func (mr *MockAgentConnMockRecorder) ListContainers(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListContainers", reflect.TypeOf((*MockAgentConn)(nil).ListContainers), ctx) +} + +// ListeningPorts mocks base method. +func (m *MockAgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListeningPorts", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListeningPortsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListeningPorts indicates an expected call of ListeningPorts. +func (mr *MockAgentConnMockRecorder) ListeningPorts(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListeningPorts", reflect.TypeOf((*MockAgentConn)(nil).ListeningPorts), ctx) +} + +// Netcheck mocks base method. +func (m *MockAgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Netcheck", ctx) + ret0, _ := ret[0].(healthsdk.AgentNetcheckReport) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Netcheck indicates an expected call of Netcheck. +func (mr *MockAgentConnMockRecorder) Netcheck(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Netcheck", reflect.TypeOf((*MockAgentConn)(nil).Netcheck), ctx) +} + +// Ping mocks base method. +func (m *MockAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(time.Duration) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(*ipnstate.PingResult) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Ping indicates an expected call of Ping. +func (mr *MockAgentConnMockRecorder) Ping(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockAgentConn)(nil).Ping), ctx) +} + +// PrometheusMetrics mocks base method. +func (m *MockAgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrometheusMetrics", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PrometheusMetrics indicates an expected call of PrometheusMetrics. +func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrometheusMetrics", reflect.TypeOf((*MockAgentConn)(nil).PrometheusMetrics), ctx) +} + +// ReconnectingPTY mocks base method. +func (m *MockAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...workspacesdk.AgentReconnectingPTYInitOption) (net.Conn, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, id, height, width, command} + for _, a := range initOpts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ReconnectingPTY", varargs...) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReconnectingPTY indicates an expected call of ReconnectingPTY. +func (mr *MockAgentConnMockRecorder) ReconnectingPTY(ctx, id, height, width, command any, initOpts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, id, height, width, command}, initOpts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReconnectingPTY", reflect.TypeOf((*MockAgentConn)(nil).ReconnectingPTY), varargs...) +} + +// RecreateDevcontainer mocks base method. +func (m *MockAgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecreateDevcontainer", ctx, devcontainerID) + ret0, _ := ret[0].(codersdk.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RecreateDevcontainer indicates an expected call of RecreateDevcontainer. +func (mr *MockAgentConnMockRecorder) RecreateDevcontainer(ctx, devcontainerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecreateDevcontainer", reflect.TypeOf((*MockAgentConn)(nil).RecreateDevcontainer), ctx, devcontainerID) +} + +// SSH mocks base method. +func (m *MockAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSH", ctx) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSH indicates an expected call of SSH. +func (mr *MockAgentConnMockRecorder) SSH(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSH", reflect.TypeOf((*MockAgentConn)(nil).SSH), ctx) +} + +// SSHClient mocks base method. +func (m *MockAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClient", ctx) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClient indicates an expected call of SSHClient. +func (mr *MockAgentConnMockRecorder) SSHClient(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClient", reflect.TypeOf((*MockAgentConn)(nil).SSHClient), ctx) +} + +// SSHClientOnPort mocks base method. +func (m *MockAgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClientOnPort", ctx, port) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClientOnPort indicates an expected call of SSHClientOnPort. +func (mr *MockAgentConnMockRecorder) SSHClientOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClientOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHClientOnPort), ctx, port) +} + +// SSHOnPort mocks base method. +func (m *MockAgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHOnPort", ctx, port) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHOnPort indicates an expected call of SSHOnPort. +func (mr *MockAgentConnMockRecorder) SSHOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHOnPort), ctx, port) +} + +// Speedtest mocks base method. +func (m *MockAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Speedtest", ctx, direction, duration) + ret0, _ := ret[0].([]speedtest.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Speedtest indicates an expected call of Speedtest. +func (mr *MockAgentConnMockRecorder) Speedtest(ctx, direction, duration any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Speedtest", reflect.TypeOf((*MockAgentConn)(nil).Speedtest), ctx, direction, duration) +} + +// TailnetConn mocks base method. +func (m *MockAgentConn) TailnetConn() *tailnet.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TailnetConn") + ret0, _ := ret[0].(*tailnet.Conn) + return ret0 +} + +// TailnetConn indicates an expected call of TailnetConn. +func (mr *MockAgentConnMockRecorder) TailnetConn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TailnetConn", reflect.TypeOf((*MockAgentConn)(nil).TailnetConn)) +} + +// WatchContainers mocks base method. +func (m *MockAgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WatchContainers", ctx, logger) + ret0, _ := ret[0].(<-chan codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(io.Closer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// WatchContainers indicates an expected call of WatchContainers. +func (mr *MockAgentConnMockRecorder) WatchContainers(ctx, logger any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchContainers", reflect.TypeOf((*MockAgentConn)(nil).WatchContainers), ctx, logger) +} diff --git a/codersdk/workspacesdk/agentconnmock/doc.go b/codersdk/workspacesdk/agentconnmock/doc.go new file mode 100644 index 0000000000000..a795b21a4a89d --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/doc.go @@ -0,0 +1,4 @@ +// Package agentconnmock contains a mock implementation of workspacesdk.AgentConn for use in tests. +package agentconnmock + +//go:generate mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 9f587cf5267a8..ddaec06388238 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -202,7 +202,7 @@ func (c *Client) RewriteDERPMap(derpMap *tailcfg.DERPMap) { tailnet.RewriteDERPMapDefaultRelay(context.Background(), c.client.Logger(), derpMap, c.client.URL) } -func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn *AgentConn, err error) { +func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn AgentConn, err error) { if options == nil { options = &DialAgentOptions{} } diff --git a/docs/admin/templates/extending-templates/prebuilt-workspaces.md b/docs/admin/templates/extending-templates/prebuilt-workspaces.md index 8e61687ce0f01..70c2031d2a837 100644 --- a/docs/admin/templates/extending-templates/prebuilt-workspaces.md +++ b/docs/admin/templates/extending-templates/prebuilt-workspaces.md @@ -29,6 +29,7 @@ Prebuilt workspaces are tightly integrated with [workspace presets](./parameters 1. The preset must define all required parameters needed to build the workspace. 1. The preset parameters define the base configuration and are immutable once a prebuilt workspace is provisioned. 1. Parameters that are not defined in the preset can still be customized by users when they claim a workspace. +1. If a user does not select a preset but provides parameters that match one or more presets, Coder will automatically select the most specific matching preset and assign a prebuilt workspace if one is available. ## Prerequisites @@ -291,16 +292,6 @@ does not reconnect after a template update. This shortcoming is described in [th and will be addressed before the next release (v2.23). In the interim, a simple workaround is to restart the workspace when it is in this problematic state. -### Current limitations - -The prebuilt workspaces feature has these current limitations: - -- **Organizations** - - Prebuilt workspaces can only be used with the default organization. - - [View issue](https://github.com/coder/internal/issues/364) - ### Monitoring and observability #### Available metrics diff --git a/docs/manifest.json b/docs/manifest.json index 66f4e6dbaf476..bd08ccfe372e6 100644 --- a/docs/manifest.json +++ b/docs/manifest.json @@ -47,6 +47,18 @@ "path": "./about/contributing/documentation.md", "icon_path": "./images/icons/document.svg" }, + { + "title": "Modules", + "description": "Learn how to contribute modules to Coder", + "path": "./about/contributing/modules.md", + "icon_path": "./images/icons/gear.svg" + }, + { + "title": "Templates", + "description": "Learn how to contribute templates to Coder", + "path": "./about/contributing/templates.md", + "icon_path": "./images/icons/picture.svg" + }, { "title": "Backend", "description": "Our guide for backend development", diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 7004653e4ed60..dc44a8794e1c6 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -42,6 +42,7 @@ import ( agplschedule "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/codersdk" entaudit "github.com/coder/coder/v2/enterprise/audit" "github.com/coder/coder/v2/enterprise/audit/backends" @@ -2767,6 +2768,114 @@ func TestPrebuildUpdateLifecycleParams(t *testing.T) { } } +func TestPrebuildActivityBump(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + clock.Set(dbtime.Now()) + + // Setup + log := testutil.Logger(t) + client, db, owner := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Clock: clock, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspacePrebuilds: 1, + }, + }, + }) + + // Given: a template and a template version with preset and a prebuilt workspace + presetID := uuid.New() + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + // Configure activity bump on the template + activityBump := time.Hour + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + ctr.ActivityBumpMillis = ptr.Ref[int64](activityBump.Milliseconds()) + }) + dbgen.Preset(t, db, database.InsertPresetParams{ + ID: presetID, + TemplateVersionID: version.ID, + DesiredInstances: sql.NullInt32{Int32: 1, Valid: true}, + }) + // Given: a prebuild with an expired Deadline + deadline := clock.Now().Add(-30 * time.Minute) + wb := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: database.PrebuildsSystemUserID, + TemplateID: template.ID, + }).Seed(database.WorkspaceBuild{ + TemplateVersionID: version.ID, + TemplateVersionPresetID: uuid.NullUUID{ + UUID: presetID, + Valid: true, + }, + Deadline: deadline, + }).WithAgent(func(agent []*proto.Agent) []*proto.Agent { + return agent + }).Do() + + // Mark the prebuilt workspace's agent as ready so the prebuild can be claimed + // nolint:gocritic + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong)) + agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(wb.AgentToken)) + require.NoError(t, err) + err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agent.WorkspaceAgent.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + }) + require.NoError(t, err) + + // Given: a prebuilt workspace with a Deadline and an empty MaxDeadline + prebuild := coderdtest.MustWorkspace(t, client, wb.Workspace.ID) + require.Equal(t, deadline.UTC(), prebuild.LatestBuild.Deadline.Time.UTC()) + require.Zero(t, prebuild.LatestBuild.MaxDeadline) + + // When: activity bump is applied to an unclaimed prebuild + workspacestats.ActivityBumpWorkspace(ctx, log, db, prebuild.ID, clock.Now().Add(10*time.Hour)) + + // Then: prebuild Deadline/MaxDeadline remain unchanged + prebuild = coderdtest.MustWorkspace(t, client, wb.Workspace.ID) + require.Equal(t, deadline.UTC(), prebuild.LatestBuild.Deadline.Time.UTC()) + require.Zero(t, prebuild.LatestBuild.MaxDeadline) + + // Given: the prebuilt workspace is claimed by a user + user, err := client.User(ctx, "testUser") + require.NoError(t, err) + claimedWorkspace, err := client.CreateUserWorkspace(ctx, user.ID.String(), codersdk.CreateWorkspaceRequest{ + TemplateVersionID: version.ID, + TemplateVersionPresetID: presetID, + Name: coderdtest.RandomUsername(t), + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, claimedWorkspace.LatestBuild.ID) + workspace := coderdtest.MustWorkspace(t, client, claimedWorkspace.ID) + require.Equal(t, prebuild.ID, workspace.ID) + // Claimed workspaces have an empty Deadline and MaxDeadline + require.Zero(t, workspace.LatestBuild.Deadline) + require.Zero(t, workspace.LatestBuild.MaxDeadline) + + // Given: the claimed workspace has an expired Deadline + err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ + ID: workspace.LatestBuild.ID, + Deadline: deadline, + UpdatedAt: clock.Now(), + }) + require.NoError(t, err) + workspace = coderdtest.MustWorkspace(t, client, claimedWorkspace.ID) + + // When: activity bump is applied to a claimed prebuild + workspacestats.ActivityBumpWorkspace(ctx, log, db, workspace.ID, clock.Now().Add(10*time.Hour)) + + // Then: Deadline is extended by the activity bump, MaxDeadline remains unset + workspace = coderdtest.MustWorkspace(t, client, claimedWorkspace.ID) + require.WithinDuration(t, clock.Now().Add(activityBump).UTC(), workspace.LatestBuild.Deadline.Time.UTC(), testutil.WaitMedium) + require.Zero(t, workspace.LatestBuild.MaxDeadline) +} + // TestWorkspaceTemplateParamsChange tests a workspace with a parameter that // validation changes on apply. The params used in create workspace are invalid // according to the static params on import. diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index b0051551a0f3d..72f5a4291c40e 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -75,7 +75,7 @@ func (c *Client) RequestIgnoreRedirects(ctx context.Context, method, path string // DialWorkspaceAgent calls the underlying codersdk.Client's DialWorkspaceAgent // method. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn *workspacesdk.AgentConn, err error) { +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn workspacesdk.AgentConn, err error) { return workspacesdk.New(c.SDKClient).DialAgent(ctx, agentID, options) } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index dba21cc24e3a0..b0990d9cb11a6 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -89,7 +89,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { // Ensure DERP for completeness. if r.cfg.ConnectionMode == ConnectionModeDerp { - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) } @@ -133,7 +133,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { return nil } -func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDisco(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const pingAttempts = 10 const pingDelay = 1 * time.Second @@ -165,7 +165,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentC return nil } -func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDirectConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const directConnectionAttempts = 30 const directConnectionDelay = 1 * time.Second @@ -174,7 +174,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac for i := 0; i < directConnectionAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts) - status := conn.Status() + status := conn.TailnetConn().Status() var err error if len(status.Peers()) != 1 { @@ -207,7 +207,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac return nil } -func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func verifyConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const verifyConnectionAttempts = 30 const verifyConnectionDelay = 1 * time.Second @@ -249,7 +249,7 @@ func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Ag return nil } -func performInitialConnections(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, specs []Connection) error { +func performInitialConnections(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, specs []Connection) error { if len(specs) == 0 { return nil } @@ -287,7 +287,7 @@ func performInitialConnections(ctx context.Context, logs io.Writer, conn *worksp return nil } -func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { +func holdConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -364,7 +364,7 @@ func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Agen return nil } -func agentHTTPClient(conn *workspacesdk.AgentConn) *http.Client { +func agentHTTPClient(conn workspacesdk.AgentConn) *http.Client { return &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, diff --git a/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.stories.tsx b/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.stories.tsx index 2073ff5bf2a7f..24b65093d384b 100644 --- a/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.stories.tsx +++ b/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.stories.tsx @@ -26,6 +26,23 @@ type Story = StoryObj; export const Default: Story = {}; +export const ZeroUsage: Story = { + args: { + managedAgentFeature: { + enabled: true, + actual: 0, + soft_limit: 60000, + limit: 120000, + usage_period: { + start: "February 27, 2025", + end: "February 27, 2026", + issued_at: "February 27, 2025", + }, + entitlement: "entitled", + }, + }, +}; + export const NearLimit: Story = { args: { managedAgentFeature: { diff --git a/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.tsx b/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.tsx index 08da49c96b710..022627c11dc02 100644 --- a/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.tsx +++ b/site/src/pages/DeploymentSettingsPage/LicensesSettingsPage/ManagedAgentsConsumption.tsx @@ -44,11 +44,16 @@ export const ManagedAgentsConsumption: FC = ({ const startDate = managedAgentFeature.usage_period?.start; const endDate = managedAgentFeature.usage_period?.end; - if (!usage || usage < 0) { + if (usage === undefined || usage < 0) { return ; } - if (!included || included < 0 || !limit || limit < 0) { + if ( + included === undefined || + included < 0 || + limit === undefined || + limit < 0 + ) { return ; } diff --git a/support/support.go b/support/support.go index 2fa41ce7eca8c..31080faaf023b 100644 --- a/support/support.go +++ b/support/support.go @@ -390,7 +390,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L if err := conn.Close(); err != nil { log.Error(ctx, "failed to close agent connection", slog.Error(err)) } - <-conn.Closed() + <-conn.TailnetConn().Closed() } eg.Go(func() error { @@ -399,7 +399,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L return xerrors.Errorf("create request: %w", err) } rr := httptest.NewRecorder() - conn.MagicsockServeHTTPDebug(rr, req) + conn.TailnetConn().MagicsockServeHTTPDebug(rr, req) a.ClientMagicsockHTML = rr.Body.Bytes() return nil })