From 5e84d257b7571d48686e81a0102971df9fdeb7f6 Mon Sep 17 00:00:00 2001 From: Danielle Maywood Date: Wed, 20 Aug 2025 10:00:44 +0100 Subject: [PATCH 01/11] refactor: convert workspacesdk.AgentConn to an interface (#19392) Fixes https://github.com/coder/internal/issues/907 We convert `workspacesdk.AgentConn` to an interface and generate a mock for it. This allows writing `coderd` tests that rely on the agent's HTTP api to not have to set up an entire tailnet networking stack. --- Makefile | 8 +- agent/agent_test.go | 14 +- cli/ping.go | 4 +- cli/portforward.go | 2 +- cli/speedtest.go | 4 +- cli/ssh.go | 12 +- coderd/coderd.go | 6 +- coderd/tailnet.go | 4 +- coderd/workspaceagents_internal_test.go | 186 +++++++++ coderd/workspaceagents_test.go | 84 +--- coderd/workspaceapps/proxy.go | 2 +- codersdk/workspacesdk/agentconn.go | 83 ++-- .../agentconnmock/agentconnmock.go | 373 ++++++++++++++++++ codersdk/workspacesdk/agentconnmock/doc.go | 4 + codersdk/workspacesdk/workspacesdk.go | 2 +- enterprise/wsproxy/wsproxysdk/wsproxysdk.go | 2 +- scaletest/agentconn/run.go | 16 +- support/support.go | 4 +- 18 files changed, 667 insertions(+), 143 deletions(-) create mode 100644 coderd/workspaceagents_internal_test.go create mode 100644 codersdk/workspacesdk/agentconnmock/agentconnmock.go create mode 100644 codersdk/workspacesdk/agentconnmock/doc.go diff --git a/Makefile b/Makefile index 9040a891700e1..a5341ee79f753 100644 --- a/Makefile +++ b/Makefile @@ -636,7 +636,8 @@ GEN_FILES := \ coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ - coderd/httpmw/loggermw/loggermock/loggermock.go + coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go # all gen targets should be added here and to gen/mark-fresh gen: gen/db gen/golden-files $(GEN_FILES) @@ -686,6 +687,7 @@ gen/mark-fresh: agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go \ " for file in $$files; do @@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g go generate ./coderd/httpmw/loggermw/loggermock/ touch "$@" +codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go + go generate ./codersdk/workspacesdk/agentconnmock/ + touch "$@" + agent/agentcontainers/dcspec/dcspec_gen.go: \ node_modules/.installed \ agent/agentcontainers/dcspec/devContainer.base.schema.json \ diff --git a/agent/agent_test.go b/agent/agent_test.go index 52d8cfc09d573..2425fd81a0ead 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2750,9 +2750,9 @@ func TestAgent_Dial(t *testing.T) { switch l.Addr().Network() { case "tcp": - conn, err = agentConn.Conn.DialContextTCP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp) case "udp": - conn, err = agentConn.Conn.DialContextUDP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp) default: t.Fatalf("unknown network: %s", l.Addr().Network()) } @@ -2811,7 +2811,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { }) // Setup a client connection. - newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn { + newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()}, DERPMap: derpMap, @@ -2891,13 +2891,13 @@ func TestAgent_UpdatedDERP(t *testing.T) { // Connect from a second client and make sure it uses the new DERP map. conn2 := newClientConn(newDerpMap, "client2") - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) t.Log("conn2 got the new DERPMap") // If the first client gets a DERP map update, it should be able to // reconnect just fine. - conn1.SetDERPMap(newDerpMap) - require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) + conn1.TailnetConn().SetDERPMap(newDerpMap) + require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs()) t.Log("set the new DERPMap on conn1") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -3264,7 +3264,7 @@ func setupSSHSessionOnPort( } func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( - *workspacesdk.AgentConn, + workspacesdk.AgentConn, *agenttest.Client, <-chan *proto.Stats, afero.Fs, diff --git a/cli/ping.go b/cli/ping.go index 0b9fde5c62eb8..29af06affeaee 100644 --- a/cli/ping.go +++ b/cli/ping.go @@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command { } defer conn.Close() - derpMap := conn.DERPMap() + derpMap := conn.TailnetConn().DERPMap() diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second) defer diagCancel() @@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command { // Silent ping to determine whether we should show diags _, didP2p, _, _ := conn.Ping(ctx) - ni := conn.GetNetInfo() + ni := conn.TailnetConn().GetNetInfo() connDiags := cliui.ConnDiags{ DisableDirect: r.disableDirect, LocalNetInfo: ni, diff --git a/cli/portforward.go b/cli/portforward.go index 59c1f5827b06f..1b055d9e4362e 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command { func listenAndPortForward( ctx context.Context, inv *serpent.Invocation, - conn *workspacesdk.AgentConn, + conn workspacesdk.AgentConn, wg *sync.WaitGroup, spec portForwardSpec, logger slog.Logger, diff --git a/cli/speedtest.go b/cli/speedtest.go index 3827b45125842..86d0e6a9ee63c 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command { if err != nil { continue } - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { continue } @@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command { outputResult.Intervals[i] = interval } } - conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) + conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) out, err := formatter.Format(inv.Context(), outputResult) if err != nil { return err diff --git a/cli/ssh.go b/cli/ssh.go index bc2bb24235ad2..a2f0db7327bef 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command { } err = sshSession.Wait() - conn.SendDisconnectedTelemetry() + conn.TailnetConn().SendDisconnectedTelemetry() if err != nil { if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { // Clear the error since it's not useful beyond @@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName { func setStatsCallback( ctx context.Context, - agentConn *workspacesdk.AgentConn, + agentConn workspacesdk.AgentConn, logger slog.Logger, networkInfoDir string, networkInfoInterval time.Duration, @@ -1437,7 +1437,7 @@ func setStatsCallback( now := time.Now() cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{}) - agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb) + agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb) return errCh, nil } @@ -1451,13 +1451,13 @@ type sshNetworkStats struct { UsingCoderConnect bool `json:"using_coder_connect"` } -func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { +func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { latency, p2p, pingResult, err := agentConn.Ping(ctx) if err != nil { return nil, err } - node := agentConn.Node() - derpMap := agentConn.DERPMap() + node := agentConn.TailnetConn().Node() + derpMap := agentConn.TailnetConn().DERPMap() totalRx := uint64(0) totalTx := uint64(0) diff --git a/coderd/coderd.go b/coderd/coderd.go index a934536c0aef0..8ab204f8a31ef 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -325,6 +325,9 @@ func New(options *Options) *API { }) } + if options.PrometheusRegistry == nil { + options.PrometheusRegistry = prometheus.NewRegistry() + } if options.Authorizer == nil { options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry) if buildinfo.IsDev() { @@ -381,9 +384,6 @@ func New(options *Options) *API { if options.FilesRateLimit == 0 { options.FilesRateLimit = 12 } - if options.PrometheusRegistry == nil { - options.PrometheusRegistry = prometheus.NewRegistry() - } if options.Clock == nil { options.Clock = quartz.NewReal() } diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 172edea95a586..cdcf657fe732d 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) ( }, nil } -func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) { +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { var ( - conn *workspacesdk.AgentConn + conn workspacesdk.AgentConn ret func() ) diff --git a/coderd/workspaceagents_internal_test.go b/coderd/workspaceagents_internal_test.go new file mode 100644 index 0000000000000..c7520f05ab503 --- /dev/null +++ b/coderd/workspaceagents_internal_test.go @@ -0,0 +1,186 @@ +package coderd + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +type fakeAgentProvider struct { + agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) +} + +func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy { + panic("unimplemented") +} + +func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + if f.agentConn != nil { + return f.agentConn(ctx, agentID) + } + + panic("unimplemented") +} + +func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + panic("unimplemented") +} + +func (fakeAgentProvider) Close() error { + return nil +} + +func TestWatchAgentContainers(t *testing.T) { + t.Parallel() + + t.Run("WebSocketClosesProperly", func(t *testing.T) { + t.Parallel() + + // This test ensures that the agent containers `/watch` websocket can gracefully + // handle the underlying websocket unexpectedly closing. This test was created in + // response to this issue: https://github.com/coder/coder/issues/19372 + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd") + + mCtrl = gomock.NewController(t) + mDB = dbmock.NewMockStore(mCtrl) + mCoordinator = tailnettest.NewMockCoordinator(mCtrl) + mAgentConn = agentconnmock.NewMockAgentConn(mCtrl) + + fAgentProvider = fakeAgentProvider{ + agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + return mAgentConn, func() {}, nil + }, + } + + workspaceID = uuid.New() + agentID = uuid.New() + resourceID = uuid.New() + jobID = uuid.New() + buildID = uuid.New() + + containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse) + + r = chi.NewMux() + + api = API{ + ctx: ctx, + Options: &Options{ + AgentInactiveDisconnectTimeout: testutil.WaitShort, + Database: mDB, + Logger: logger, + DeploymentValues: &codersdk.DeploymentValues{}, + TailnetCoordinator: tailnettest.NewFakeCoordinator(), + }, + } + ) + + var tailnetCoordinator tailnet.Coordinator = mCoordinator + api.TailnetCoordinator.Store(&tailnetCoordinator) + api.agentProvider = fAgentProvider + + // Setup: Allow `ExtractWorkspaceAgentParams` to complete. + mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{ + ID: agentID, + ResourceID: resourceID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + }, nil) + mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{ + ID: resourceID, + JobID: jobID, + }, nil) + mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{ + ID: jobID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }, nil) + mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{ + WorkspaceID: workspaceID, + ID: buildID, + }, nil) + + // And: Allow `db2dsk.WorkspaceAgent` to complete. + mCoordinator.EXPECT().Node(gomock.Any()).Return(nil) + + // And: Allow `WatchContainers` to be called, returing our `containersCh` channel. + mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()). + Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil) + + // And: We mount the HTTP Handler + r.With(httpmw.ExtractWorkspaceAgentParam(mDB)). + Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers) + + // Given: We create the HTTP server + srv := httptest.NewServer(r) + defer srv.Close() + + // And: Dial the WebSocket + wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil) + require.NoError(t, err) + if resp.Body != nil { + defer resp.Body.Close() + } + + // And: Create a streaming decoder + decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger) + defer decoder.Close() + decodeCh := decoder.Chan() + + // And: We can successfully send through the channel. + testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{{ + ID: "test-container-id", + }}, + }) + + // And: Receive the data. + containerResp := testutil.RequireReceive(ctx, t, decodeCh) + require.Len(t, containerResp.Containers, 1) + require.Equal(t, "test-container-id", containerResp.Containers[0].ID) + + // When: We close the `containersCh` + close(containersCh) + + // Then: We expect `decodeCh` to be closed. + select { + case <-ctx.Done(): + t.Fail() + + case _, ok := <-decodeCh: + require.False(t, ok, "channel is expected to be closed") + } + }) +} diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 1855ed8a7e8fc..ac58df1b772ad 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -593,7 +593,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) { _ = agenttest.New(t, client.URL, r.AgentToken) resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) - conn, err := func() (*workspacesdk.AgentConn, error) { + conn, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -1574,82 +1574,6 @@ func TestWatchWorkspaceAgentDevcontainers(t *testing.T) { } } }) - - t.Run("PayloadTooLarge", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitSuperLong) - logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - mClock = quartz.NewMock(t) - updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop") - mCtrl = gomock.NewController(t) - mCCLI = acmock.NewMockContainerCLI(mCtrl) - - client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &logger}) - user = coderdtest.CreateFirstUser(t, client) - r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { - return agents - }).Do() - ) - - // WebSocket limit is 4MiB, so we want to ensure we create _more_ than 4MiB worth of payload. - // Creating 20,000 fake containers creates a payload of roughly 7MiB. - var fakeContainers []codersdk.WorkspaceAgentContainer - for range 20_000 { - fakeContainers = append(fakeContainers, codersdk.WorkspaceAgentContainer{ - CreatedAt: time.Now(), - ID: uuid.NewString(), - FriendlyName: uuid.NewString(), - Image: "busybox:latest", - Labels: map[string]string{ - agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project", - agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json", - }, - Running: false, - Ports: []codersdk.WorkspaceAgentContainerPort{}, - Status: string(codersdk.WorkspaceAgentDevcontainerStatusRunning), - Volumes: map[string]string{}, - }) - } - - mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: fakeContainers}, nil) - mCCLI.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() - - _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { - o.Logger = logger.Named("agent") - o.Devcontainers = true - o.DevcontainerAPIOptions = []agentcontainers.Option{ - agentcontainers.WithClock(mClock), - agentcontainers.WithContainerCLI(mCCLI), - agentcontainers.WithWatcher(watcher.NewNoop()), - } - }) - - resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() - require.Len(t, resources, 1, "expected one resource") - require.Len(t, resources[0].Agents, 1, "expected one agent") - agentID := resources[0].Agents[0].ID - - updaterTickerTrap.MustWait(ctx).MustRelease(ctx) - defer updaterTickerTrap.Close() - - containers, closer, err := client.WatchWorkspaceAgentContainers(ctx, agentID) - require.NoError(t, err) - defer func() { - closer.Close() - }() - - select { - case <-ctx.Done(): - t.Fail() - case _, ok := <-containers: - require.False(t, ok) - } - }) } func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { @@ -2497,7 +2421,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { agentID := resources[0].Agents[0].ID // Connect from a client. - conn1, err := func() (*workspacesdk.AgentConn, error) { + conn1, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -2538,7 +2462,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { // Wait for the DERP map to be updated on the existing client. require.Eventually(t, func() bool { - regionIDs := conn1.Conn.DERPMap().RegionIDs() + regionIDs := conn1.TailnetConn().DERPMap().RegionIDs() return len(regionIDs) == 1 && regionIDs[0] == 2 }, testutil.WaitLong, testutil.IntervalFast) @@ -2555,7 +2479,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { defer conn2.Close() ok = conn2.AwaitReachable(ctx) require.True(t, ok) - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) } func TestWorkspaceAgentExternalAuthListen(t *testing.T) { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 2f1294558f67a..002bb1ea05aae 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -74,7 +74,7 @@ type AgentProvider interface { ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy // AgentConn returns a new connection to the specified agent. - AgentConn(ctx context.Context, agentID uuid.UUID) (_ *workspacesdk.AgentConn, release func(), _ error) + AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 9c65b7ee9a1e1..bb929c9ba2a04 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -34,8 +34,8 @@ import ( // to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the // conn is shared and closing it is undesirable, you may return ErrNoClose from // opts.CloseFunc. This will ensure the underlying conn is not closed. -func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { - return &AgentConn{ +func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) AgentConn { + return &agentConn{ Conn: conn, opts: opts, } @@ -43,23 +43,54 @@ func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { // AgentConn represents a connection to a workspace agent. // @typescript-ignore AgentConn -type AgentConn struct { +type AgentConn interface { + TailnetConn() *tailnet.Conn + + AwaitReachable(ctx context.Context) bool + Close() error + DebugLogs(ctx context.Context) ([]byte, error) + DebugMagicsock(ctx context.Context) ([]byte, error) + DebugManifest(ctx context.Context) ([]byte, error) + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) + GetPeerDiagnostics() tailnet.PeerDiagnostics + ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) + ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) + Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) + Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) + PrometheusMetrics(ctx context.Context) ([]byte, error) + ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) + RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) + SSH(ctx context.Context) (*gonet.TCPConn, error) + SSHClient(ctx context.Context) (*ssh.Client, error) + SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) + SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) + Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) + WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) +} + +// AgentConn represents a connection to a workspace agent. +// @typescript-ignore AgentConn +type agentConn struct { *tailnet.Conn opts AgentConnOptions } +func (c *agentConn) TailnetConn() *tailnet.Conn { + return c.Conn +} + // @typescript-ignore AgentConnOptions type AgentConnOptions struct { AgentID uuid.UUID CloseFunc func() error } -func (c *AgentConn) agentAddress() netip.Addr { +func (c *agentConn) agentAddress() netip.Addr { return tailnet.TailscaleServicePrefix.AddrFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. -func (c *AgentConn) AwaitReachable(ctx context.Context) bool { +func (c *agentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -68,7 +99,7 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool { // Ping pings the agent and returns the round-trip time. // The bool returns true if the ping was made P2P. -func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { +func (c *agentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -76,7 +107,7 @@ func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.Pi } // Close ends the connection to the workspace agent. -func (c *AgentConn) Close() error { +func (c *agentConn) Close() error { var cerr error if c.opts.CloseFunc != nil { cerr = c.opts.CloseFunc() @@ -131,7 +162,7 @@ type ReconnectingPTYRequest struct { // ReconnectingPTY spawns a new reconnecting terminal session. // `ReconnectingPTYRequest` should be JSON marshaled and written to the returned net.Conn. // Raw terminal output will be read from the returned net.Conn. -func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { +func (c *agentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -171,13 +202,13 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w // SSH pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent. -func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { +func (c *agentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { return c.SSHOnPort(ctx, AgentSSHPort) } // SSHOnPort pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent on the specified port. -func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { +func (c *agentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -190,12 +221,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, } // SSHClient calls SSH to create a client -func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { +func (c *agentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { +func (c *agentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -218,7 +249,7 @@ func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Clie } // Speedtest runs a speedtest against the workspace agent. -func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { +func (c *agentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -242,7 +273,7 @@ func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction // DialContext dials the address provided in the workspace agent. // The network must be "tcp" or "udp". -func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *agentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -265,7 +296,7 @@ func (c *AgentConn) DialContext(ctx context.Context, network string, addr string } // ListeningPorts lists the ports that are currently in use by the workspace. -func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { +func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil) @@ -282,7 +313,7 @@ func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent } // Netcheck returns a network check report from the workspace agent. -func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { +func (c *agentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/netcheck", nil) @@ -299,7 +330,7 @@ func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport } // DebugMagicsock makes a request to the workspace agent's magicsock debug endpoint. -func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/magicsock", nil) @@ -319,7 +350,7 @@ func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { // DebugManifest returns the agent's in-memory manifest. Unfortunately this must // be returns as a []byte to avoid an import cycle. -func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugManifest(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/manifest", nil) @@ -338,7 +369,7 @@ func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { } // DebugLogs returns up to the last 10MB of `/tmp/coder-agent.log` -func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugLogs(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/logs", nil) @@ -357,7 +388,7 @@ func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { } // PrometheusMetrics returns a response from the agent's prometheus metrics endpoint -func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { +func (c *agentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/prometheus", nil) @@ -376,7 +407,7 @@ func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { } // ListContainers returns a response from the agent's containers endpoint -func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { +func (c *agentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/containers", nil) @@ -391,7 +422,7 @@ func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } -func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { +func (c *agentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -427,7 +458,7 @@ func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<- // RecreateDevcontainer recreates a devcontainer with the given container. // This is a blocking call and will wait for the container to be recreated. -func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { +func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/containers/devcontainers/"+devcontainerID+"/recreate", nil) @@ -446,7 +477,7 @@ func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str } // apiRequest makes a request to the workspace agent's HTTP API server. -func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -463,7 +494,7 @@ func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io // apiClient returns an HTTP client that can be used to make // requests to the workspace agent's HTTP API server. -func (c *AgentConn) apiClient() *http.Client { +func (c *agentConn) apiClient() *http.Client { return &http.Client{ Transport: &http.Transport{ // Disable keep alives as we're usually only making a single @@ -504,6 +535,6 @@ func (c *AgentConn) apiClient() *http.Client { } } -func (c *AgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { +func (c *agentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { return c.Conn.GetPeerDiagnostics(c.opts.AgentID) } diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go new file mode 100644 index 0000000000000..eb55bb27938c0 --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -0,0 +1,373 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: .. (interfaces: AgentConn) +// +// Generated by this command: +// +// mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn +// + +// Package agentconnmock is a generated GoMock package. +package agentconnmock + +import ( + context "context" + io "io" + net "net" + reflect "reflect" + time "time" + + slog "cdr.dev/slog" + codersdk "github.com/coder/coder/v2/codersdk" + healthsdk "github.com/coder/coder/v2/codersdk/healthsdk" + workspacesdk "github.com/coder/coder/v2/codersdk/workspacesdk" + tailnet "github.com/coder/coder/v2/tailnet" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" + ssh "golang.org/x/crypto/ssh" + gonet "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + ipnstate "tailscale.com/ipn/ipnstate" + speedtest "tailscale.com/net/speedtest" +) + +// MockAgentConn is a mock of AgentConn interface. +type MockAgentConn struct { + ctrl *gomock.Controller + recorder *MockAgentConnMockRecorder + isgomock struct{} +} + +// MockAgentConnMockRecorder is the mock recorder for MockAgentConn. +type MockAgentConnMockRecorder struct { + mock *MockAgentConn +} + +// NewMockAgentConn creates a new mock instance. +func NewMockAgentConn(ctrl *gomock.Controller) *MockAgentConn { + mock := &MockAgentConn{ctrl: ctrl} + mock.recorder = &MockAgentConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAgentConn) EXPECT() *MockAgentConnMockRecorder { + return m.recorder +} + +// AwaitReachable mocks base method. +func (m *MockAgentConn) AwaitReachable(ctx context.Context) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AwaitReachable", ctx) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AwaitReachable indicates an expected call of AwaitReachable. +func (mr *MockAgentConnMockRecorder) AwaitReachable(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitReachable", reflect.TypeOf((*MockAgentConn)(nil).AwaitReachable), ctx) +} + +// Close mocks base method. +func (m *MockAgentConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAgentConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgentConn)(nil).Close)) +} + +// DebugLogs mocks base method. +func (m *MockAgentConn) DebugLogs(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugLogs", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugLogs indicates an expected call of DebugLogs. +func (mr *MockAgentConnMockRecorder) DebugLogs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugLogs", reflect.TypeOf((*MockAgentConn)(nil).DebugLogs), ctx) +} + +// DebugMagicsock mocks base method. +func (m *MockAgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugMagicsock", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugMagicsock indicates an expected call of DebugMagicsock. +func (mr *MockAgentConnMockRecorder) DebugMagicsock(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugMagicsock", reflect.TypeOf((*MockAgentConn)(nil).DebugMagicsock), ctx) +} + +// DebugManifest mocks base method. +func (m *MockAgentConn) DebugManifest(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugManifest", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugManifest indicates an expected call of DebugManifest. +func (mr *MockAgentConnMockRecorder) DebugManifest(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugManifest", reflect.TypeOf((*MockAgentConn)(nil).DebugManifest), ctx) +} + +// DialContext mocks base method. +func (m *MockAgentConn) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialContext", ctx, network, addr) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialContext indicates an expected call of DialContext. +func (mr *MockAgentConnMockRecorder) DialContext(ctx, network, addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockAgentConn)(nil).DialContext), ctx, network, addr) +} + +// GetPeerDiagnostics mocks base method. +func (m *MockAgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerDiagnostics") + ret0, _ := ret[0].(tailnet.PeerDiagnostics) + return ret0 +} + +// GetPeerDiagnostics indicates an expected call of GetPeerDiagnostics. +func (mr *MockAgentConnMockRecorder) GetPeerDiagnostics() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerDiagnostics", reflect.TypeOf((*MockAgentConn)(nil).GetPeerDiagnostics)) +} + +// ListContainers mocks base method. +func (m *MockAgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListContainers", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListContainers indicates an expected call of ListContainers. +func (mr *MockAgentConnMockRecorder) ListContainers(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListContainers", reflect.TypeOf((*MockAgentConn)(nil).ListContainers), ctx) +} + +// ListeningPorts mocks base method. +func (m *MockAgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListeningPorts", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListeningPortsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListeningPorts indicates an expected call of ListeningPorts. +func (mr *MockAgentConnMockRecorder) ListeningPorts(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListeningPorts", reflect.TypeOf((*MockAgentConn)(nil).ListeningPorts), ctx) +} + +// Netcheck mocks base method. +func (m *MockAgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Netcheck", ctx) + ret0, _ := ret[0].(healthsdk.AgentNetcheckReport) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Netcheck indicates an expected call of Netcheck. +func (mr *MockAgentConnMockRecorder) Netcheck(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Netcheck", reflect.TypeOf((*MockAgentConn)(nil).Netcheck), ctx) +} + +// Ping mocks base method. +func (m *MockAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(time.Duration) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(*ipnstate.PingResult) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Ping indicates an expected call of Ping. +func (mr *MockAgentConnMockRecorder) Ping(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockAgentConn)(nil).Ping), ctx) +} + +// PrometheusMetrics mocks base method. +func (m *MockAgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrometheusMetrics", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PrometheusMetrics indicates an expected call of PrometheusMetrics. +func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrometheusMetrics", reflect.TypeOf((*MockAgentConn)(nil).PrometheusMetrics), ctx) +} + +// ReconnectingPTY mocks base method. +func (m *MockAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...workspacesdk.AgentReconnectingPTYInitOption) (net.Conn, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, id, height, width, command} + for _, a := range initOpts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ReconnectingPTY", varargs...) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReconnectingPTY indicates an expected call of ReconnectingPTY. +func (mr *MockAgentConnMockRecorder) ReconnectingPTY(ctx, id, height, width, command any, initOpts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, id, height, width, command}, initOpts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReconnectingPTY", reflect.TypeOf((*MockAgentConn)(nil).ReconnectingPTY), varargs...) +} + +// RecreateDevcontainer mocks base method. +func (m *MockAgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecreateDevcontainer", ctx, devcontainerID) + ret0, _ := ret[0].(codersdk.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RecreateDevcontainer indicates an expected call of RecreateDevcontainer. +func (mr *MockAgentConnMockRecorder) RecreateDevcontainer(ctx, devcontainerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecreateDevcontainer", reflect.TypeOf((*MockAgentConn)(nil).RecreateDevcontainer), ctx, devcontainerID) +} + +// SSH mocks base method. +func (m *MockAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSH", ctx) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSH indicates an expected call of SSH. +func (mr *MockAgentConnMockRecorder) SSH(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSH", reflect.TypeOf((*MockAgentConn)(nil).SSH), ctx) +} + +// SSHClient mocks base method. +func (m *MockAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClient", ctx) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClient indicates an expected call of SSHClient. +func (mr *MockAgentConnMockRecorder) SSHClient(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClient", reflect.TypeOf((*MockAgentConn)(nil).SSHClient), ctx) +} + +// SSHClientOnPort mocks base method. +func (m *MockAgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClientOnPort", ctx, port) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClientOnPort indicates an expected call of SSHClientOnPort. +func (mr *MockAgentConnMockRecorder) SSHClientOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClientOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHClientOnPort), ctx, port) +} + +// SSHOnPort mocks base method. +func (m *MockAgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHOnPort", ctx, port) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHOnPort indicates an expected call of SSHOnPort. +func (mr *MockAgentConnMockRecorder) SSHOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHOnPort), ctx, port) +} + +// Speedtest mocks base method. +func (m *MockAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Speedtest", ctx, direction, duration) + ret0, _ := ret[0].([]speedtest.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Speedtest indicates an expected call of Speedtest. +func (mr *MockAgentConnMockRecorder) Speedtest(ctx, direction, duration any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Speedtest", reflect.TypeOf((*MockAgentConn)(nil).Speedtest), ctx, direction, duration) +} + +// TailnetConn mocks base method. +func (m *MockAgentConn) TailnetConn() *tailnet.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TailnetConn") + ret0, _ := ret[0].(*tailnet.Conn) + return ret0 +} + +// TailnetConn indicates an expected call of TailnetConn. +func (mr *MockAgentConnMockRecorder) TailnetConn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TailnetConn", reflect.TypeOf((*MockAgentConn)(nil).TailnetConn)) +} + +// WatchContainers mocks base method. +func (m *MockAgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WatchContainers", ctx, logger) + ret0, _ := ret[0].(<-chan codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(io.Closer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// WatchContainers indicates an expected call of WatchContainers. +func (mr *MockAgentConnMockRecorder) WatchContainers(ctx, logger any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchContainers", reflect.TypeOf((*MockAgentConn)(nil).WatchContainers), ctx, logger) +} diff --git a/codersdk/workspacesdk/agentconnmock/doc.go b/codersdk/workspacesdk/agentconnmock/doc.go new file mode 100644 index 0000000000000..a795b21a4a89d --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/doc.go @@ -0,0 +1,4 @@ +// Package agentconnmock contains a mock implementation of workspacesdk.AgentConn for use in tests. +package agentconnmock + +//go:generate mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 9f587cf5267a8..ddaec06388238 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -202,7 +202,7 @@ func (c *Client) RewriteDERPMap(derpMap *tailcfg.DERPMap) { tailnet.RewriteDERPMapDefaultRelay(context.Background(), c.client.Logger(), derpMap, c.client.URL) } -func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn *AgentConn, err error) { +func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn AgentConn, err error) { if options == nil { options = &DialAgentOptions{} } diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index b0051551a0f3d..72f5a4291c40e 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -75,7 +75,7 @@ func (c *Client) RequestIgnoreRedirects(ctx context.Context, method, path string // DialWorkspaceAgent calls the underlying codersdk.Client's DialWorkspaceAgent // method. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn *workspacesdk.AgentConn, err error) { +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn workspacesdk.AgentConn, err error) { return workspacesdk.New(c.SDKClient).DialAgent(ctx, agentID, options) } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index dba21cc24e3a0..b0990d9cb11a6 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -89,7 +89,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { // Ensure DERP for completeness. if r.cfg.ConnectionMode == ConnectionModeDerp { - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) } @@ -133,7 +133,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { return nil } -func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDisco(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const pingAttempts = 10 const pingDelay = 1 * time.Second @@ -165,7 +165,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentC return nil } -func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDirectConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const directConnectionAttempts = 30 const directConnectionDelay = 1 * time.Second @@ -174,7 +174,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac for i := 0; i < directConnectionAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts) - status := conn.Status() + status := conn.TailnetConn().Status() var err error if len(status.Peers()) != 1 { @@ -207,7 +207,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac return nil } -func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func verifyConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const verifyConnectionAttempts = 30 const verifyConnectionDelay = 1 * time.Second @@ -249,7 +249,7 @@ func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Ag return nil } -func performInitialConnections(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, specs []Connection) error { +func performInitialConnections(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, specs []Connection) error { if len(specs) == 0 { return nil } @@ -287,7 +287,7 @@ func performInitialConnections(ctx context.Context, logs io.Writer, conn *worksp return nil } -func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { +func holdConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -364,7 +364,7 @@ func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Agen return nil } -func agentHTTPClient(conn *workspacesdk.AgentConn) *http.Client { +func agentHTTPClient(conn workspacesdk.AgentConn) *http.Client { return &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, diff --git a/support/support.go b/support/support.go index 2fa41ce7eca8c..31080faaf023b 100644 --- a/support/support.go +++ b/support/support.go @@ -390,7 +390,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L if err := conn.Close(); err != nil { log.Error(ctx, "failed to close agent connection", slog.Error(err)) } - <-conn.Closed() + <-conn.TailnetConn().Closed() } eg.Go(func() error { @@ -399,7 +399,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L return xerrors.Errorf("create request: %w", err) } rr := httptest.NewRecorder() - conn.MagicsockServeHTTPDebug(rr, req) + conn.TailnetConn().MagicsockServeHTTPDebug(rr, req) a.ClientMagicsockHTML = rr.Body.Bytes() return nil }) From f9a6adc70452e2ee4028e0025fd358ef2c9dbb5f Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 20 Aug 2025 11:02:53 +0200 Subject: [PATCH 02/11] feat: claim prebuilds based on workspace parameters instead of preset id (#19279) Closes https://github.com/coder/coder/issues/18356. This change finds and selects a matching preset if one was not chosen during workspace creation. This solidifies the relationship between presets and parameters. When a workspace is created without in explicitly chosen preset, it will now still be eligible to claim a prebuilt workspace if one is available. --- coderd/database/dbauthz/dbauthz.go | 8 + coderd/database/dbauthz/dbauthz_test.go | 16 + coderd/database/dbmetrics/querymetrics.go | 7 + coderd/database/dbmock/dbmock.go | 15 + coderd/database/querier.go | 5 + coderd/database/queries.sql.go | 41 +++ coderd/database/queries/prebuilds.sql | 27 ++ coderd/prebuilds/parameters.go | 42 +++ coderd/prebuilds/parameters_test.go | 198 ++++++++++++ coderd/workspacebuilds_test.go | 38 ++- coderd/workspaces.go | 40 ++- coderd/workspaces_test.go | 282 ++++++++++++++++++ coderd/wsbuilder/wsbuilder.go | 21 +- coderd/wsbuilder/wsbuilder_test.go | 22 ++ .../prebuilt-workspaces.md | 11 +- 15 files changed, 736 insertions(+), 37 deletions(-) create mode 100644 coderd/prebuilds/parameters.go create mode 100644 coderd/prebuilds/parameters_test.go diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 7ae1e4bbf9b73..a716c04adc030 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1837,6 +1837,14 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt) } +func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) { + _, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID) + if err != nil { + return uuid.Nil, err + } + return q.db.FindMatchingPresetID(ctx, arg) +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index e4639c3ae0adf..ce70a9b1f112a 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -4965,6 +4965,22 @@ func (s *MethodTestSuite) TestPrebuilds() { template, policy.ActionUse, ).Errors(sql.ErrNoRows) })) + s.Run("FindMatchingPresetID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + t1 := testutil.Fake(s.T(), faker, database.Template{}) + tv := testutil.Fake(s.T(), faker, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}}) + dbm.EXPECT().FindMatchingPresetID(gomock.Any(), database.FindMatchingPresetIDParams{ + TemplateVersionID: tv.ID, + ParameterNames: []string{"test"}, + ParameterValues: []string{"test"}, + }).Return(uuid.Nil, nil).AnyTimes() + dbm.EXPECT().GetTemplateVersionByID(gomock.Any(), tv.ID).Return(tv, nil).AnyTimes() + dbm.EXPECT().GetTemplateByID(gomock.Any(), t1.ID).Return(t1, nil).AnyTimes() + check.Args(database.FindMatchingPresetIDParams{ + TemplateVersionID: tv.ID, + ParameterNames: []string{"test"}, + ParameterValues: []string{"test"}, + }).Asserts(tv.RBACObject(t1), policy.ActionRead).Returns(uuid.Nil) + })) s.Run("GetPrebuildMetrics", s.Subtest(func(_ database.Store, check *expects) { check.Args(). Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 12133997bf2c9..11d21eab3b593 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -565,6 +565,13 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context. return r0, r1 } +func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.FindMatchingPresetID(ctx, arg) + m.queryLatencies.WithLabelValues("FindMatchingPresetID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { start := time.Now() apiKey, err := m.s.GetAPIKeyByID(ctx, id) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 96e277cd7af58..67244cf2b01e9 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1051,6 +1051,21 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt) } +// FindMatchingPresetID mocks base method. +func (m *MockStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindMatchingPresetID", ctx, arg) + ret0, _ := ret[0].(uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindMatchingPresetID indicates an expected call of FindMatchingPresetID. +func (mr *MockStoreMockRecorder) FindMatchingPresetID(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindMatchingPresetID", reflect.TypeOf((*MockStore)(nil).FindMatchingPresetID), ctx, arg) +} + // GetAPIKeyByID mocks base method. func (m *MockStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 8ac974ff20ee8..c490a04d2b653 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -137,6 +137,11 @@ type sqlcQuerier interface { FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error) FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error) + // FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. + // It returns the preset ID if a match is found, or NULL if no match is found. + // The query finds presets where all preset parameters are present in the provided parameters, + // and returns the preset with the most parameters (largest subset). + FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error) GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1b63e7c1e960f..70558724a664d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -7252,6 +7252,47 @@ func (q *sqlQuerier) CountInProgressPrebuilds(ctx context.Context) ([]CountInPro return items, nil } +const findMatchingPresetID = `-- name: FindMatchingPresetID :one +WITH provided_params AS ( + SELECT + unnest($1::text[]) AS name, + unnest($2::text[]) AS value +), +preset_matches AS ( + SELECT + tvp.id AS template_version_preset_id, + COALESCE(COUNT(tvpp.name), 0) AS total_preset_params, + COALESCE(COUNT(pp.name), 0) AS matching_params + FROM template_version_presets tvp + LEFT JOIN template_version_preset_parameters tvpp ON tvpp.template_version_preset_id = tvp.id + LEFT JOIN provided_params pp ON pp.name = tvpp.name AND pp.value = tvpp.value + WHERE tvp.template_version_id = $3 + GROUP BY tvp.id +) +SELECT pm.template_version_preset_id +FROM preset_matches pm +WHERE pm.total_preset_params = pm.matching_params -- All preset parameters must match +ORDER BY pm.total_preset_params DESC -- Return the preset with the most parameters +LIMIT 1 +` + +type FindMatchingPresetIDParams struct { + ParameterNames []string `db:"parameter_names" json:"parameter_names"` + ParameterValues []string `db:"parameter_values" json:"parameter_values"` + TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"` +} + +// FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. +// It returns the preset ID if a match is found, or NULL if no match is found. +// The query finds presets where all preset parameters are present in the provided parameters, +// and returns the preset with the most parameters (largest subset). +func (q *sqlQuerier) FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, findMatchingPresetID, pq.Array(arg.ParameterNames), pq.Array(arg.ParameterValues), arg.TemplateVersionID) + var template_version_preset_id uuid.UUID + err := row.Scan(&template_version_preset_id) + return template_version_preset_id, err +} + const getPrebuildMetrics = `-- name: GetPrebuildMetrics :many SELECT t.name as template_name, diff --git a/coderd/database/queries/prebuilds.sql b/coderd/database/queries/prebuilds.sql index 87a713974c563..8654453554e8c 100644 --- a/coderd/database/queries/prebuilds.sql +++ b/coderd/database/queries/prebuilds.sql @@ -245,3 +245,30 @@ INNER JOIN organizations o ON o.id = w.organization_id WHERE NOT t.deleted AND wpb.build_number = 1 GROUP BY t.name, tvp.name, o.name ORDER BY t.name, tvp.name, o.name; + +-- name: FindMatchingPresetID :one +-- FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. +-- It returns the preset ID if a match is found, or NULL if no match is found. +-- The query finds presets where all preset parameters are present in the provided parameters, +-- and returns the preset with the most parameters (largest subset). +WITH provided_params AS ( + SELECT + unnest(@parameter_names::text[]) AS name, + unnest(@parameter_values::text[]) AS value +), +preset_matches AS ( + SELECT + tvp.id AS template_version_preset_id, + COALESCE(COUNT(tvpp.name), 0) AS total_preset_params, + COALESCE(COUNT(pp.name), 0) AS matching_params + FROM template_version_presets tvp + LEFT JOIN template_version_preset_parameters tvpp ON tvpp.template_version_preset_id = tvp.id + LEFT JOIN provided_params pp ON pp.name = tvpp.name AND pp.value = tvpp.value + WHERE tvp.template_version_id = @template_version_id + GROUP BY tvp.id +) +SELECT pm.template_version_preset_id +FROM preset_matches pm +WHERE pm.total_preset_params = pm.matching_params -- All preset parameters must match +ORDER BY pm.total_preset_params DESC -- Return the preset with the most parameters +LIMIT 1; diff --git a/coderd/prebuilds/parameters.go b/coderd/prebuilds/parameters.go new file mode 100644 index 0000000000000..63a1a7b78bfa7 --- /dev/null +++ b/coderd/prebuilds/parameters.go @@ -0,0 +1,42 @@ +package prebuilds + +import ( + "context" + "database/sql" + "errors" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +// FindMatchingPresetID finds a preset ID that matches the provided parameters. +// It returns the preset ID if a match is found, or uuid.Nil if no match is found. +// The function performs a bidirectional comparison to ensure all parameters match exactly. +func FindMatchingPresetID( + ctx context.Context, + store database.Store, + templateVersionID uuid.UUID, + parameterNames []string, + parameterValues []string, +) (uuid.UUID, error) { + if len(parameterNames) != len(parameterValues) { + return uuid.Nil, xerrors.New("parameter names and values must have the same length") + } + + result, err := store.FindMatchingPresetID(ctx, database.FindMatchingPresetIDParams{ + TemplateVersionID: templateVersionID, + ParameterNames: parameterNames, + ParameterValues: parameterValues, + }) + if err != nil { + // Handle the case where no matching preset is found (no rows returned) + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, nil + } + return uuid.Nil, xerrors.Errorf("find matching preset ID: %w", err) + } + + return result, nil +} diff --git a/coderd/prebuilds/parameters_test.go b/coderd/prebuilds/parameters_test.go new file mode 100644 index 0000000000000..e9366bb1da02b --- /dev/null +++ b/coderd/prebuilds/parameters_test.go @@ -0,0 +1,198 @@ +package prebuilds_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/testutil" +) + +func TestFindMatchingPresetID(t *testing.T) { + t.Parallel() + + presetIDs := []uuid.UUID{ + uuid.New(), + uuid.New(), + } + // Give each preset a meaningful name in alphabetical order + presetNames := map[uuid.UUID]string{ + presetIDs[0]: "development", + presetIDs[1]: "production", + } + tests := []struct { + name string + parameterNames []string + parameterValues []string + presetParameters []database.TemplateVersionPresetParameter + expectedPresetID uuid.UUID + expectError bool + errorContains string + }{ + { + name: "exact match", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-west-2", "t3.medium"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: presetIDs[0], + expectError: false, + }, + { + name: "no match - different values", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-east-1", "t3.medium"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: uuid.Nil, + expectError: false, + }, + { + name: "no match - fewer provided parameters", + parameterNames: []string{"region"}, + parameterValues: []string{"us-west-2"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: uuid.Nil, + expectError: false, + }, + { + name: "subset match - extra provided parameter", + parameterNames: []string{"region", "instance_type", "extra_param"}, + parameterValues: []string{"us-west-2", "t3.medium", "extra_value"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + // antagonist: + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: presetIDs[0], // Should match because all preset parameters are present + expectError: false, + }, + { + name: "mismatched parameter names vs values", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-west-2"}, + presetParameters: []database.TemplateVersionPresetParameter{}, + expectedPresetID: uuid.Nil, + expectError: true, + errorContains: "parameter names and values must have the same length", + }, + { + name: "multiple presets - match first", + parameterNames: []string{"region", "instance_type"}, + parameterValues: []string{"us-west-2", "t3.medium"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-east-1"}, + {TemplateVersionPresetID: presetIDs[1], Name: "instance_type", Value: "t3.large"}, + }, + expectedPresetID: presetIDs[0], + expectError: false, + }, + { + name: "largest subset match", + parameterNames: []string{"region", "instance_type", "storage_size"}, + parameterValues: []string{"us-west-2", "t3.medium", "100gb"}, + presetParameters: []database.TemplateVersionPresetParameter{ + {TemplateVersionPresetID: presetIDs[0], Name: "region", Value: "us-west-2"}, + {TemplateVersionPresetID: presetIDs[0], Name: "instance_type", Value: "t3.medium"}, + {TemplateVersionPresetID: presetIDs[1], Name: "region", Value: "us-west-2"}, + }, + expectedPresetID: presetIDs[0], // Should match the larger subset (2 params vs 1 param) + expectError: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + db, _ := dbtestutil.NewDB(t) + org := dbgen.Organization(t, db, database.Organization{}) + user := dbgen.User(t, db, database.User{}) + templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: user.ID, + JobID: uuid.New(), + }) + + // Group parameters by preset ID and create presets + presetMap := make(map[uuid.UUID][]database.TemplateVersionPresetParameter) + for _, param := range tt.presetParameters { + presetMap[param.TemplateVersionPresetID] = append(presetMap[param.TemplateVersionPresetID], param) + } + + // Create presets and insert their parameters + for presetID, params := range presetMap { + // Create the preset + _, err := db.InsertPreset(ctx, database.InsertPresetParams{ + ID: presetID, + TemplateVersionID: templateVersion.ID, + Name: presetNames[presetID], + CreatedAt: dbtestutil.NowInDefaultTimezone(), + }) + require.NoError(t, err) + + // Insert parameters for this preset + names := make([]string, len(params)) + values := make([]string, len(params)) + for i, param := range params { + names[i] = param.Name + values[i] = param.Value + } + + _, err = db.InsertPresetParameters(ctx, database.InsertPresetParametersParams{ + TemplateVersionPresetID: presetID, + Names: names, + Values: values, + }) + require.NoError(t, err) + } + + result, err := prebuilds.FindMatchingPresetID( + ctx, + db, + templateVersion.ID, + tt.parameterNames, + tt.parameterValues, + ) + + // Assert results + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedPresetID, result) + } + }) + } +} diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index 29c9cac0ffa13..633acae328673 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -1638,6 +1638,8 @@ func TestPostWorkspaceBuild(t *testing.T) { t.Run("SetsPresetID", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) user := coderdtest.CreateFirstUser(t, client) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ @@ -1645,9 +1647,20 @@ func TestPostWorkspaceBuild(t *testing.T) { ProvisionPlan: []*proto.Response{{ Type: &proto.Response_Plan{ Plan: &proto.PlanComplete{ - Presets: []*proto.Preset{{ - Name: "test", - }}, + Presets: []*proto.Preset{ + { + Name: "autodetected", + }, + { + Name: "manual", + Parameters: []*proto.PresetParameter{ + { + Name: "param1", + Value: "value1", + }, + }, + }, + }, }, }, }}, @@ -1655,28 +1668,29 @@ func TestPostWorkspaceBuild(t *testing.T) { }) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, template.ID) - coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - require.Nil(t, workspace.LatestBuild.TemplateVersionPresetID) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() presets, err := client.TemplateVersionPresets(ctx, version.ID) require.NoError(t, err) - require.Equal(t, 1, len(presets)) - require.Equal(t, "test", presets[0].Name) + require.Equal(t, 2, len(presets)) + require.Equal(t, "autodetected", presets[0].Name) + require.Equal(t, "manual", presets[1].Name) + + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + // Preset ID was detected based on the workspace parameters: + require.Equal(t, presets[0].ID, *workspace.LatestBuild.TemplateVersionPresetID) build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ TemplateVersionID: version.ID, Transition: codersdk.WorkspaceTransitionStart, - TemplateVersionPresetID: presets[0].ID, + TemplateVersionPresetID: presets[1].ID, }) require.NoError(t, err) require.NotNil(t, build.TemplateVersionPresetID) workspace, err = client.Workspace(ctx, workspace.ID) require.NoError(t, err) + require.Equal(t, presets[1].ID, *workspace.LatestBuild.TemplateVersionPresetID) require.Equal(t, build.TemplateVersionPresetID, workspace.LatestBuild.TemplateVersionPresetID) }) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 23bd8c5f6ed9e..b2b2610ff1349 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -638,14 +638,35 @@ func createWorkspace( // Use injected Clock to allow time mocking in tests now := api.Clock.Now() - // If a template preset was chosen, try claim a prebuilt workspace. - if req.TemplateVersionPresetID != uuid.Nil { + templateVersionPresetID := req.TemplateVersionPresetID + + // If no preset was chosen, look for a matching preset by parameter values. + if templateVersionPresetID == uuid.Nil { + parameterNames := make([]string, len(req.RichParameterValues)) + parameterValues := make([]string, len(req.RichParameterValues)) + for i, parameter := range req.RichParameterValues { + parameterNames[i] = parameter.Name + parameterValues[i] = parameter.Value + } + var err error + templateVersionID := req.TemplateVersionID + if templateVersionID == uuid.Nil { + templateVersionID = template.ActiveVersionID + } + templateVersionPresetID, err = prebuilds.FindMatchingPresetID(ctx, db, templateVersionID, parameterNames, parameterValues) + if err != nil { + return xerrors.Errorf("find matching preset: %w", err) + } + } + + // Try to claim a prebuilt workspace. + if templateVersionPresetID != uuid.Nil { // Try and claim an eligible prebuild, if available. // On successful claim, initialize all lifecycle fields from template and workspace-level config // so the newly claimed workspace is properly managed by the lifecycle executor. claimedWorkspace, err = claimPrebuild( - ctx, prebuildsClaimer, db, api.Logger, now, req, owner, - dbAutostartSchedule, nextStartAt, dbTTL) + ctx, prebuildsClaimer, db, api.Logger, now, req.Name, owner, + templateVersionPresetID, dbAutostartSchedule, nextStartAt, dbTTL) // If claiming fails with an expected error (no claimable prebuilds or AGPL does not support prebuilds), // we fall back to creating a new workspace. Otherwise, propagate the unexpected error. if err != nil { @@ -654,7 +675,7 @@ func createWorkspace( fields := []any{ slog.Error(err), slog.F("workspace_name", req.Name), - slog.F("template_version_preset_id", req.TemplateVersionPresetID), + slog.F("template_version_preset_id", templateVersionPresetID), } if !isExpectedError { @@ -718,8 +739,8 @@ func createWorkspace( if req.TemplateVersionID != uuid.Nil { builder = builder.VersionID(req.TemplateVersionID) } - if req.TemplateVersionPresetID != uuid.Nil { - builder = builder.TemplateVersionPresetID(req.TemplateVersionPresetID) + if templateVersionPresetID != uuid.Nil { + builder = builder.TemplateVersionPresetID(templateVersionPresetID) } if claimedWorkspace != nil { builder = builder.MarkPrebuiltWorkspaceClaim() @@ -884,13 +905,14 @@ func claimPrebuild( db database.Store, logger slog.Logger, now time.Time, - req codersdk.CreateWorkspaceRequest, + name string, owner workspaceOwner, + templateVersionPresetID uuid.UUID, autostartSchedule sql.NullString, nextStartAt sql.NullTime, ttl sql.NullInt64, ) (*database.Workspace, error) { - claimedID, err := claimer.Claim(ctx, now, owner.ID, req.Name, req.TemplateVersionPresetID, autostartSchedule, nextStartAt, ttl) + claimedID, err := claimer.Claim(ctx, now, owner.ID, name, templateVersionPresetID, autostartSchedule, nextStartAt, ttl) if err != nil { // TODO: enhance this by clarifying whether this *specific* prebuild failed or whether there are none to claim. return nil, xerrors.Errorf("claim prebuild: %w", err) diff --git a/coderd/workspaces_test.go b/coderd/workspaces_test.go index 443098036af62..8fc11ef6c8ccb 100644 --- a/coderd/workspaces_test.go +++ b/coderd/workspaces_test.go @@ -4915,3 +4915,285 @@ func TestUpdateWorkspaceACL(t *testing.T) { require.Equal(t, cerr.Validations[0].Field, "user_roles") }) } + +func TestWorkspaceCreateWithImplicitPreset(t *testing.T) { + t.Parallel() + + // Helper function to create template with presets + createTemplateWithPresets := func(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, presets []*proto.Preset) (codersdk.Template, codersdk.TemplateVersion) { + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Presets: presets, + }, + }, + }, + }, + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + return template, version + } + + // Helper function to create workspace and verify preset usage + createWorkspaceAndVerifyPreset := func(t *testing.T, client *codersdk.Client, template codersdk.Template, expectedPresetID *uuid.UUID, params []codersdk.WorkspaceBuildParameter) codersdk.Workspace { + wsName := testutil.GetRandomNameHyphenated(t) + var ws codersdk.Workspace + if len(params) > 0 { + ws = coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.Name = wsName + cwr.RichParameterValues = params + }) + } else { + ws = coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.Name = wsName + }) + } + require.Equal(t, wsName, ws.Name) + + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID) + + // Verify the preset was used if expected + if expectedPresetID != nil { + require.NotNil(t, ws.LatestBuild.TemplateVersionPresetID) + require.Equal(t, *expectedPresetID, *ws.LatestBuild.TemplateVersionPresetID) + } else { + require.Nil(t, ws.LatestBuild.TemplateVersionPresetID) + } + + return ws + } + + t.Run("NoPresets", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with no presets + template, _ := createTemplateWithPresets(t, client, user, []*proto.Preset{}) + + // Test workspace creation with no parameters + createWorkspaceAndVerifyPreset(t, client, template, nil, nil) + + // Test workspace creation with parameters (should still work, no preset matching) + createWorkspaceAndVerifyPreset(t, client, template, nil, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + }) + + t.Run("SinglePresetNoParameters", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with single preset that has no parameters + preset := &proto.Preset{ + Name: "empty-preset", + Description: "A preset with no parameters", + Parameters: []*proto.PresetParameter{}, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset}) + + // Get the preset ID from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + presetID := presets[0].ID + + // Test workspace creation with no parameters - should match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, nil) + + // Test workspace creation with parameters - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + }) + + t.Run("SinglePresetWithParameters", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with single preset that has parameters + preset := &proto.Preset{ + Name: "param-preset", + Description: "A preset with parameters", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset}) + + // Get the preset ID from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + presetID := presets[0].ID + + // Test workspace creation with no parameters - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, nil, nil) + + // Test workspace creation with exact matching parameters - should match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }) + + // Test workspace creation with partial matching parameters - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, nil, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + + // Test workspace creation with different parameter values - should not match the preset + createWorkspaceAndVerifyPreset(t, client, template, nil, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "different"}, + }) + + // Test workspace creation with extra parameters - should match the preset + createWorkspaceAndVerifyPreset(t, client, template, &presetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + {Name: "param3", Value: "value3"}, + }) + }) + + t.Run("MultiplePresets", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with multiple presets + preset1 := &proto.Preset{ + Name: "empty-preset", + Description: "A preset with no parameters", + Parameters: []*proto.PresetParameter{}, + } + preset2 := &proto.Preset{ + Name: "single-param-preset", + Description: "A preset with one parameter", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + }, + } + preset3 := &proto.Preset{ + Name: "multi-param-preset", + Description: "A preset with multiple parameters", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset1, preset2, preset3}) + + // Get the preset IDs from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 3) + + // Sort presets by name to get consistent ordering + var emptyPresetID, singleParamPresetID, multiParamPresetID uuid.UUID + for _, p := range presets { + switch p.Name { + case "empty-preset": + emptyPresetID = p.ID + case "single-param-preset": + singleParamPresetID = p.ID + case "multi-param-preset": + multiParamPresetID = p.ID + } + } + + // Test workspace creation with no parameters - should match empty preset + createWorkspaceAndVerifyPreset(t, client, template, &emptyPresetID, nil) + + // Test workspace creation with single parameter - should match single param preset + createWorkspaceAndVerifyPreset(t, client, template, &singleParamPresetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + }) + + // Test workspace creation with multiple parameters - should match multi param preset + createWorkspaceAndVerifyPreset(t, client, template, &multiParamPresetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, + {Name: "param2", Value: "value2"}, + }) + + // Test workspace creation with non-matching parameters - should not match any preset + createWorkspaceAndVerifyPreset(t, client, template, &emptyPresetID, []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "different"}, + }) + }) + + t.Run("PresetSpecifiedExplicitly", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + // Create template with multiple presets + preset1 := &proto.Preset{ + Name: "preset1", + Description: "First preset", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value1"}, + }, + } + preset2 := &proto.Preset{ + Name: "preset2", + Description: "Second preset", + Parameters: []*proto.PresetParameter{ + {Name: "param1", Value: "value2"}, + }, + } + template, version := createTemplateWithPresets(t, client, user, []*proto.Preset{preset1, preset2}) + + // Get the preset IDs from the database + ctx := context.Background() + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 2) + + var preset1ID, preset2ID uuid.UUID + for _, p := range presets { + switch p.Name { + case "preset1": + preset1ID = p.ID + case "preset2": + preset2ID = p.ID + } + } + + // Test workspace creation with preset1 specified explicitly - should use preset1 regardless of parameters + ws := coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.TemplateVersionPresetID = preset1ID + cwr.RichParameterValues = []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value2"}, // This would normally match preset2 + } + }) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID) + require.NotNil(t, ws.LatestBuild.TemplateVersionPresetID) + require.Equal(t, preset1ID, *ws.LatestBuild.TemplateVersionPresetID) + + // Test workspace creation with preset2 specified explicitly - should use preset2 regardless of parameters + ws2 := coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.TemplateVersionPresetID = preset2ID + cwr.RichParameterValues = []codersdk.WorkspaceBuildParameter{ + {Name: "param1", Value: "value1"}, // This would normally match preset1 + } + }) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws2.LatestBuild.ID) + require.NotNil(t, ws2.LatestBuild.TemplateVersionPresetID) + require.Equal(t, preset2ID, *ws2.LatestBuild.TemplateVersionPresetID) + }) +} diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 73e449ee5bb93..223b8bec084ad 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -15,6 +15,7 @@ import ( "github.com/coder/coder/v2/coderd/dynamicparameters" "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/provisioner/terraform/tfparse" @@ -442,6 +443,20 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object var workspaceBuild database.WorkspaceBuild err = b.store.InTx(func(store database.Store) error { + names, values, err := b.getParameters() + if err != nil { + // getParameters already wraps errors in BuildError + return err + } + + if b.templateVersionPresetID == uuid.Nil { + presetID, err := prebuilds.FindMatchingPresetID(b.ctx, b.store, templateVersionID, names, values) + if err != nil { + return BuildError{http.StatusInternalServerError, "find matching preset", err} + } + b.templateVersionPresetID = presetID + } + err = store.InsertWorkspaceBuild(b.ctx, database.InsertWorkspaceBuildParams{ ID: workspaceBuildID, CreatedAt: now, @@ -473,12 +488,6 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object return BuildError{code, "insert workspace build", err} } - names, values, err := b.getParameters() - if err != nil { - // getParameters already wraps errors in BuildError - return err - } - err = store.InsertWorkspaceBuildParameters(b.ctx, database.InsertWorkspaceBuildParametersParams{ WorkspaceBuildID: workspaceBuildID, Name: names, diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index ee421a8adb649..b862e6459c285 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -82,6 +82,7 @@ func TestBuilder_NoOptions(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(inactiveVersionID, bld.TemplateVersionID) asrt.Equal(workspaceID, bld.WorkspaceID) @@ -132,6 +133,7 @@ func TestBuilder_Initiator(t *testing.T) { asrt.Equal(otherUserID, job.InitiatorID) }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(otherUserID, bld.InitiatorID) }), @@ -180,6 +182,7 @@ func TestBuilder_Baggage(t *testing.T) { asrt.Contains(string(job.TraceMetadata.RawMessage), "ip=127.0.0.1") }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { }), expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) { @@ -219,6 +222,7 @@ func TestBuilder_Reason(t *testing.T) { expectProvisionerJob(func(_ database.InsertProvisionerJobParams) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(database.BuildReasonAutostart, bld.Reason) }), @@ -261,6 +265,7 @@ func TestBuilder_ActiveVersion(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(activeVersionID, bld.TemplateVersionID) // no previous build... @@ -386,6 +391,7 @@ func TestWorkspaceBuildWithTags(t *testing.T) { expectBuildParameters(func(_ database.InsertWorkspaceBuildParametersParams) { }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -470,6 +476,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { } }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -519,6 +526,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { } }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -661,6 +669,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { } }), withBuild, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -713,6 +722,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), // Outputs + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), withInTx, expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), @@ -775,6 +785,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), // Outputs + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), withInTx, expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), @@ -906,6 +917,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(inactiveVersionID, bld.TemplateVersionID) asrt.Equal(workspaceID, bld.WorkspaceID) @@ -968,6 +980,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) { asrt.Equal(inactiveVersionID, bld.TemplateVersionID) asrt.Equal(workspaceID, bld.WorkspaceID) @@ -1041,6 +1054,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { // Outputs expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), withInTx, + expectFindMatchingPresetID(uuid.Nil, sql.ErrNoRows), expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), withBuild, expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) {}), @@ -1485,6 +1499,14 @@ func withProvisionerDaemons(provisionerDaemons []database.GetEligibleProvisioner } } +func expectFindMatchingPresetID(id uuid.UUID, err error) func(mTx *dbmock.MockStore) { + return func(mTx *dbmock.MockStore) { + mTx.EXPECT().FindMatchingPresetID(gomock.Any(), gomock.Any()). + Times(1). + Return(id, err) + } +} + type fakeUsageChecker struct { checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) } diff --git a/docs/admin/templates/extending-templates/prebuilt-workspaces.md b/docs/admin/templates/extending-templates/prebuilt-workspaces.md index 8e61687ce0f01..70c2031d2a837 100644 --- a/docs/admin/templates/extending-templates/prebuilt-workspaces.md +++ b/docs/admin/templates/extending-templates/prebuilt-workspaces.md @@ -29,6 +29,7 @@ Prebuilt workspaces are tightly integrated with [workspace presets](./parameters 1. The preset must define all required parameters needed to build the workspace. 1. The preset parameters define the base configuration and are immutable once a prebuilt workspace is provisioned. 1. Parameters that are not defined in the preset can still be customized by users when they claim a workspace. +1. If a user does not select a preset but provides parameters that match one or more presets, Coder will automatically select the most specific matching preset and assign a prebuilt workspace if one is available. ## Prerequisites @@ -291,16 +292,6 @@ does not reconnect after a template update. This shortcoming is described in [th and will be addressed before the next release (v2.23). In the interim, a simple workaround is to restart the workspace when it is in this problematic state. -### Current limitations - -The prebuilt workspaces feature has these current limitations: - -- **Organizations** - - Prebuilt workspaces can only be used with the default organization. - - [View issue](https://github.com/coder/internal/issues/364) - ### Monitoring and observability #### Available metrics From cd1faffeff834d2c96653485a38d745e3bf5bb69 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 20 Aug 2025 11:14:41 +0200 Subject: [PATCH 03/11] docs: re-add missing Templates and Modules entries to manifest.json (#19442) --- docs/manifest.json | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/manifest.json b/docs/manifest.json index 66f4e6dbaf476..bd08ccfe372e6 100644 --- a/docs/manifest.json +++ b/docs/manifest.json @@ -47,6 +47,18 @@ "path": "./about/contributing/documentation.md", "icon_path": "./images/icons/document.svg" }, + { + "title": "Modules", + "description": "Learn how to contribute modules to Coder", + "path": "./about/contributing/modules.md", + "icon_path": "./images/icons/gear.svg" + }, + { + "title": "Templates", + "description": "Learn how to contribute templates to Coder", + "path": "./about/contributing/templates.md", + "icon_path": "./images/icons/picture.svg" + }, { "title": "Backend", "description": "Our guide for backend development", From 560cf84251bd124dc343bcca251816214c9fb507 Mon Sep 17 00:00:00 2001 From: Susana Ferreira Date: Wed, 20 Aug 2025 12:19:14 +0100 Subject: [PATCH 04/11] fix: prevent activity bump for prebuilt workspaces (#19263) ## Description This PR ensures that activity-based deadline extensions ("activity bumping") are not applied to prebuilt workspaces. Prebuilds are managed by the reconciliation loop and must not have `deadline` or `max_deadline` values set or extended, as they are not part of the regular lifecycle executor path. ## Changes - Update `ActivityBumpWorkspace` SQL query to discard prebuilt workspaces - Update application layer to avoid calling activity bump logic on prebuilt workspaces Related with: * Issue: https://github.com/coder/coder/issues/18898 * PR: https://github.com/coder/coder/pull/19252 --- coderd/database/queries.sql.go | 8 +- coderd/database/queries/activitybump.sql | 8 +- coderd/workspacestats/reporter.go | 51 ++++++----- enterprise/coderd/workspaces_test.go | 109 +++++++++++++++++++++++ 4 files changed, 148 insertions(+), 28 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 70558724a664d..d16bd34f25f82 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -32,7 +32,7 @@ WITH latest AS ( -- be as if the workspace auto started at the given time and the -- original TTL was applied. -- - -- Sadly we can't define ` + "`" + `activity_bump_interval` + "`" + ` above since + -- Sadly we can't define 'activity_bump_interval' above since -- it won't be available for this CASE statement, so we have to -- copy the cast twice. WHEN NOW() + (templates.activity_bump / 1000 / 1000 / 1000 || ' seconds')::interval > $1 :: timestamptz @@ -62,7 +62,11 @@ WITH latest AS ( ON workspaces.id = workspace_builds.workspace_id JOIN templates ON templates.id = workspaces.template_id - WHERE workspace_builds.workspace_id = $2::uuid + WHERE + workspace_builds.workspace_id = $2::uuid + -- Prebuilt workspaces (identified by having the prebuilds system user as owner_id) + -- are managed by the reconciliation loop and not subject to activity bumping + AND workspaces.owner_id != 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID ORDER BY workspace_builds.build_number DESC LIMIT 1 ) diff --git a/coderd/database/queries/activitybump.sql b/coderd/database/queries/activitybump.sql index 09349d29e5d06..e367a93abf778 100644 --- a/coderd/database/queries/activitybump.sql +++ b/coderd/database/queries/activitybump.sql @@ -22,7 +22,7 @@ WITH latest AS ( -- be as if the workspace auto started at the given time and the -- original TTL was applied. -- - -- Sadly we can't define `activity_bump_interval` above since + -- Sadly we can't define 'activity_bump_interval' above since -- it won't be available for this CASE statement, so we have to -- copy the cast twice. WHEN NOW() + (templates.activity_bump / 1000 / 1000 / 1000 || ' seconds')::interval > @next_autostart :: timestamptz @@ -52,7 +52,11 @@ WITH latest AS ( ON workspaces.id = workspace_builds.workspace_id JOIN templates ON templates.id = workspaces.template_id - WHERE workspace_builds.workspace_id = @workspace_id::uuid + WHERE + workspace_builds.workspace_id = @workspace_id::uuid + -- Prebuilt workspaces (identified by having the prebuilds system user as owner_id) + -- are managed by the reconciliation loop and not subject to activity bumping + AND workspaces.owner_id != 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID ORDER BY workspace_builds.build_number DESC LIMIT 1 ) diff --git a/coderd/workspacestats/reporter.go b/coderd/workspacestats/reporter.go index 58d177f1c2071..f6b8a8dd0953b 100644 --- a/coderd/workspacestats/reporter.go +++ b/coderd/workspacestats/reporter.go @@ -149,33 +149,36 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac return nil } - // check next autostart - var nextAutostart time.Time - if workspace.AutostartSchedule.String != "" { - templateSchedule, err := (*(r.opts.TemplateScheduleStore.Load())).Get(ctx, r.opts.Database, workspace.TemplateID) - // If the template schedule fails to load, just default to bumping - // without the next transition and log it. - switch { - case err == nil: - next, allowed := schedule.NextAutostart(now, workspace.AutostartSchedule.String, templateSchedule) - if allowed { - nextAutostart = next + // Prebuilds are not subject to activity-based deadline bumps + if !workspace.IsPrebuild() { + // check next autostart + var nextAutostart time.Time + if workspace.AutostartSchedule.String != "" { + templateSchedule, err := (*(r.opts.TemplateScheduleStore.Load())).Get(ctx, r.opts.Database, workspace.TemplateID) + // If the template schedule fails to load, just default to bumping + // without the next transition and log it. + switch { + case err == nil: + next, allowed := schedule.NextAutostart(now, workspace.AutostartSchedule.String, templateSchedule) + if allowed { + nextAutostart = next + } + case database.IsQueryCanceledError(err): + r.opts.Logger.Debug(ctx, "query canceled while loading template schedule", + slog.F("workspace_id", workspace.ID), + slog.F("template_id", workspace.TemplateID)) + default: + r.opts.Logger.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min", + slog.F("workspace_id", workspace.ID), + slog.F("template_id", workspace.TemplateID), + slog.Error(err), + ) } - case database.IsQueryCanceledError(err): - r.opts.Logger.Debug(ctx, "query canceled while loading template schedule", - slog.F("workspace_id", workspace.ID), - slog.F("template_id", workspace.TemplateID)) - default: - r.opts.Logger.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min", - slog.F("workspace_id", workspace.ID), - slog.F("template_id", workspace.TemplateID), - slog.Error(err), - ) } - } - // bump workspace activity - ActivityBumpWorkspace(ctx, r.opts.Logger.Named("activity_bump"), r.opts.Database, workspace.ID, nextAutostart) + // bump workspace activity + ActivityBumpWorkspace(ctx, r.opts.Logger.Named("activity_bump"), r.opts.Database, workspace.ID, nextAutostart) + } // bump workspace last_used_at r.opts.UsageTracker.Add(workspace.ID) diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 7004653e4ed60..dc44a8794e1c6 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -42,6 +42,7 @@ import ( agplschedule "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/codersdk" entaudit "github.com/coder/coder/v2/enterprise/audit" "github.com/coder/coder/v2/enterprise/audit/backends" @@ -2767,6 +2768,114 @@ func TestPrebuildUpdateLifecycleParams(t *testing.T) { } } +func TestPrebuildActivityBump(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + clock.Set(dbtime.Now()) + + // Setup + log := testutil.Logger(t) + client, db, owner := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Clock: clock, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspacePrebuilds: 1, + }, + }, + }) + + // Given: a template and a template version with preset and a prebuilt workspace + presetID := uuid.New() + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + // Configure activity bump on the template + activityBump := time.Hour + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + ctr.ActivityBumpMillis = ptr.Ref[int64](activityBump.Milliseconds()) + }) + dbgen.Preset(t, db, database.InsertPresetParams{ + ID: presetID, + TemplateVersionID: version.ID, + DesiredInstances: sql.NullInt32{Int32: 1, Valid: true}, + }) + // Given: a prebuild with an expired Deadline + deadline := clock.Now().Add(-30 * time.Minute) + wb := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: database.PrebuildsSystemUserID, + TemplateID: template.ID, + }).Seed(database.WorkspaceBuild{ + TemplateVersionID: version.ID, + TemplateVersionPresetID: uuid.NullUUID{ + UUID: presetID, + Valid: true, + }, + Deadline: deadline, + }).WithAgent(func(agent []*proto.Agent) []*proto.Agent { + return agent + }).Do() + + // Mark the prebuilt workspace's agent as ready so the prebuild can be claimed + // nolint:gocritic + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong)) + agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(wb.AgentToken)) + require.NoError(t, err) + err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agent.WorkspaceAgent.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + }) + require.NoError(t, err) + + // Given: a prebuilt workspace with a Deadline and an empty MaxDeadline + prebuild := coderdtest.MustWorkspace(t, client, wb.Workspace.ID) + require.Equal(t, deadline.UTC(), prebuild.LatestBuild.Deadline.Time.UTC()) + require.Zero(t, prebuild.LatestBuild.MaxDeadline) + + // When: activity bump is applied to an unclaimed prebuild + workspacestats.ActivityBumpWorkspace(ctx, log, db, prebuild.ID, clock.Now().Add(10*time.Hour)) + + // Then: prebuild Deadline/MaxDeadline remain unchanged + prebuild = coderdtest.MustWorkspace(t, client, wb.Workspace.ID) + require.Equal(t, deadline.UTC(), prebuild.LatestBuild.Deadline.Time.UTC()) + require.Zero(t, prebuild.LatestBuild.MaxDeadline) + + // Given: the prebuilt workspace is claimed by a user + user, err := client.User(ctx, "testUser") + require.NoError(t, err) + claimedWorkspace, err := client.CreateUserWorkspace(ctx, user.ID.String(), codersdk.CreateWorkspaceRequest{ + TemplateVersionID: version.ID, + TemplateVersionPresetID: presetID, + Name: coderdtest.RandomUsername(t), + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, claimedWorkspace.LatestBuild.ID) + workspace := coderdtest.MustWorkspace(t, client, claimedWorkspace.ID) + require.Equal(t, prebuild.ID, workspace.ID) + // Claimed workspaces have an empty Deadline and MaxDeadline + require.Zero(t, workspace.LatestBuild.Deadline) + require.Zero(t, workspace.LatestBuild.MaxDeadline) + + // Given: the claimed workspace has an expired Deadline + err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ + ID: workspace.LatestBuild.ID, + Deadline: deadline, + UpdatedAt: clock.Now(), + }) + require.NoError(t, err) + workspace = coderdtest.MustWorkspace(t, client, claimedWorkspace.ID) + + // When: activity bump is applied to a claimed prebuild + workspacestats.ActivityBumpWorkspace(ctx, log, db, workspace.ID, clock.Now().Add(10*time.Hour)) + + // Then: Deadline is extended by the activity bump, MaxDeadline remains unset + workspace = coderdtest.MustWorkspace(t, client, claimedWorkspace.ID) + require.WithinDuration(t, clock.Now().Add(activityBump).UTC(), workspace.LatestBuild.Deadline.Time.UTC(), testutil.WaitMedium) + require.Zero(t, workspace.LatestBuild.MaxDeadline) +} + // TestWorkspaceTemplateParamsChange tests a workspace with a parameter that // validation changes on apply. The params used in create workspace are invalid // according to the static params on import. From dd867bd7434e57ce7e5cb8a51fd0a60fc66028ec Mon Sep 17 00:00:00 2001 From: Garrett Delfosse Date: Wed, 20 Aug 2025 05:39:08 -0700 Subject: [PATCH 05/11] fix: fix jetbrains toolbox connection tracking (#19348) Fixes https://github.com/coder/coder/issues/18350 I attempted the route of relying on just the session env vars, in hopes that this issue was fixed in Toolbox and the process name matching was no longer need, but it was not a fruitful endeavor and it seems to be using the same connection logic as it did in gateway, just with new binary and flag names. --- agent/agentssh/agentssh.go | 2 ++ agent/agentssh/jetbrainstrack.go | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f53fe207c72cf..f9c28a3e6ee25 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -46,6 +46,8 @@ const ( // MagicProcessCmdlineJetBrains is a string in a process's command line that // uniquely identifies it as JetBrains software. MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains" + MagicProcessCmdlineToolbox = "com.jetbrains.toolbox" + MagicProcessCmdlineGateway = "remote-dev-server" // BlockedFileTransferErrorCode indicates that SSH server restricted the raw command from performing // the file transfer. diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 9b2fdf83b21d0..874f4c278ce79 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -53,7 +53,7 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, reportConne // If this is not JetBrains, then we do not need to do anything special. We // attempt to match on something that appears unique to JetBrains software. - if !strings.Contains(strings.ToLower(cmdline), strings.ToLower(MagicProcessCmdlineJetBrains)) { + if !isJetbrainsProcess(cmdline) { return newChannel } @@ -104,3 +104,18 @@ func (c *ChannelOnClose) Close() error { c.once.Do(c.done) return c.Channel.Close() } + +func isJetbrainsProcess(cmdline string) bool { + opts := []string{ + MagicProcessCmdlineJetBrains, + MagicProcessCmdlineToolbox, + MagicProcessCmdlineGateway, + } + + for _, opt := range opts { + if strings.Contains(strings.ToLower(cmdline), strings.ToLower(opt)) { + return true + } + } + return false +} From 6eb02d1c2a22a99b7b57f9338f82d578c7bccc2e Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 20 Aug 2025 23:38:09 +1000 Subject: [PATCH 06/11] chore: wire up usage tracking for managed agents (#19096) Wires up the usage collector and publisher to coderd. Relates to coder/internal#814 --- cli/delete_test.go | 1 - cli/provisioners_test.go | 1 - coderd/agentapi/subagent_test.go | 26 +- coderd/coderd.go | 14 + coderd/database/dbauthz/dbauthz.go | 28 +- coderd/database/dbauthz/dbauthz_test.go | 15 +- coderd/externalauth/externalauth_test.go | 1 - coderd/files/cache_test.go | 6 - coderd/idpsync/group_test.go | 2 - coderd/idpsync/role_test.go | 1 - coderd/insights_test.go | 2 - coderd/notifications/manager_test.go | 4 - coderd/notifications/metrics_test.go | 4 - coderd/notifications/notifications_test.go | 22 - .../reports/generator_internal_test.go | 1 - .../insights/metricscollector_test.go | 1 - .../provisionerdserver/provisionerdserver.go | 21 + .../provisionerdserver_test.go | 496 ++++++++++++------ coderd/rbac/authz.go | 2 +- coderd/usage/inserter.go | 2 + coderd/userauth_test.go | 2 - coderd/users_test.go | 5 - coderd/workspaceagents_test.go | 6 +- coderd/workspacebuilds_test.go | 2 - coderd/workspaces_test.go | 14 +- enterprise/cli/prebuilds_test.go | 1 - enterprise/cli/server.go | 45 +- enterprise/coderd/coderd.go | 10 + enterprise/coderd/coderd_test.go | 4 - .../coderd/enidpsync/organizations_test.go | 1 - enterprise/coderd/idpsync_test.go | 1 - .../coderd/prebuilds/metricscollector_test.go | 2 - enterprise/coderd/provisionerdaemons.go | 1 + enterprise/coderd/provisionerdaemons_test.go | 3 - enterprise/coderd/schedule/template_test.go | 1 - enterprise/coderd/usage/inserter.go | 23 +- enterprise/coderd/usage/inserter_test.go | 8 +- enterprise/coderd/usage/publisher.go | 51 +- enterprise/coderd/usage/publisher_test.go | 30 +- enterprise/coderd/userauth_test.go | 1 - enterprise/coderd/workspacequota_test.go | 8 - enterprise/coderd/workspaces_test.go | 11 - scripts/rules.go | 4 +- 43 files changed, 539 insertions(+), 345 deletions(-) diff --git a/cli/delete_test.go b/cli/delete_test.go index c01893419f80f..2e550d74849ab 100644 --- a/cli/delete_test.go +++ b/cli/delete_test.go @@ -111,7 +111,6 @@ func TestDelete(t *testing.T) { // The API checks if the user has any workspaces, so we cannot delete a user // this way. ctx := testutil.Context(t, testutil.WaitShort) - // nolint:gocritic // Unit test err := api.Database.UpdateUserDeletedByID(dbauthz.AsSystemRestricted(ctx), deleteMeUser.ID) require.NoError(t, err) diff --git a/cli/provisioners_test.go b/cli/provisioners_test.go index 30a89714ff57f..0c3fe5ae2f6d1 100644 --- a/cli/provisioners_test.go +++ b/cli/provisioners_test.go @@ -31,7 +31,6 @@ func TestProvisioners_Golden(t *testing.T) { // Replace UUIDs with predictable values for golden files. replace := make(map[string]string) updateReplaceUUIDs := func(coderdAPI *coderd.API) { - //nolint:gocritic // This is a test. systemCtx := dbauthz.AsSystemRestricted(context.Background()) provisioners, err := coderdAPI.Database.GetProvisionerDaemons(systemCtx) require.NoError(t, err) diff --git a/coderd/agentapi/subagent_test.go b/coderd/agentapi/subagent_test.go index 0a95a70e5216d..1b6eef936f827 100644 --- a/coderd/agentapi/subagent_test.go +++ b/coderd/agentapi/subagent_test.go @@ -163,7 +163,7 @@ func TestSubAgentAPI(t *testing.T) { agentID, err := uuid.FromBytes(createResp.Agent.Id) require.NoError(t, err) - agent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) //nolint:gocritic // this is a test. + agent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) require.NoError(t, err) assert.Equal(t, tt.agentName, agent.Name) @@ -621,7 +621,7 @@ func TestSubAgentAPI(t *testing.T) { agentID, err := uuid.FromBytes(createResp.Agent.Id) require.NoError(t, err) - apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) //nolint:gocritic // this is a test. + apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) require.NoError(t, err) // Sort the apps for determinism @@ -751,7 +751,7 @@ func TestSubAgentAPI(t *testing.T) { agentID, err := uuid.FromBytes(createResp.Agent.Id) require.NoError(t, err) - apps, err := db.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) //nolint:gocritic // this is a test. + apps, err := db.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) require.NoError(t, err) require.Len(t, apps, 1) require.Equal(t, "k5jd7a99-duplicate-slug", apps[0].Slug) @@ -789,7 +789,7 @@ func TestSubAgentAPI(t *testing.T) { require.NoError(t, err) // Then: It is deleted. - _, err = db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgent.ID) //nolint:gocritic // this is a test. + _, err = db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgent.ID) require.ErrorIs(t, err, sql.ErrNoRows) }) @@ -830,10 +830,10 @@ func TestSubAgentAPI(t *testing.T) { require.NoError(t, err) // Then: The correct one is deleted. - _, err = api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgentOne.ID) //nolint:gocritic // this is a test. + _, err = api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgentOne.ID) require.ErrorIs(t, err, sql.ErrNoRows) - _, err = api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgentTwo.ID) //nolint:gocritic // this is a test. + _, err = api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgentTwo.ID) require.NoError(t, err) }) @@ -871,7 +871,7 @@ func TestSubAgentAPI(t *testing.T) { var notAuthorizedError dbauthz.NotAuthorizedError require.ErrorAs(t, err, ¬AuthorizedError) - _, err = db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgentOne.ID) //nolint:gocritic // this is a test. + _, err = db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), childAgentOne.ID) require.NoError(t, err) }) @@ -912,7 +912,7 @@ func TestSubAgentAPI(t *testing.T) { require.NoError(t, err) // Verify that the apps were created - apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), subAgentID) //nolint:gocritic // this is a test. + apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), subAgentID) require.NoError(t, err) require.Len(t, apps, 2) @@ -923,7 +923,7 @@ func TestSubAgentAPI(t *testing.T) { require.NoError(t, err) // Then: The agent is deleted - _, err = api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), subAgentID) //nolint:gocritic // this is a test. + _, err = api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), subAgentID) require.ErrorIs(t, err, sql.ErrNoRows) // And: The apps are *retained* to avoid causing issues @@ -1068,7 +1068,7 @@ func TestSubAgentAPI(t *testing.T) { agentID, err := uuid.FromBytes(createResp.Agent.Id) require.NoError(t, err) - subAgent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) //nolint:gocritic // this is a test. + subAgent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) require.NoError(t, err) require.Equal(t, len(tt.expectedApps), len(subAgent.DisplayApps), "display apps count mismatch") @@ -1118,14 +1118,14 @@ func TestSubAgentAPI(t *testing.T) { require.NoError(t, err) // Verify display apps - subAgent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) //nolint:gocritic // this is a test. + subAgent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) require.NoError(t, err) require.Len(t, subAgent.DisplayApps, 2) require.Equal(t, database.DisplayAppVscode, subAgent.DisplayApps[0]) require.Equal(t, database.DisplayAppWebTerminal, subAgent.DisplayApps[1]) // Verify regular apps - apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) //nolint:gocritic // this is a test. + apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) require.NoError(t, err) require.Len(t, apps, 1) require.Equal(t, "v4qhkq17-custom-app", apps[0].Slug) @@ -1190,7 +1190,7 @@ func TestSubAgentAPI(t *testing.T) { }) // When: We list the sub agents. - listResp, err := api.ListSubAgents(ctx, &proto.ListSubAgentsRequest{}) //nolint:gocritic // this is a test. + listResp, err := api.ListSubAgents(ctx, &proto.ListSubAgentsRequest{}) require.NoError(t, err) listedChildAgents := listResp.Agents diff --git a/coderd/coderd.go b/coderd/coderd.go index 8ab204f8a31ef..5debc13d21431 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/v2/coderd/oauth2provider" "github.com/coder/coder/v2/coderd/pproflabel" "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/andybalholm/brotli" @@ -200,6 +201,7 @@ type Options struct { TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] AccessControlStore *atomic.Pointer[dbauthz.AccessControlStore] + UsageInserter *atomic.Pointer[usage.Inserter] // CoordinatorResumeTokenProvider is used to provide and validate resume // tokens issued by and passed to the coordinator DRPC API. CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider @@ -428,6 +430,13 @@ func New(options *Options) *API { v := schedule.NewAGPLUserQuietHoursScheduleStore() options.UserQuietHoursScheduleStore.Store(&v) } + if options.UsageInserter == nil { + options.UsageInserter = &atomic.Pointer[usage.Inserter]{} + } + if options.UsageInserter.Load() == nil { + inserter := usage.NewAGPLInserter() + options.UsageInserter.Store(&inserter) + } if options.OneTimePasscodeValidityPeriod == 0 { options.OneTimePasscodeValidityPeriod = 20 * time.Minute } @@ -590,6 +599,7 @@ func New(options *Options) *API { UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore, AccessControlStore: options.AccessControlStore, BuildUsageChecker: &buildUsageChecker, + UsageInserter: options.UsageInserter, FileCache: files.New(options.PrometheusRegistry, options.Authorizer), Experiments: experiments, WebpushDispatcher: options.WebPushDispatcher, @@ -1690,6 +1700,9 @@ type API struct { // BuildUsageChecker is a pointer as it's passed around to multiple // components. BuildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] + // UsageInserter is a pointer to an atomic pointer because it is passed to + // multiple components. + UsageInserter *atomic.Pointer[usage.Inserter] UpdatesProvider tailnet.WorkspaceUpdatesProvider @@ -1905,6 +1918,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n &api.Auditor, api.TemplateScheduleStore, api.UserQuietHoursScheduleStore, + api.UsageInserter, api.DeploymentValues, provisionerdserver.Options{ OIDCConfig: api.OIDCConfig, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a716c04adc030..94e60db47cb30 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -213,6 +213,8 @@ var ( // Provisionerd creates workspaces resources monitor rbac.ResourceWorkspaceAgentResourceMonitor.Type: {policy.ActionCreate}, rbac.ResourceWorkspaceAgentDevcontainers.Type: {policy.ActionCreate}, + // Provisionerd creates usage events + rbac.ResourceUsageEvent.Type: {policy.ActionCreate}, }), Org: map[string][]rbac.Permission{}, User: []rbac.Permission{}, @@ -510,17 +512,19 @@ var ( Scope: rbac.ScopeAll, }.WithCachedASTValue() - subjectUsageTracker = rbac.Subject{ - Type: rbac.SubjectTypeUsageTracker, - FriendlyName: "Usage Tracker", + subjectUsagePublisher = rbac.Subject{ + Type: rbac.SubjectTypeUsagePublisher, + FriendlyName: "Usage Publisher", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ { - Identifier: rbac.RoleIdentifier{Name: "usage-tracker"}, - DisplayName: "Usage Tracker", + Identifier: rbac.RoleIdentifier{Name: "usage-publisher"}, + DisplayName: "Usage Publisher", Site: rbac.Permissions(map[string][]policy.Action{ - rbac.ResourceLicense.Type: {policy.ActionRead}, - rbac.ResourceUsageEvent.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate}, + rbac.ResourceLicense.Type: {policy.ActionRead}, + // The usage publisher doesn't create events, just + // reads/processes them. + rbac.ResourceUsageEvent.Type: {policy.ActionRead, policy.ActionUpdate}, }), Org: map[string][]rbac.Permission{}, User: []rbac.Permission{}, @@ -604,10 +608,10 @@ func AsFileReader(ctx context.Context) context.Context { return As(ctx, subjectFileReader) } -// AsUsageTracker returns a context with an actor that has permissions required -// for creating, reading, and updating usage events. -func AsUsageTracker(ctx context.Context) context.Context { - return As(ctx, subjectUsageTracker) +// AsUsagePublisher returns a context with an actor that has permissions +// required for creating, reading, and updating usage events. +func AsUsagePublisher(ctx context.Context) context.Context { + return As(ctx, subjectUsagePublisher) } var AsRemoveActor = rbac.Subject{ @@ -3038,7 +3042,7 @@ func (q *querier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTe } func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceLicense); err != nil { return nil, err } return q.db.GetUnexpiredLicenses(ctx) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index ce70a9b1f112a..ad444d1025514 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -758,6 +758,18 @@ func (s *MethodTestSuite) TestLicense() { check.Args().Asserts(l, policy.ActionRead). Returns([]database.License{l}) })) + s.Run("GetUnexpiredLicenses", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + l := database.License{ + ID: 1, + Exp: time.Now().Add(time.Hour * 24 * 30), + UUID: uuid.New(), + } + db.EXPECT().GetUnexpiredLicenses(gomock.Any()). + Return([]database.License{l}, nil). + AnyTimes() + check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead). + Returns([]database.License{l}) + })) s.Run("InsertLicense", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertLicenseParams{}). Asserts(rbac.ResourceLicense, policy.ActionCreate) @@ -3770,9 +3782,6 @@ func (s *MethodTestSuite) TestSystemFunctions() { s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) { check.Args(false).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns(int64(0)) })) - s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead) - })) s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) { u := dbgen.User(s.T(), db, database.User{}) check.Args(u.ID).Asserts(rbac.ResourceSystem, policy.ActionRead) diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 484d59beabb9b..8e46566ed2738 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -337,7 +337,6 @@ func TestRefreshToken(t *testing.T) { require.Equal(t, 1, validateCalls, "token is validated") require.Equal(t, 1, refreshCalls, "token is refreshed") require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated") - //nolint:gocritic // testing dbLink, err := db.GetExternalAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetExternalAuthLinkParams{ ProviderID: link.ProviderID, UserID: link.UserID, diff --git a/coderd/files/cache_test.go b/coderd/files/cache_test.go index 6f8f74e74fe8e..b81deae5d9714 100644 --- a/coderd/files/cache_test.go +++ b/coderd/files/cache_test.go @@ -45,7 +45,6 @@ func TestCancelledFetch(t *testing.T) { cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // Cancel the context for the first call; should fail. - //nolint:gocritic // Unit testing ctx, cancel := context.WithCancel(dbauthz.AsFileReader(testutil.Context(t, testutil.WaitShort))) cancel() _, err := cache.Acquire(ctx, dbM, fileID) @@ -71,7 +70,6 @@ func TestCancelledConcurrentFetch(t *testing.T) { cache := files.LeakCache{Cache: files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})} - //nolint:gocritic // Unit testing ctx := dbauthz.AsFileReader(testutil.Context(t, testutil.WaitShort)) // Cancel the context for the first call; should fail. @@ -99,7 +97,6 @@ func TestConcurrentFetch(t *testing.T) { }) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - //nolint:gocritic // Unit testing ctx := dbauthz.AsFileReader(testutil.Context(t, testutil.WaitShort)) // Expect 2 calls to Acquire before we continue the test @@ -151,7 +148,6 @@ func TestCacheRBAC(t *testing.T) { Scope: rbac.ScopeAll, }) - //nolint:gocritic // Unit testing cacheReader := dbauthz.AsFileReader(ctx) t.Run("NoRolesOpen", func(t *testing.T) { @@ -207,7 +203,6 @@ func cachePromMetricName(metric string) string { func TestConcurrency(t *testing.T) { t.Parallel() - //nolint:gocritic // Unit testing ctx := dbauthz.AsFileReader(t.Context()) const fileSize = 10 @@ -268,7 +263,6 @@ func TestConcurrency(t *testing.T) { func TestRelease(t *testing.T) { t.Parallel() - //nolint:gocritic // Unit testing ctx := dbauthz.AsFileReader(t.Context()) const fileSize = 10 diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 478d6557de551..7f4ee9f435813 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -328,7 +328,6 @@ func TestGroupSyncTable(t *testing.T) { }, } - //nolint:gocritic // testing defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) require.NoError(t, err) SetupOrganization(t, s, db, user, defOrg.ID, def) @@ -527,7 +526,6 @@ func TestApplyGroupDifference(t *testing.T) { db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitMedium) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) org := dbgen.Organization(t, db, database.Organization{}) diff --git a/coderd/idpsync/role_test.go b/coderd/idpsync/role_test.go index 6df091097b966..db172e0ee4237 100644 --- a/coderd/idpsync/role_test.go +++ b/coderd/idpsync/role_test.go @@ -273,7 +273,6 @@ func TestRoleSyncTable(t *testing.T) { } // Also assert site wide roles - //nolint:gocritic // unit testing assertions allRoles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), user.ID) require.NoError(t, err) diff --git a/coderd/insights_test.go b/coderd/insights_test.go index d916b20fea26e..cf5f63065df99 100644 --- a/coderd/insights_test.go +++ b/coderd/insights_test.go @@ -754,7 +754,6 @@ func TestTemplateInsights_Golden(t *testing.T) { Database: db, AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize, }) - //nolint:gocritic // This is a test. err = reporter.ReportAppStats(dbauthz.AsSystemRestricted(ctx), stats) require.NoError(t, err, "want no error inserting app stats") @@ -1646,7 +1645,6 @@ func TestUserActivityInsights_Golden(t *testing.T) { Database: db, AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize, }) - //nolint:gocritic // This is a test. err = reporter.ReportAppStats(dbauthz.AsSystemRestricted(ctx), stats) require.NoError(t, err, "want no error inserting app stats") diff --git a/coderd/notifications/manager_test.go b/coderd/notifications/manager_test.go index e9c309f0a09d3..30af0c88b852c 100644 --- a/coderd/notifications/manager_test.go +++ b/coderd/notifications/manager_test.go @@ -31,7 +31,6 @@ func TestBufferedUpdates(t *testing.T) { // setup - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, ps := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -108,7 +107,6 @@ func TestBuildPayload(t *testing.T) { // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, _ := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -166,7 +164,6 @@ func TestStopBeforeRun(t *testing.T) { // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, ps := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -187,7 +184,6 @@ func TestRunStopRace(t *testing.T) { // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium)) store, ps := dbtestutil.NewDB(t) logger := testutil.Logger(t) diff --git a/coderd/notifications/metrics_test.go b/coderd/notifications/metrics_test.go index 5517f86061cc0..6ba6635a50c4c 100644 --- a/coderd/notifications/metrics_test.go +++ b/coderd/notifications/metrics_test.go @@ -37,7 +37,6 @@ func TestMetrics(t *testing.T) { t.Skip("This test requires postgres; it relies on business-logic only implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -226,7 +225,6 @@ func TestPendingUpdatesMetric(t *testing.T) { t.Parallel() // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -320,7 +318,6 @@ func TestInflightDispatchesMetric(t *testing.T) { t.Parallel() // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -400,7 +397,6 @@ func TestCustomMethodMetricCollection(t *testing.T) { t.Skip("This test requires postgres; it relies on business-logic only implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) diff --git a/coderd/notifications/notifications_test.go b/coderd/notifications/notifications_test.go index e213a62df9996..f5e72a8327d7e 100644 --- a/coderd/notifications/notifications_test.go +++ b/coderd/notifications/notifications_test.go @@ -70,7 +70,6 @@ func TestBasicNotificationRoundtrip(t *testing.T) { t.Skip("This test requires postgres; it relies on business-logic only implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -137,7 +136,6 @@ func TestSMTPDispatch(t *testing.T) { // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -203,7 +201,6 @@ func TestWebhookDispatch(t *testing.T) { // SETUP - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -287,7 +284,6 @@ func TestBackpressure(t *testing.T) { store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitShort)) const method = database.NotificationMethodWebhook @@ -416,7 +412,6 @@ func TestRetries(t *testing.T) { } const maxAttempts = 3 - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -516,7 +511,6 @@ func TestExpiredLeaseIsRequeued(t *testing.T) { t.Skip("This test requires postgres; it relies on business-logic only implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -536,7 +530,6 @@ func TestExpiredLeaseIsRequeued(t *testing.T) { noopInterceptor := newNoopStoreSyncer(store) - // nolint:gocritic // Unit test. mgrCtx, cancelManagerCtx := context.WithCancel(dbauthz.AsNotifier(context.Background())) t.Cleanup(cancelManagerCtx) @@ -645,7 +638,6 @@ func TestNotifierPaused(t *testing.T) { // Setup. - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -1323,7 +1315,6 @@ func TestNotificationTemplates_Golden(t *testing.T) { return &db, &api.Logger, &user }() - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) _, pubsub := dbtestutil.NewDB(t) @@ -1406,13 +1397,11 @@ func TestNotificationTemplates_Golden(t *testing.T) { // as appearance changes are enterprise features and we do not want to mix those // can't use the api if tc.appName != "" { - // nolint:gocritic // Unit test. err = (*db).UpsertApplicationName(dbauthz.AsSystemRestricted(ctx), "Custom Application") require.NoError(t, err) } if tc.logoURL != "" { - // nolint:gocritic // Unit test. err = (*db).UpsertLogoURL(dbauthz.AsSystemRestricted(ctx), "https://custom.application/logo.png") require.NoError(t, err) } @@ -1510,7 +1499,6 @@ func TestNotificationTemplates_Golden(t *testing.T) { }() _, pubsub := dbtestutil.NewDB(t) - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) // Spin up the mock webhook server @@ -1650,7 +1638,6 @@ func TestDisabledByDefaultBeforeEnqueue(t *testing.T) { t.Skip("This test requires postgres; it is testing business-logic implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, _ := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -1676,7 +1663,6 @@ func TestDisabledBeforeEnqueue(t *testing.T) { t.Skip("This test requires postgres; it is testing business-logic implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, _ := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -1712,7 +1698,6 @@ func TestDisabledAfterEnqueue(t *testing.T) { t.Skip("This test requires postgres; it is testing business-logic implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -1769,7 +1754,6 @@ func TestCustomNotificationMethod(t *testing.T) { t.Skip("This test requires postgres; it relies on business-logic only implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -1873,7 +1857,6 @@ func TestNotificationsTemplates(t *testing.T) { t.Skip("This test requires postgres; it relies on business-logic only implemented in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) api := coderdtest.New(t, createOpts(t)) @@ -1910,7 +1893,6 @@ func TestNotificationDuplicates(t *testing.T) { t.Skip("This test requires postgres; it is testing the dedupe hash trigger in the database") } - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -2007,7 +1989,6 @@ func TestNotificationTargetMatrix(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -2051,7 +2032,6 @@ func TestNotificationOneTimePasswordDeliveryTargets(t *testing.T) { t.Run("Inbox", func(t *testing.T) { t.Parallel() - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, _ := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -2076,7 +2056,6 @@ func TestNotificationOneTimePasswordDeliveryTargets(t *testing.T) { t.Run("SMTP", func(t *testing.T) { t.Parallel() - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, _ := dbtestutil.NewDB(t) logger := testutil.Logger(t) @@ -2100,7 +2079,6 @@ func TestNotificationOneTimePasswordDeliveryTargets(t *testing.T) { t.Run("Webhook", func(t *testing.T) { t.Parallel() - // nolint:gocritic // Unit test. ctx := dbauthz.AsNotifier(testutil.Context(t, testutil.WaitSuperLong)) store, _ := dbtestutil.NewDB(t) logger := testutil.Logger(t) diff --git a/coderd/notifications/reports/generator_internal_test.go b/coderd/notifications/reports/generator_internal_test.go index f61064c4e0b23..6dcff173118cb 100644 --- a/coderd/notifications/reports/generator_internal_test.go +++ b/coderd/notifications/reports/generator_internal_test.go @@ -505,7 +505,6 @@ func TestReportFailedWorkspaceBuilds(t *testing.T) { func setup(t *testing.T) (context.Context, slog.Logger, database.Store, pubsub.Pubsub, *notificationstest.FakeEnqueuer, *quartz.Mock) { t.Helper() - // nolint:gocritic // reportFailedWorkspaceBuilds is called by system. ctx := dbauthz.AsSystemRestricted(context.Background()) logger := slogtest.Make(t, &slogtest.Options{}) db, ps := dbtestutil.NewDB(t) diff --git a/coderd/prometheusmetrics/insights/metricscollector_test.go b/coderd/prometheusmetrics/insights/metricscollector_test.go index 9382fa5013525..5c18ec6d1a60f 100644 --- a/coderd/prometheusmetrics/insights/metricscollector_test.go +++ b/coderd/prometheusmetrics/insights/metricscollector_test.go @@ -128,7 +128,6 @@ func TestCollectInsights(t *testing.T) { AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize, }) refTime := time.Now().Add(-3 * time.Minute).Truncate(time.Minute) - //nolint:gocritic // This is a test. err = reporter.ReportAppStats(dbauthz.AsSystemRestricted(context.Background()), []workspaceapps.StatsReport{ { UserID: user.ID, diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index f0091ca63ed5a..93573131a04f8 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -29,6 +29,7 @@ import ( "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk/drpcsdk" @@ -121,6 +122,7 @@ type server struct { DeploymentValues *codersdk.DeploymentValues NotificationsEnqueuer notifications.Enqueuer PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator] + UsageInserter *atomic.Pointer[usage.Inserter] OIDCConfig promoauth.OAuth2Config @@ -174,6 +176,7 @@ func NewServer( auditor *atomic.Pointer[audit.Auditor], templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore], userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore], + usageInserter *atomic.Pointer[usage.Inserter], deploymentValues *codersdk.DeploymentValues, options Options, enqueuer notifications.Enqueuer, @@ -195,6 +198,9 @@ func NewServer( if userQuietHoursScheduleStore == nil { return nil, xerrors.New("userQuietHoursScheduleStore is nil") } + if usageInserter == nil { + return nil, xerrors.New("usageCollector is nil") + } if deploymentValues == nil { return nil, xerrors.New("deploymentValues is nil") } @@ -244,6 +250,7 @@ func NewServer( heartbeatInterval: options.HeartbeatInterval, heartbeatFn: options.HeartbeatFn, PrebuildsOrchestrator: prebuildsOrchestrator, + UsageInserter: usageInserter, } if s.heartbeatFn == nil { @@ -2030,6 +2037,20 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro sidebarAppID = uuid.NullUUID{} } + if hasAITask && workspaceBuild.Transition == database.WorkspaceTransitionStart { + // Insert usage event for managed agents. + usageInserter := s.UsageInserter.Load() + if usageInserter != nil { + event := usage.DCManagedAgentsV1{ + Count: 1, + } + err = (*usageInserter).InsertDiscreteUsageEvent(ctx, db, event) + if err != nil { + return xerrors.Errorf("insert %q event: %w", event.EventType(), err) + } + } + } + hasExternalAgent := false for _, resource := range jobType.WorkspaceBuild.Resources { if resource.Type == "coder_external_agent" { diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 7fb351bf0c0da..8bb06eb52cd70 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" @@ -30,7 +31,9 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" @@ -44,6 +47,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/telemetry" + "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -67,6 +71,13 @@ func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursSc return ptr } +func testUsageInserter() *atomic.Pointer[usage.Inserter] { + ptr := &atomic.Pointer[usage.Inserter]{} + inserter := usage.NewAGPLInserter() + ptr.Store(&inserter) + return ptr +} + func TestAcquireJob_LongPoll(t *testing.T) { t.Parallel() //nolint:dogsled @@ -681,12 +692,20 @@ func TestUpdateJob(t *testing.T) { t.Run("NotRunning", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, nil) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: json.RawMessage("{}"), + ID: version.JobID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + })), OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) @@ -700,12 +719,20 @@ func TestUpdateJob(t *testing.T) { t.Run("NotOwner", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, nil) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: json.RawMessage("{}"), + ID: version.JobID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + })), OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) @@ -730,38 +757,57 @@ func TestUpdateJob(t *testing.T) { require.ErrorContains(t, err, "you don't own this job") }) - setupJob := func(t *testing.T, db database.Store, srvID, orgID uuid.UUID, tags database.StringMap) uuid.UUID { - job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), - OrganizationID: orgID, - Provisioner: database.ProvisionerTypeEcho, - Type: database.ProvisionerJobTypeTemplateVersionImport, - StorageMethod: database.ProvisionerStorageMethodFile, - Input: json.RawMessage("{}"), - Tags: tags, - }) - require.NoError(t, err) - _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ - WorkerID: uuid.NullUUID{ - UUID: srvID, - Valid: true, - }, - Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, - StartedAt: sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - }, - OrganizationID: orgID, - ProvisionerTags: must(json.Marshal(job.Tags)), - }) + setupJob := func(t *testing.T, db database.Store, srvID, orgID uuid.UUID, tags database.StringMap) (templateVersionID, jobID uuid.UUID) { + templateVersionID = uuid.New() + jobID = uuid.New() + err := db.InTx(func(db database.Store) error { + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + ID: templateVersionID, + CreatedBy: user.ID, + OrganizationID: orgID, + JobID: jobID, + }) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: version.JobID, + OrganizationID: orgID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeTemplateVersionImport, + StorageMethod: database.ProvisionerStorageMethodFile, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + })), + Tags: tags, + }) + if err != nil { + return xerrors.Errorf("insert provisioner job: %w", err) + } + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srvID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + StartedAt: sql.NullTime{ + Time: dbtime.Now(), + Valid: true, + }, + OrganizationID: orgID, + ProvisionerTags: must(json.Marshal(job.Tags)), + }) + if err != nil { + return xerrors.Errorf("acquire provisioner job: %w", err) + } + return nil + }, nil) require.NoError(t, err) - return job.ID + return templateVersionID, jobID } t.Run("Success", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), }) @@ -771,7 +817,7 @@ func TestUpdateJob(t *testing.T) { t.Run("Logs", func(t *testing.T) { t.Parallel() srv, db, ps, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) published := make(chan struct{}) @@ -796,23 +842,14 @@ func TestUpdateJob(t *testing.T) { t.Run("Readme", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) - versionID := uuid.New() - user := dbgen.User(t, db, database.User{}) - err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: versionID, - CreatedBy: user.ID, - OrganizationID: pd.OrganizationID, - JobID: job, - }) - require.NoError(t, err) - _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + templateVersionID, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), Readme: []byte("# hello world"), }) require.NoError(t, err) - version, err := db.GetTemplateVersionByID(ctx, versionID) + version, err := db.GetTemplateVersionByID(ctx, templateVersionID) require.NoError(t, err) require.Equal(t, "# hello world", version.Readme) }) @@ -825,16 +862,7 @@ func TestUpdateJob(t *testing.T) { defer cancel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) - versionID := uuid.New() - user := dbgen.User(t, db, database.User{}) - err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: versionID, - CreatedBy: user.ID, - JobID: job, - OrganizationID: pd.OrganizationID, - }) - require.NoError(t, err) + templateVersionID, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) firstTemplateVariable := &sdkproto.TemplateVariable{ Name: "first", Type: "string", @@ -863,7 +891,7 @@ func TestUpdateJob(t *testing.T) { require.NoError(t, err) require.Len(t, response.VariableValues, 2) - templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID) + templateVariables, err := db.GetTemplateVersionVariables(ctx, templateVersionID) require.NoError(t, err) require.Len(t, templateVariables, 2) require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue) @@ -875,16 +903,7 @@ func TestUpdateJob(t *testing.T) { defer cancel() srv, db, _, pd := setup(t, false, &overrides{}) - user := dbgen.User(t, db, database.User{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) - versionID := uuid.New() - err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - CreatedBy: user.ID, - ID: versionID, - JobID: job, - OrganizationID: pd.OrganizationID, - }) - require.NoError(t, err) + templateVersionID, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) firstTemplateVariable := &sdkproto.TemplateVariable{ Name: "first", Type: "string", @@ -909,7 +928,7 @@ func TestUpdateJob(t *testing.T) { // Even though there is an error returned, variables are stored in the database // to show the schema in the site UI. - templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID) + templateVariables, err := db.GetTemplateVersionVariables(ctx, templateVersionID) require.NoError(t, err) require.Len(t, templateVariables, 2) require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue) @@ -923,18 +942,9 @@ func TestUpdateJob(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) - versionID := uuid.New() - user := dbgen.User(t, db, database.User{}) - err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: versionID, - CreatedBy: user.ID, - JobID: job, - OrganizationID: pd.OrganizationID, - }) - require.NoError(t, err) - _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + srv, db, _, pd := setup(t, false, nil) + templateVersionID, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), WorkspaceTags: map[string]string{ "bird": "tweety", @@ -943,7 +953,7 @@ func TestUpdateJob(t *testing.T) { }) require.NoError(t, err) - workspaceTags, err := db.GetTemplateVersionWorkspaceTags(ctx, versionID) + workspaceTags, err := db.GetTemplateVersionWorkspaceTags(ctx, templateVersionID) require.NoError(t, err) require.Len(t, workspaceTags, 2) require.Equal(t, workspaceTags[0].Key, "bird") @@ -955,7 +965,7 @@ func TestUpdateJob(t *testing.T) { t.Run("LogSizeLimit", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) // Create a log message that exceeds the 1MB limit largeOutput := strings.Repeat("a", 1048577) // 1MB + 1 byte @@ -979,7 +989,7 @@ func TestUpdateJob(t *testing.T) { t.Run("IncrementalLogSizeOverflow", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) // Send logs that together exceed the limit mediumOutput := strings.Repeat("b", 524289) // Half a MB + 1 byte @@ -1020,7 +1030,7 @@ func TestUpdateJob(t *testing.T) { t.Run("LogSizeTracking", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) logOutput := "test log message" expectedSize := int32(len(logOutput)) // #nosec G115 - Log length is 16. @@ -1045,7 +1055,7 @@ func TestUpdateJob(t *testing.T) { t.Run("LogOverflowStopsProcessing", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) - job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) + _, job := setupJob(t, db, pd.ID, pd.OrganizationID, pd.Tags) // First: trigger overflow largeOutput := strings.Repeat("a", 1048577) // 1MB + 1 byte @@ -1108,12 +1118,20 @@ func TestFailJob(t *testing.T) { t.Run("NotOwner", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, nil) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeTemplateVersionImport, - Input: json.RawMessage("{}"), + ID: version.JobID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: version.ID, + })), OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) @@ -1139,13 +1157,21 @@ func TestFailJob(t *testing.T) { }) t.Run("AlreadyCompleted", func(t *testing.T) { t.Parallel() - srv, db, _, pd := setup(t, false, &overrides{}) + srv, db, _, pd := setup(t, false, nil) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), - Provisioner: database.ProvisionerTypeEcho, - Type: database.ProvisionerJobTypeTemplateVersionImport, - StorageMethod: database.ProvisionerStorageMethodFile, - Input: json.RawMessage("{}"), + ID: version.JobID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeTemplateVersionImport, + StorageMethod: database.ProvisionerStorageMethodFile, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: version.ID, + })), OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) @@ -1310,14 +1336,22 @@ func TestCompleteJob(t *testing.T) { t.Run("NotOwner", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, nil) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), + ID: version.JobID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeWorkspaceBuild, + Type: database.ProvisionerJobTypeTemplateVersionImport, OrganizationID: pd.OrganizationID, - Input: json.RawMessage("{}"), - Tags: pd.Tags, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + })), + Tags: pd.Tags, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ @@ -1361,10 +1395,12 @@ func TestCompleteJob(t *testing.T) { OrganizationID: pd.OrganizationID, ID: jobID, Provisioner: database.ProvisionerTypeEcho, - Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), - StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeTemplateVersionImport, - Tags: pd.Tags, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: versionID, + })), + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Tags: pd.Tags, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ @@ -1410,14 +1446,22 @@ func TestCompleteJob(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) org := dbgen.Organization(t, db, database.Organization{}) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: org.ID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), OrganizationID: org.ID, Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionDryRun, StorageMethod: database.ProvisionerStorageMethodFile, - Input: json.RawMessage("{}"), - Tags: pd.Tags, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + })), + Tags: pd.Tags, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ @@ -1628,25 +1672,49 @@ func TestCompleteJob(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) jobID := uuid.New() - versionID := uuid.New() user := dbgen.User(t, db, database.User{}) - err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ CreatedBy: user.ID, - ID: versionID, - JobID: jobID, OrganizationID: pd.OrganizationID, + JobID: jobID, + }) + template := dbgen.Template(t, db, database.Template{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + ActiveVersionID: tv.ID, + }) + err := db.UpdateTemplateVersionByID(ctx, database.UpdateTemplateVersionByIDParams{ + ID: tv.ID, + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + UpdatedAt: dbtime.Now(), + Name: tv.Name, + Message: tv.Message, }) require.NoError(t, err) + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: pd.OrganizationID, + TemplateID: template.ID, + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: jobID, Provisioner: database.ProvisionerTypeEcho, - Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), + Input: json.RawMessage("{}"), StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeWorkspaceBuild, OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) require.NoError(t, err) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + TemplateVersionID: tv.ID, + InitiatorID: user.ID, + JobID: jobID, + }) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ OrganizationID: pd.OrganizationID, WorkerID: uuid.NullUUID{ @@ -1697,11 +1765,13 @@ func TestCompleteJob(t *testing.T) { }) require.NoError(t, err) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: jobID, - Provisioner: database.ProvisionerTypeEcho, - Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), + ID: jobID, + Provisioner: database.ProvisionerTypeEcho, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: versionID, + })), StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeWorkspaceBuild, + Type: database.ProvisionerJobTypeTemplateVersionImport, OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) @@ -1766,10 +1836,12 @@ func TestCompleteJob(t *testing.T) { OrganizationID: pd.OrganizationID, ID: jobID, Provisioner: database.ProvisionerTypeEcho, - Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), - StorageMethod: database.ProvisionerStorageMethodFile, - Type: database.ProvisionerJobTypeWorkspaceBuild, - Tags: pd.Tags, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: versionID, + })), + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Tags: pd.Tags, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ @@ -2091,12 +2163,20 @@ func TestCompleteJob(t *testing.T) { t.Run("TemplateDryRun", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{}) + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + CreatedBy: user.ID, + OrganizationID: pd.OrganizationID, + JobID: uuid.New(), + }) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: uuid.New(), - Provisioner: database.ProvisionerTypeEcho, - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - StorageMethod: database.ProvisionerStorageMethodFile, - Input: json.RawMessage("{}"), + ID: version.JobID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + StorageMethod: database.ProvisionerStorageMethodFile, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + })), OrganizationID: pd.OrganizationID, Tags: pd.Tags, }) @@ -2191,8 +2271,10 @@ func TestCompleteJob(t *testing.T) { Transition: database.WorkspaceTransitionStart, }}, provisionerJobParams: database.InsertProvisionerJobParams{ - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: json.RawMessage("{}"), + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: templateVersionID, + })), }, }, { @@ -2349,22 +2431,26 @@ func TestCompleteJob(t *testing.T) { OrganizationID: pd.OrganizationID, }) tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + ID: templateVersionID, CreatedBy: user.ID, OrganizationID: pd.OrganizationID, TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: job.ID, }) - workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ - TemplateID: tpl.ID, - OrganizationID: pd.OrganizationID, - OwnerID: user.ID, - }) - _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - ID: workspaceBuildID, - JobID: job.ID, - WorkspaceID: workspace.ID, - TemplateVersionID: tv.ID, - }) + + if jobParams.Type == database.ProvisionerJobTypeWorkspaceBuild { + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + TemplateID: tpl.ID, + OrganizationID: pd.OrganizationID, + OwnerID: user.ID, + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + ID: workspaceBuildID, + JobID: job.ID, + WorkspaceID: workspace.ID, + TemplateVersionID: tv.ID, + }) + } require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ @@ -2672,7 +2758,10 @@ func TestCompleteJob(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - srv, db, _, pd := setup(t, false, &overrides{}) + fakeUsageInserter, usageInserterPtr := newFakeUsageInserter() + srv, db, _, pd := setup(t, false, &overrides{ + usageInserter: usageInserterPtr, + }) importJobID := uuid.New() tvID := uuid.New() @@ -2741,6 +2830,10 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) require.True(t, version.HasAITask.Valid) // We ALWAYS expect a value to be set, therefore not nil, i.e. valid = true. require.Equal(t, tc.expected, version.HasAITask.Bool) + + // We never expect a usage event to be collected for + // template imports. + require.Empty(t, fakeUsageInserter.collectedEvents) }) } }) @@ -2750,22 +2843,27 @@ func TestCompleteJob(t *testing.T) { // will be set as well in that case. t.Run("WorkspaceBuild", func(t *testing.T) { type testcase struct { - name string - input *proto.CompletedJob_WorkspaceBuild - expected bool + name string + transition database.WorkspaceTransition + input *proto.CompletedJob_WorkspaceBuild + expectHasAiTask bool + expectUsageEvent bool } sidebarAppID := uuid.NewString() for _, tc := range []testcase{ { - name: "has_ai_task is false by default", - input: &proto.CompletedJob_WorkspaceBuild{ + name: "has_ai_task is false by default", + transition: database.WorkspaceTransitionStart, + input: &proto.CompletedJob_WorkspaceBuild{ // No AiTasks defined. }, - expected: false, + expectHasAiTask: false, + expectUsageEvent: false, }, { - name: "has_ai_task is set to true", + name: "has_ai_task is set to true", + transition: database.WorkspaceTransitionStart, input: &proto.CompletedJob_WorkspaceBuild{ AiTasks: []*sdkproto.AITask{ { @@ -2792,11 +2890,13 @@ func TestCompleteJob(t *testing.T) { }, }, }, - expected: true, + expectHasAiTask: true, + expectUsageEvent: true, }, // Checks regression for https://github.com/coder/coder/issues/18776 { - name: "non-existing app", + name: "non-existing app", + transition: database.WorkspaceTransitionStart, input: &proto.CompletedJob_WorkspaceBuild{ AiTasks: []*sdkproto.AITask{ { @@ -2808,13 +2908,49 @@ func TestCompleteJob(t *testing.T) { }, }, }, - expected: false, + expectHasAiTask: false, + expectUsageEvent: false, + }, + { + name: "has_ai_task is set to true, but transition is not start", + transition: database.WorkspaceTransitionStop, + input: &proto.CompletedJob_WorkspaceBuild{ + AiTasks: []*sdkproto.AITask{ + { + Id: uuid.NewString(), + SidebarApp: &sdkproto.AITaskSidebarApp{ + Id: sidebarAppID, + }, + }, + }, + Resources: []*sdkproto.Resource{ + { + Agents: []*sdkproto.Agent{ + { + Id: uuid.NewString(), + Name: "a", + Apps: []*sdkproto.App{ + { + Id: sidebarAppID, + Slug: "test-app", + }, + }, + }, + }, + }, + }, + }, + expectHasAiTask: true, + expectUsageEvent: false, }, } { t.Run(tc.name, func(t *testing.T) { t.Parallel() - srv, db, _, pd := setup(t, false, &overrides{}) + fakeUsageInserter, usageInserterPtr := newFakeUsageInserter() + srv, db, _, pd := setup(t, false, &overrides{ + usageInserter: usageInserterPtr, + }) importJobID := uuid.New() tvID := uuid.New() @@ -2868,7 +3004,7 @@ func TestCompleteJob(t *testing.T) { WorkspaceID: workspaceTable.ID, TemplateVersionID: version.ID, InitiatorID: user.ID, - Transition: database.WorkspaceTransitionStart, + Transition: tc.transition, }) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ @@ -2899,11 +3035,22 @@ func TestCompleteJob(t *testing.T) { build, err = db.GetWorkspaceBuildByID(ctx, build.ID) require.NoError(t, err) require.True(t, build.HasAITask.Valid) // We ALWAYS expect a value to be set, therefore not nil, i.e. valid = true. - require.Equal(t, tc.expected, build.HasAITask.Bool) + require.Equal(t, tc.expectHasAiTask, build.HasAITask.Bool) - if tc.expected { + if tc.expectHasAiTask { require.Equal(t, sidebarAppID, build.AITaskSidebarAppID.UUID.String()) } + + if tc.expectUsageEvent { + // Check that a usage event was collected. + require.Len(t, fakeUsageInserter.collectedEvents, 1) + require.Equal(t, usage.DCManagedAgentsV1{ + Count: 1, + }, fakeUsageInserter.collectedEvents[0]) + } else { + // Check that no usage event was collected. + require.Empty(t, fakeUsageInserter.collectedEvents) + } }) } }) @@ -3835,6 +3982,7 @@ type overrides struct { externalAuthConfigs []*externalauth.Config templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] + usageInserter *atomic.Pointer[usage.Inserter] clock *quartz.Mock acquireJobLongPollDuration time.Duration heartbeatFn func(ctx context.Context) error @@ -3855,13 +4003,14 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi var externalAuthConfigs []*externalauth.Config tss := testTemplateScheduleStore() uqhss := testUserQuietHoursScheduleStore() + usageInserter := testUsageInserter() clock := quartz.NewReal() pollDur := time.Duration(0) if ov == nil { ov = &overrides{} } if ov.ctx == nil { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(dbauthz.AsProvisionerd(context.Background())) t.Cleanup(cancel) ov.ctx = ctx } @@ -3892,6 +4041,15 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi require.True(t, swapped) } } + if ov.usageInserter != nil { + tUsageInserter := usageInserter.Load() + // keep the initial test value if the override hasn't set the atomic pointer. + usageInserter = ov.usageInserter + if usageInserter.Load() == nil { + swapped := usageInserter.CompareAndSwap(nil, tUsageInserter) + require.True(t, swapped) + } + } if ov.clock != nil { clock = ov.clock } @@ -3929,6 +4087,10 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi var op atomic.Pointer[agplprebuilds.ReconciliationOrchestrator] op.Store(&prebuildsOrchestrator) + // Use an authz wrapped database for the server to ensure permission checks + // work. + authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) + serverDB := dbauthz.New(db, authorizer, logger, coderdtest.AccessControlStorePointer()) srv, err := provisionerdserver.NewServer( ov.ctx, proto.CurrentVersion.String(), @@ -3938,7 +4100,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), []database.ProvisionerType{database.ProvisionerTypeEcho}, provisionerdserver.Tags(daemon.Tags), - db, + serverDB, ps, provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps), telemetry.NewNoop(), @@ -3947,6 +4109,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi auditPtr, tss, uqhss, + usageInserter, deploymentValues, provisionerdserver.Options{ ExternalAuthConfigs: externalAuthConfigs, @@ -4061,3 +4224,22 @@ func (s *fakeStream) cancel() { s.canceled = true s.c.Broadcast() } + +type fakeUsageInserter struct { + collectedEvents []usage.Event +} + +var _ usage.Inserter = &fakeUsageInserter{} + +func newFakeUsageInserter() (*fakeUsageInserter, *atomic.Pointer[usage.Inserter]) { + ptr := &atomic.Pointer[usage.Inserter]{} + fake := &fakeUsageInserter{} + var inserter usage.Inserter = fake + ptr.Store(&inserter) + return fake, ptr +} + +func (f *fakeUsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usage.DiscreteEvent) error { + f.collectedEvents = append(f.collectedEvents, event) + return nil +} diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index a8130bea17ad3..0b48a24aebe83 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -76,7 +76,7 @@ const ( SubjectTypeNotifier SubjectType = "notifier" SubjectTypeSubAgentAPI SubjectType = "sub_agent_api" SubjectTypeFileReader SubjectType = "file_reader" - SubjectTypeUsageTracker SubjectType = "usage_tracker" + SubjectTypeUsagePublisher SubjectType = "usage_publisher" ) const ( diff --git a/coderd/usage/inserter.go b/coderd/usage/inserter.go index 08ca8dec3e881..3a0e85f273abb 100644 --- a/coderd/usage/inserter.go +++ b/coderd/usage/inserter.go @@ -10,6 +10,8 @@ import ( type Inserter interface { // InsertDiscreteUsageEvent writes a discrete usage event to the database // within the given transaction. + // The caller context must be authorized to create usage events in the + // database. InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event DiscreteEvent) error } diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 4c9412fda3fb7..504b102e9ee5b 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -335,7 +335,6 @@ func TestUserOAuth2Github(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) - // nolint:gocritic // Unit test count, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false) require.NoError(t, err) require.Equal(t, int64(1), count) @@ -897,7 +896,6 @@ func TestUserOAuth2Github(t *testing.T) { require.Empty(t, links) // Make sure a user_link cannot be created with a deleted user. - // nolint:gocritic // Unit test _, err = db.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{ UserID: deleted.ID, LoginType: "github", diff --git a/coderd/users_test.go b/coderd/users_test.go index 5928fc6486f51..22c9fad5eebea 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -1544,7 +1544,6 @@ func TestUsersFilter(t *testing.T) { } userClient, userData := coderdtest.CreateAnotherUser(t, client, first.OrganizationID, roles...) // Set the last seen for each user to a unique day - // nolint:gocritic // Unit test _, err := api.Database.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{ ID: userData.ID, LastSeenAt: lastSeenNow.Add(-1 * time.Hour * 24 * time.Duration(i)), @@ -1572,7 +1571,6 @@ func TestUsersFilter(t *testing.T) { // Add users with different creation dates for testing date filters for i := 0; i < 3; i++ { - // nolint:gocritic // Using system context is necessary to seed data in tests user1, err := api.Database.InsertUser(dbauthz.AsSystemRestricted(ctx), database.InsertUserParams{ ID: uuid.New(), Email: fmt.Sprintf("before%d@coder.com", i), @@ -1594,7 +1592,6 @@ func TestUsersFilter(t *testing.T) { require.NoError(t, err) users = append(users, sdkUser1) - // nolint:gocritic //Using system context is necessary to seed data in tests user2, err := api.Database.InsertUser(dbauthz.AsSystemRestricted(ctx), database.InsertUserParams{ ID: uuid.New(), Email: fmt.Sprintf("during%d@coder.com", i), @@ -1615,7 +1612,6 @@ func TestUsersFilter(t *testing.T) { require.NoError(t, err) users = append(users, sdkUser2) - // nolint:gocritic // Using system context is necessary to seed data in tests user3, err := api.Database.InsertUser(dbauthz.AsSystemRestricted(ctx), database.InsertUserParams{ ID: uuid.New(), Email: fmt.Sprintf("after%d@coder.com", i), @@ -1912,7 +1908,6 @@ func TestGetUsers(t *testing.T) { Email: "test2@coder.com", Username: "test2", }) - // nolint:gocritic // Unit test err := db.UpdateUserGithubComUserID(dbauthz.AsSystemRestricted(ctx), database.UpdateUserGithubComUserIDParams{ ID: first.UserID, GithubComUserID: sql.NullInt64{ diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index ac58df1b772ad..6f28b12af5ae0 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -562,7 +562,6 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { seed := database.WorkspaceTable{OrganizationID: user.OrganizationID, OwnerID: user.UserID} wsb := dbfake.WorkspaceBuild(t, db, seed).WithAgent().Do() // When: the workspace is marked as soft-deleted - // nolint:gocritic // this is a test err := db.UpdateWorkspaceDeletedByID( dbauthz.AsProvisionerd(ctx), database.UpdateWorkspaceDeletedByIDParams{ID: wsb.Workspace.ID, Deleted: true}, @@ -633,7 +632,6 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) agentToken, err := uuid.Parse(r.AgentToken) require.NoError(t, err) - //nolint: gocritic // testing ao, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentToken) require.NoError(t, err) @@ -724,7 +722,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { agentTokenUUID, err := uuid.Parse(r.AgentToken) require.NoError(t, err) ctx := testutil.Context(t, testutil.WaitLong) - agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint + agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) require.NoError(t, err) // Connect with no resume token, and ensure that the peer ID is set to a @@ -796,7 +794,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { agentTokenUUID, err := uuid.Parse(r.AgentToken) require.NoError(t, err) ctx := testutil.Context(t, testutil.WaitLong) - agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint + agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) require.NoError(t, err) // Connect with no resume token, and ensure that the peer ID is set to a diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index 633acae328673..e888115093a9b 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -55,7 +55,6 @@ func TestWorkspaceBuild(t *testing.T) { Auditor: auditor, }) user := coderdtest.CreateFirstUser(t, client) - //nolint:gocritic // testing up, err := db.UpdateUserProfile(dbauthz.AsSystemRestricted(ctx), database.UpdateUserProfileParams{ ID: user.UserID, Email: coderdtest.FirstUserParams.Email, @@ -518,7 +517,6 @@ func TestWorkspaceBuildsProvisionerState(t *testing.T) { OrganizationID: first.OrganizationID, }).Do() - // nolint:gocritic // For testing daemons, err := store.GetProvisionerDaemons(dbauthz.AsSystemReadProvisionerDaemons(ctx)) require.NoError(t, err) require.Empty(t, daemons, "Provisioner daemons should be empty for this test") diff --git a/coderd/workspaces_test.go b/coderd/workspaces_test.go index 8fc11ef6c8ccb..4df83114c68a1 100644 --- a/coderd/workspaces_test.go +++ b/coderd/workspaces_test.go @@ -1427,7 +1427,6 @@ func TestWorkspaceFilterAllStatus(t *testing.T) { t.Parallel() // For this test, we do not care about permissions. - // nolint:gocritic // unit testing ctx := dbauthz.AsSystemRestricted(context.Background()) db, pubsub := dbtestutil.NewDB(t) client := coderdtest.New(t, &coderdtest.Options{ @@ -2215,15 +2214,12 @@ func TestWorkspaceFilterManual(t *testing.T) { after := coderdtest.CreateWorkspace(t, client, template.ID) _ = coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, after.LatestBuild.ID) - //nolint:gocritic // Unit testing context err := api.Database.UpdateWorkspaceLastUsedAt(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceLastUsedAtParams{ ID: before.ID, LastUsedAt: now.UTC().Add(time.Hour * -1), }) require.NoError(t, err) - // Unit testing context - //nolint:gocritic // Unit testing context err = api.Database.UpdateWorkspaceLastUsedAt(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceLastUsedAtParams{ ID: after.ID, LastUsedAt: now.UTC().Add(time.Hour * 1), @@ -2916,14 +2912,14 @@ func TestWorkspaceUpdateTTL(t *testing.T) { // This is a hack, but the max_deadline isn't precisely configurable // without a lot of unnecessary hassle. - dbBuild, err := db.GetWorkspaceBuildByID(dbauthz.AsSystemRestricted(ctx), build.ID) //nolint:gocritic // test + dbBuild, err := db.GetWorkspaceBuildByID(dbauthz.AsSystemRestricted(ctx), build.ID) require.NoError(t, err) - dbJob, err := db.GetProvisionerJobByID(dbauthz.AsSystemRestricted(ctx), dbBuild.JobID) //nolint:gocritic // test + dbJob, err := db.GetProvisionerJobByID(dbauthz.AsSystemRestricted(ctx), dbBuild.JobID) require.NoError(t, err) require.True(t, dbJob.CompletedAt.Valid) initialDeadline := dbJob.CompletedAt.Time.Add(deadline) expectedMaxDeadline := dbJob.CompletedAt.Time.Add(maxDeadline) - err = db.UpdateWorkspaceBuildDeadlineByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceBuildDeadlineByIDParams{ //nolint:gocritic // test + err = db.UpdateWorkspaceBuildDeadlineByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceBuildDeadlineByIDParams{ ID: build.ID, Deadline: initialDeadline, MaxDeadline: expectedMaxDeadline, @@ -4507,14 +4503,12 @@ func TestOIDCRemoved(t *testing.T) { user, userData := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, rbac.ScopedRoleOrgAdmin(first.OrganizationID)) ctx := testutil.Context(t, testutil.WaitMedium) - //nolint:gocritic // unit test _, err := db.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{ NewLoginType: database.LoginTypeOIDC, UserID: userData.ID, }) require.NoError(t, err) - //nolint:gocritic // unit test _, err = db.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{ UserID: userData.ID, LoginType: database.LoginTypeOIDC, @@ -4603,7 +4597,6 @@ func TestWorkspaceFilterHasAITask(t *testing.T) { }) if aiTaskPrompt != nil { - //nolint:gocritic // unit test err := db.InsertWorkspaceBuildParameters(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceBuildParametersParams{ WorkspaceBuildID: build.ID, Name: []string{provider.TaskPromptParameterName}, @@ -4806,7 +4799,6 @@ func TestMultipleAITasksDisallowed(t *testing.T) { ws := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID) - //nolint: gocritic // testing ctx := dbauthz.AsSystemRestricted(t.Context()) pj, err := db.GetProvisionerJobByID(ctx, ws.LatestBuild.Job.ID) require.NoError(t, err) diff --git a/enterprise/cli/prebuilds_test.go b/enterprise/cli/prebuilds_test.go index 76d11a41d67f0..cf0c74105020c 100644 --- a/enterprise/cli/prebuilds_test.go +++ b/enterprise/cli/prebuilds_test.go @@ -434,7 +434,6 @@ func TestSchedulePrebuilds(t *testing.T) { }).Do() // Mark the prebuilt workspace's agent as ready so the prebuild can be claimed - // nolint:gocritic ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong)) agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(workspaceBuild.AgentToken)) require.NoError(t, err) diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index 3b1fd63ab1c4c..f58ec86b58a43 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/v2/enterprise/audit/backends" "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/dormancy" + "github.com/coder/coder/v2/enterprise/coderd/usage" "github.com/coder/coder/v2/enterprise/dbcrypt" "github.com/coder/coder/v2/enterprise/trialer" "github.com/coder/coder/v2/tailnet" @@ -116,11 +117,33 @@ func (r *RootCmd) Server(_ func()) *serpent.Command { o.ExternalTokenEncryption = cs } + if o.LicenseKeys == nil { + o.LicenseKeys = coderd.Keys + } + + closers := &multiCloser{} + + // Create the enterprise API. api, err := coderd.New(ctx, o) if err != nil { return nil, nil, err } - return api.AGPL, api, nil + closers.Add(api) + + // Start the enterprise usage publisher routine. This won't do anything + // unless the deployment is licensed and one of the licenses has usage + // publishing enabled. + publisher := usage.NewTallymanPublisher(ctx, options.Logger, options.Database, o.LicenseKeys, + usage.PublisherWithHTTPClient(api.HTTPClient), + ) + err = publisher.Start() + if err != nil { + _ = closers.Close() + return nil, nil, xerrors.Errorf("start usage publisher: %w", err) + } + closers.Add(publisher) + + return api.AGPL, closers, nil }) cmd.AddSubcommands( @@ -128,3 +151,23 @@ func (r *RootCmd) Server(_ func()) *serpent.Command { ) return cmd } + +type multiCloser struct { + closers []io.Closer +} + +var _ io.Closer = &multiCloser{} + +func (m *multiCloser) Add(closer io.Closer) { + m.closers = append(m.closers, closer) +} + +func (m *multiCloser) Close() error { + var errs []error + for _, closer := range m.closers { + if err := closer.Close(); err != nil { + errs = append(errs, xerrors.Errorf("close %T: %w", closer, err)) + } + } + return errors.Join(errs...) +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 8190de103cd7a..a81e16585473b 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/coder/coder/v2/buildinfo" @@ -21,10 +22,12 @@ import ( "github.com/coder/coder/v2/coderd/pproflabel" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" + agplusage "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/portsharing" + "github.com/coder/coder/v2/enterprise/coderd/usage" "github.com/coder/quartz" "golang.org/x/xerrors" @@ -90,6 +93,13 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { if options.Entitlements == nil { options.Entitlements = entitlements.New() } + if options.Options.UsageInserter == nil { + options.Options.UsageInserter = &atomic.Pointer[agplusage.Inserter]{} + } + if options.Options.UsageInserter.Load() == nil { + collector := usage.NewDBInserter() + options.Options.UsageInserter.Store(&collector) + } ctx, cancelFunc := context.WithCancel(ctx) diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 94d9e4fda20df..302b367c304cd 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -154,7 +154,6 @@ func TestEntitlements(t *testing.T) { entitlements, err := anotherClient.Entitlements(context.Background()) require.NoError(t, err) require.False(t, entitlements.HasLicense) - //nolint:gocritic // unit test ctx := testDBAuthzRole(context.Background()) _, err = api.Database.InsertLicense(ctx, database.InsertLicenseParams{ UploadedAt: dbtime.Now(), @@ -186,7 +185,6 @@ func TestEntitlements(t *testing.T) { require.False(t, entitlements.HasLicense) // Valid ctx := context.Background() - //nolint:gocritic // unit test _, err = api.Database.InsertLicense(testDBAuthzRole(ctx), database.InsertLicenseParams{ UploadedAt: dbtime.Now(), Exp: dbtime.Now().AddDate(1, 0, 0), @@ -198,7 +196,6 @@ func TestEntitlements(t *testing.T) { }) require.NoError(t, err) // Expired - //nolint:gocritic // unit test _, err = api.Database.InsertLicense(testDBAuthzRole(ctx), database.InsertLicenseParams{ UploadedAt: dbtime.Now(), Exp: dbtime.Now().AddDate(-1, 0, 0), @@ -208,7 +205,6 @@ func TestEntitlements(t *testing.T) { }) require.NoError(t, err) // Invalid - //nolint:gocritic // unit test _, err = api.Database.InsertLicense(testDBAuthzRole(ctx), database.InsertLicenseParams{ UploadedAt: dbtime.Now(), Exp: dbtime.Now().AddDate(1, 0, 0), diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 13a9bd69ed8fd..c3bae7cd1d848 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -56,7 +56,6 @@ func TestOrganizationSync(t *testing.T) { requireUserOrgs := func(t *testing.T, db database.Store, user database.User, expected []uuid.UUID) { t.Helper() - // nolint:gocritic // in testing members, err := db.OrganizationMembers(dbauthz.AsSystemRestricted(context.Background()), database.OrganizationMembersParams{ UserID: user.ID, }) diff --git a/enterprise/coderd/idpsync_test.go b/enterprise/coderd/idpsync_test.go index d34701c3f6936..49d83a62688ba 100644 --- a/enterprise/coderd/idpsync_test.go +++ b/enterprise/coderd/idpsync_test.go @@ -39,7 +39,6 @@ func TestGetGroupSyncSettings(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) dbresv := runtimeconfig.OrganizationResolver(user.OrganizationID, runtimeconfig.NewStoreResolver(db)) entry := runtimeconfig.MustNew[*idpsync.GroupSyncSettings]("group-sync-settings") - //nolint:gocritic // Requires system context to set runtime config err := entry.SetRuntimeValue(dbauthz.AsSystemRestricted(ctx), dbresv, &idpsync.GroupSyncSettings{Field: "august"}) require.NoError(t, err) diff --git a/enterprise/coderd/prebuilds/metricscollector_test.go b/enterprise/coderd/prebuilds/metricscollector_test.go index 1e9f3f5082806..b852079beb2af 100644 --- a/enterprise/coderd/prebuilds/metricscollector_test.go +++ b/enterprise/coderd/prebuilds/metricscollector_test.go @@ -231,7 +231,6 @@ func TestMetricsCollector(t *testing.T) { } // Force an update to the metrics state to allow the collector to collect fresh metrics. - // nolint:gocritic // Authz context needed to retrieve state. require.NoError(t, collector.UpdateState(dbauthz.AsPrebuildsOrchestrator(ctx), testutil.WaitLong)) metricsFamilies, err := registry.Gather() @@ -367,7 +366,6 @@ func TestMetricsCollector_DuplicateTemplateNames(t *testing.T) { "organization_name": defaultOrg.Name, } - // nolint:gocritic // Authz context needed to retrieve state. ctx = dbauthz.AsPrebuildsOrchestrator(ctx) // Then: metrics collect successfully. diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index c8304952781d1..65b03a7d6b864 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -352,6 +352,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) &api.AGPL.Auditor, api.AGPL.TemplateScheduleStore, api.AGPL.UserQuietHoursScheduleStore, + api.AGPL.UsageInserter, api.DeploymentValues, provisionerdserver.Options{ ExternalAuthConfigs: api.ExternalAuthConfigs, diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index a94a60ffff3c2..5797e978fa34c 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -682,7 +682,6 @@ func TestProvisionerDaemonServe(t *testing.T) { if tc.insertParams.Name != "" { tc.insertParams.OrganizationID = user.OrganizationID - // nolint:gocritic // test _, err := db.InsertProvisionerKey(dbauthz.AsSystemRestricted(ctx), tc.insertParams) require.NoError(t, err) } @@ -945,7 +944,6 @@ func TestGetProvisionerDaemons(t *testing.T) { daemonCreatedAt := time.Now() - //nolint:gocritic // We're not testing auth on the following in this test provisionerKey, err := db.InsertProvisionerKey(dbauthz.AsSystemRestricted(ctx), database.InsertProvisionerKeyParams{ Name: "Test Provisioner Key", ID: uuid.New(), @@ -956,7 +954,6 @@ func TestGetProvisionerDaemons(t *testing.T) { }) require.NoError(t, err, "should be able to create a provisioner key") - //nolint:gocritic // We're not testing auth on the following in this test pd, err := db.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{ CreatedAt: daemonCreatedAt, Name: "Test Provisioner Daemon", diff --git a/enterprise/coderd/schedule/template_test.go b/enterprise/coderd/schedule/template_test.go index 70dc3084899ad..e764826f76922 100644 --- a/enterprise/coderd/schedule/template_test.go +++ b/enterprise/coderd/schedule/template_test.go @@ -719,7 +719,6 @@ func TestNotifications(t *testing.T) { // Lower the dormancy TTL to ensure the schedule recalculates deadlines and // triggers notifications. - // nolint:gocritic // Need an actor in the context. _, err = templateScheduleStore.Set(dbauthz.AsNotifier(ctx), db, template, agplschedule.TemplateScheduleOptions{ TimeTilDormant: timeTilDormant / 2, TimeTilDormantAutoDelete: timeTilDormant / 2, diff --git a/enterprise/coderd/usage/inserter.go b/enterprise/coderd/usage/inserter.go index 3320c25d454ce..e07b536e758bd 100644 --- a/enterprise/coderd/usage/inserter.go +++ b/enterprise/coderd/usage/inserter.go @@ -13,16 +13,17 @@ import ( "github.com/coder/quartz" ) -// Inserter accepts usage events and stores them in the database for publishing. -type Inserter struct { +// dbCollector collects usage events and stores them in the database for +// publishing. +type dbCollector struct { clock quartz.Clock } -var _ agplusage.Inserter = &Inserter{} +var _ agplusage.Inserter = &dbCollector{} -// NewInserter creates a new database-backed usage event inserter. -func NewInserter(opts ...InserterOptions) *Inserter { - c := &Inserter{ +// NewDBInserter creates a new database-backed usage event inserter. +func NewDBInserter(opts ...InserterOption) agplusage.Inserter { + c := &dbCollector{ clock: quartz.NewReal(), } for _, opt := range opts { @@ -31,17 +32,17 @@ func NewInserter(opts ...InserterOptions) *Inserter { return c } -type InserterOptions func(*Inserter) +type InserterOption func(*dbCollector) // InserterWithClock sets the quartz clock to use for the inserter. -func InserterWithClock(clock quartz.Clock) InserterOptions { - return func(c *Inserter) { +func InserterWithClock(clock quartz.Clock) InserterOption { + return func(c *dbCollector) { c.clock = clock } } // InsertDiscreteUsageEvent implements agplusage.Inserter. -func (c *Inserter) InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event agplusage.DiscreteEvent) error { +func (i *dbCollector) InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event agplusage.DiscreteEvent) error { if !event.EventType().IsDiscrete() { return xerrors.Errorf("event type %q is not a discrete event", event.EventType()) } @@ -61,6 +62,6 @@ func (c *Inserter) InsertDiscreteUsageEvent(ctx context.Context, tx database.Sto ID: uuid.New().String(), EventType: string(event.EventType()), EventData: jsonData, - CreatedAt: dbtime.Time(c.clock.Now()), + CreatedAt: dbtime.Time(i.clock.Now()), }) } diff --git a/enterprise/coderd/usage/inserter_test.go b/enterprise/coderd/usage/inserter_test.go index c5abd931cfaba..b7ced536aef3b 100644 --- a/enterprise/coderd/usage/inserter_test.go +++ b/enterprise/coderd/usage/inserter_test.go @@ -28,7 +28,7 @@ func TestInserter(t *testing.T) { ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) clock := quartz.NewMock(t) - inserter := usage.NewInserter(usage.InserterWithClock(clock)) + inserter := usage.NewDBInserter(usage.InserterWithClock(clock)) now := dbtime.Now() events := []struct { @@ -51,8 +51,8 @@ func TestInserter(t *testing.T) { for _, event := range events { eventJSON := jsoninate(t, event.event) - db.EXPECT().InsertUsageEvent(ctx, gomock.Any()).DoAndReturn( - func(ctx interface{}, params database.InsertUsageEventParams) error { + db.EXPECT().InsertUsageEvent(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx any, params database.InsertUsageEventParams) error { _, err := uuid.Parse(params.ID) assert.NoError(t, err) assert.Equal(t, string(event.event.EventType()), params.EventType) @@ -76,7 +76,7 @@ func TestInserter(t *testing.T) { db := dbmock.NewMockStore(ctrl) // We should get an error if the event is invalid. - inserter := usage.NewInserter() + inserter := usage.NewDBInserter() err := inserter.InsertDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ Count: 0, // invalid }) diff --git a/enterprise/coderd/usage/publisher.go b/enterprise/coderd/usage/publisher.go index e8722841160fb..8c0811c7727c8 100644 --- a/enterprise/coderd/usage/publisher.go +++ b/enterprise/coderd/usage/publisher.go @@ -15,11 +15,11 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/pproflabel" agplusage "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/cryptorand" - "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/quartz" ) @@ -49,17 +49,17 @@ type Publisher interface { } type tallymanPublisher struct { - ctx context.Context - ctxCancel context.CancelFunc - log slog.Logger - db database.Store - done chan struct{} + ctx context.Context + ctxCancel context.CancelFunc + log slog.Logger + db database.Store + licenseKeys map[string]ed25519.PublicKey + done chan struct{} // Configured with options: ingestURL string httpClient *http.Client clock quartz.Clock - licenseKeys map[string]ed25519.PublicKey initialDelay time.Duration } @@ -67,19 +67,21 @@ var _ Publisher = &tallymanPublisher{} // NewTallymanPublisher creates a Publisher that publishes usage events to // Coder's Tallyman service. -func NewTallymanPublisher(ctx context.Context, log slog.Logger, db database.Store, opts ...TallymanPublisherOption) Publisher { +func NewTallymanPublisher(ctx context.Context, log slog.Logger, db database.Store, keys map[string]ed25519.PublicKey, opts ...TallymanPublisherOption) Publisher { ctx, cancel := context.WithCancel(ctx) + ctx = dbauthz.AsUsagePublisher(ctx) //nolint:gocritic // we intentionally want to be able to process usage events + publisher := &tallymanPublisher{ - ctx: ctx, - ctxCancel: cancel, - log: log, - db: db, - done: make(chan struct{}), + ctx: ctx, + ctxCancel: cancel, + log: log, + db: db, + licenseKeys: keys, + done: make(chan struct{}), - ingestURL: tallymanIngestURLV1, - httpClient: http.DefaultClient, - clock: quartz.NewReal(), - licenseKeys: coderd.Keys, + ingestURL: tallymanIngestURLV1, + httpClient: http.DefaultClient, + clock: quartz.NewReal(), } for _, opt := range opts { opt(publisher) @@ -92,6 +94,9 @@ type TallymanPublisherOption func(*tallymanPublisher) // PublisherWithHTTPClient sets the HTTP client to use for publishing usage events. func PublisherWithHTTPClient(httpClient *http.Client) TallymanPublisherOption { return func(p *tallymanPublisher) { + if httpClient == nil { + httpClient = http.DefaultClient + } p.httpClient = httpClient } } @@ -103,14 +108,6 @@ func PublisherWithClock(clock quartz.Clock) TallymanPublisherOption { } } -// PublisherWithLicenseKeys sets the license public keys to use for license -// validation. -func PublisherWithLicenseKeys(keys map[string]ed25519.PublicKey) TallymanPublisherOption { - return func(p *tallymanPublisher) { - p.licenseKeys = keys - } -} - // PublisherWithIngestURL sets the ingest URL to use for publishing usage // events. func PublisherWithIngestURL(ingestURL string) TallymanPublisherOption { @@ -149,6 +146,10 @@ func (p *tallymanPublisher) Start() error { p.initialDelay = tallymanPublishInitialMinimumDelay + time.Duration(plusDelay) } + if len(p.licenseKeys) == 0 { + return xerrors.New("no license keys provided") + } + pproflabel.Go(ctx, pproflabel.Service(pproflabel.ServiceTallymanPublisher), func(ctx context.Context) { p.publishLoop(ctx, deploymentUUID) }) diff --git a/enterprise/coderd/usage/publisher_test.go b/enterprise/coderd/usage/publisher_test.go index a2a997b032ac0..7a17935a64a61 100644 --- a/enterprise/coderd/usage/publisher_test.go +++ b/enterprise/coderd/usage/publisher_test.go @@ -10,16 +10,20 @@ import ( "time" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "go.uber.org/mock/gomock" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/rbac" agplusage "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/usage" @@ -40,6 +44,7 @@ func TestIntegration(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) log := slogtest.Make(t, nil) db, _ := dbtestutil.NewDB(t) + clock := quartz.NewMock(t) deploymentID, licenseJWT := configureDeployment(ctx, t, db) now := time.Now() @@ -60,7 +65,7 @@ func TestIntegration(t *testing.T) { return handler(req) })) - inserter := usage.NewInserter( + inserter := usage.NewDBInserter( usage.InserterWithClock(clock), ) // Insert an old event that should never be published. @@ -80,10 +85,12 @@ func TestIntegration(t *testing.T) { require.NoErrorf(t, err, "collecting event %d", i) } - publisher := usage.NewTallymanPublisher(ctx, log, db, + // Wrap the publisher's DB in a dbauthz to ensure that the publisher has + // enough permissions. + authzDB := dbauthz.New(db, rbac.NewAuthorizer(prometheus.NewRegistry()), log, coderdtest.AccessControlStorePointer()) + publisher := usage.NewTallymanPublisher(ctx, log, authzDB, coderdenttest.Keys, usage.PublisherWithClock(clock), usage.PublisherWithIngestURL(ingestURL), - usage.PublisherWithLicenseKeys(coderdenttest.Keys), ) defer publisher.Close() @@ -212,10 +219,9 @@ func TestPublisherNoEligibleLicenses(t *testing.T) { } })) - publisher := usage.NewTallymanPublisher(ctx, log, db, + publisher := usage.NewTallymanPublisher(ctx, log, db, coderdenttest.Keys, usage.PublisherWithClock(clock), usage.PublisherWithIngestURL(ingestURL), - usage.PublisherWithLicenseKeys(coderdenttest.Keys), ) defer publisher.Close() @@ -283,14 +289,13 @@ func TestPublisherClaimExpiry(t *testing.T) { return tallymanAcceptAllHandler(req) })) - inserter := usage.NewInserter( + inserter := usage.NewDBInserter( usage.InserterWithClock(clock), ) - publisher := usage.NewTallymanPublisher(ctx, log, db, + publisher := usage.NewTallymanPublisher(ctx, log, db, coderdenttest.Keys, usage.PublisherWithClock(clock), usage.PublisherWithIngestURL(ingestURL), - usage.PublisherWithLicenseKeys(coderdenttest.Keys), usage.PublisherWithInitialDelay(17*time.Minute), ) defer publisher.Close() @@ -367,10 +372,9 @@ func TestPublisherMissingEvents(t *testing.T) { } })) - publisher := usage.NewTallymanPublisher(ctx, log, db, + publisher := usage.NewTallymanPublisher(ctx, log, db, coderdenttest.Keys, usage.PublisherWithClock(clock), usage.PublisherWithIngestURL(ingestURL), - usage.PublisherWithLicenseKeys(coderdenttest.Keys), ) // Expect the publisher to call SelectUsageEventsForPublishing, followed by @@ -510,10 +514,9 @@ func TestPublisherLicenseSelection(t *testing.T) { return tallymanAcceptAllHandler(req) })) - publisher := usage.NewTallymanPublisher(ctx, log, db, + publisher := usage.NewTallymanPublisher(ctx, log, db, coderdenttest.Keys, usage.PublisherWithClock(clock), usage.PublisherWithIngestURL(ingestURL), - usage.PublisherWithLicenseKeys(coderdenttest.Keys), ) defer publisher.Close() @@ -579,10 +582,9 @@ func TestPublisherTallymanError(t *testing.T) { } })) - publisher := usage.NewTallymanPublisher(ctx, log, db, + publisher := usage.NewTallymanPublisher(ctx, log, db, coderdenttest.Keys, usage.PublisherWithClock(clock), usage.PublisherWithIngestURL(ingestURL), - usage.PublisherWithLicenseKeys(coderdenttest.Keys), ) defer publisher.Close() diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 46207f319dbe1..fd4706a25e511 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -941,7 +941,6 @@ func TestGroupSync(t *testing.T) { require.NoError(t, err) } - // nolint:gocritic _, err := runner.API.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{ NewLoginType: database.LoginTypeOIDC, UserID: user.ID, diff --git a/enterprise/coderd/workspacequota_test.go b/enterprise/coderd/workspacequota_test.go index f49e135ad55b3..f39b090ca21b1 100644 --- a/enterprise/coderd/workspacequota_test.go +++ b/enterprise/coderd/workspacequota_test.go @@ -462,7 +462,6 @@ func TestWorkspaceSerialization(t *testing.T) { // +------------------------------+------------------+ // pq: could not serialize access due to concurrent update ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -520,7 +519,6 @@ func TestWorkspaceSerialization(t *testing.T) { // +------------------------------+------------------+ // Works! ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -589,7 +587,6 @@ func TestWorkspaceSerialization(t *testing.T) { // +---------------------+----------------------------------+ // pq: could not serialize access due to concurrent update ctx := testutil.Context(t, testutil.WaitShort) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -642,7 +639,6 @@ func TestWorkspaceSerialization(t *testing.T) { // | CommitTx() | | // +---------------------+----------------------------------+ ctx := testutil.Context(t, testutil.WaitShort) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -686,7 +682,6 @@ func TestWorkspaceSerialization(t *testing.T) { // +---------------------+----------------------------------+ // Works! ctx := testutil.Context(t, testutil.WaitShort) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) var err error @@ -741,7 +736,6 @@ func TestWorkspaceSerialization(t *testing.T) { // | | CommitTx() | // +---------------------+---------------------+ ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -799,7 +793,6 @@ func TestWorkspaceSerialization(t *testing.T) { // | | CommitTx() | // +---------------------+---------------------+ ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -860,7 +853,6 @@ func TestWorkspaceSerialization(t *testing.T) { // +---------------------+---------------------+ // pq: could not serialize access due to read/write dependencies among transactions ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // testing ctx = dbauthz.AsSystemRestricted(ctx) myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index dc44a8794e1c6..1cdcd9fb43144 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -570,7 +570,6 @@ func TestCreateUserWorkspace(t *testing.T) { return a }).Do() - // nolint:gocritic // this is a test ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong)) agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(r.AgentToken)) require.NoError(t, err) @@ -1708,7 +1707,6 @@ func TestWorkspaceAutobuild(t *testing.T) { // We want to test the database nullifies the NextStartAt so we // make a raw DB call here. We pass in NextStartAt here so we // can test the database will nullify it and not us. - //nolint: gocritic // We need system context to modify this. err = db.UpdateWorkspaceAutostart(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAutostartParams{ ID: ws.ID, AutostartSchedule: sql.NullString{Valid: true, String: sched.String()}, @@ -2720,7 +2718,6 @@ func TestPrebuildUpdateLifecycleParams(t *testing.T) { }).Do() // Mark the prebuilt workspace's agent as ready so the prebuild can be claimed - // nolint:gocritic ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong)) agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(workspaceBuild.AgentToken)) require.NoError(t, err) @@ -3722,7 +3719,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { require.Equal(t, ws.LatestBuild.MatchedProvisioners.Available, 0) // Verify that the provisioner daemon is registered in the database - //nolint:gocritic // unit testing daemons, err := db.GetProvisionerDaemons(dbauthz.AsSystemRestricted(ctx)) require.NoError(t, err) require.Equal(t, 1, len(daemons)) @@ -3758,7 +3754,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { ctx = testutil.Context(t, testutil.WaitLong) // Reset the context to avoid timeouts. - // nolint:gocritic // unit testing daemons, err := db.GetProvisionerDaemons(dbauthz.AsSystemRestricted(ctx)) require.NoError(t, err) require.Equal(t, len(daemons), 1) @@ -3768,8 +3763,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { require.NoError(t, err) // Simulate it's subsequent deletion from the database: - - // nolint:gocritic // unit testing _, err = db.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{ Name: daemons[0].Name, OrganizationID: daemons[0].OrganizationID, @@ -3787,7 +3780,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { }, }) require.NoError(t, err) - // nolint:gocritic // unit testing err = db.DeleteOldProvisionerDaemons(dbauthz.AsSystemRestricted(ctx)) require.NoError(t, err) @@ -3798,7 +3790,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { require.Equal(t, workspace.LatestBuild.MatchedProvisioners.Count, 0) require.Equal(t, workspace.LatestBuild.MatchedProvisioners.Available, 0) - // nolint:gocritic // unit testing _, err = client.WorkspaceByOwnerAndName(dbauthz.As(ctx, userSubject), username, workspace.Name, codersdk.WorkspaceOptions{}) require.NoError(t, err) require.Equal(t, workspace.LatestBuild.Status, codersdk.WorkspaceStatusPending) @@ -3835,7 +3826,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { ctx = testutil.Context(t, testutil.WaitLong) // Reset the context to avoid timeouts. - // nolint:gocritic // unit testing daemons, err := db.GetProvisionerDaemons(dbauthz.AsSystemRestricted(ctx)) require.NoError(t, err) require.Equal(t, len(daemons), 1) @@ -3844,7 +3834,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { err = closer.Close() require.NoError(t, err) - // nolint:gocritic // unit testing _, err = db.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{ Name: daemons[0].Name, OrganizationID: daemons[0].OrganizationID, diff --git a/scripts/rules.go b/scripts/rules.go index f15582d12a4bb..dce029a102d01 100644 --- a/scripts/rules.go +++ b/scripts/rules.go @@ -37,7 +37,9 @@ func dbauthzAuthorizationContext(m dsl.Matcher) { Where( m["c"].Type.Implements("context.Context") && // Only report on functions that start with "As". - m["f"].Text.Matches("^As"), + m["f"].Text.Matches("^As") && + // Ignore test usages of dbauthz contexts. + !m.File().Name.Matches(`_test\.go$`), ). // Instructions for fixing the lint error should be included on the dangerous function. Report("Using '$f' is dangerous and should be accompanied by a comment explaining why it's ok and a nolint.") From 1a601c30ad659dae763eb2573d0deeaf4287782d Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 20 Aug 2025 23:48:38 +1000 Subject: [PATCH 07/11] chore: move usage types to new package (#19103) --- .../provisionerdserver/provisionerdserver.go | 15 +- .../provisionerdserver_test.go | 7 +- coderd/usage/events.go | 82 ----------- coderd/usage/inserter.go | 5 +- coderd/usage/usagetypes/events.go | 129 ++++++++++++++++++ coderd/usage/usagetypes/events_test.go | 61 +++++++++ coderd/usage/usagetypes/tallyman.go | 70 ++++++++++ coderd/usage/usagetypes/tallyman_test.go | 85 ++++++++++++ enterprise/coderd/usage/inserter.go | 15 +- enterprise/coderd/usage/inserter_test.go | 24 ++-- enterprise/coderd/usage/publisher.go | 79 ++++------- enterprise/coderd/usage/publisher_test.go | 123 +++++++++-------- 12 files changed, 470 insertions(+), 225 deletions(-) delete mode 100644 coderd/usage/events.go create mode 100644 coderd/usage/usagetypes/events.go create mode 100644 coderd/usage/usagetypes/events_test.go create mode 100644 coderd/usage/usagetypes/tallyman.go create mode 100644 coderd/usage/usagetypes/tallyman_test.go diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 93573131a04f8..d7bc29aca3044 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -28,14 +28,6 @@ import ( protobuf "google.golang.org/protobuf/proto" "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/usage" - "github.com/coder/coder/v2/coderd/util/slice" - - "github.com/coder/coder/v2/codersdk/drpcsdk" - - "github.com/coder/quartz" - "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" @@ -49,13 +41,18 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" + "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/provisioner" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" + "github.com/coder/quartz" ) const ( @@ -2041,7 +2038,7 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro // Insert usage event for managed agents. usageInserter := s.UsageInserter.Load() if usageInserter != nil { - event := usage.DCManagedAgentsV1{ + event := usagetypes.DCManagedAgentsV1{ Count: 1, } err = (*usageInserter).InsertDiscreteUsageEvent(ctx, db, event) diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 8bb06eb52cd70..8baa7c99c30b9 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -48,6 +48,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -3044,7 +3045,7 @@ func TestCompleteJob(t *testing.T) { if tc.expectUsageEvent { // Check that a usage event was collected. require.Len(t, fakeUsageInserter.collectedEvents, 1) - require.Equal(t, usage.DCManagedAgentsV1{ + require.Equal(t, usagetypes.DCManagedAgentsV1{ Count: 1, }, fakeUsageInserter.collectedEvents[0]) } else { @@ -4226,7 +4227,7 @@ func (s *fakeStream) cancel() { } type fakeUsageInserter struct { - collectedEvents []usage.Event + collectedEvents []usagetypes.Event } var _ usage.Inserter = &fakeUsageInserter{} @@ -4239,7 +4240,7 @@ func newFakeUsageInserter() (*fakeUsageInserter, *atomic.Pointer[usage.Inserter] return fake, ptr } -func (f *fakeUsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usage.DiscreteEvent) error { +func (f *fakeUsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error { f.collectedEvents = append(f.collectedEvents, event) return nil } diff --git a/coderd/usage/events.go b/coderd/usage/events.go deleted file mode 100644 index f0910eefc2814..0000000000000 --- a/coderd/usage/events.go +++ /dev/null @@ -1,82 +0,0 @@ -package usage - -import ( - "strings" - - "golang.org/x/xerrors" -) - -// EventType is an enum of all usage event types. It mirrors the check -// constraint on the `event_type` column in the `usage_events` table. -type EventType string //nolint:revive - -const ( - UsageEventTypeDCManagedAgentsV1 EventType = "dc_managed_agents_v1" -) - -func (e EventType) Valid() bool { - switch e { - case UsageEventTypeDCManagedAgentsV1: - return true - default: - return false - } -} - -func (e EventType) IsDiscrete() bool { - return e.Valid() && strings.HasPrefix(string(e), "dc_") -} - -func (e EventType) IsHeartbeat() bool { - return e.Valid() && strings.HasPrefix(string(e), "hb_") -} - -// Event is a usage event that can be collected by the usage collector. -// -// Note that the following event types should not be updated once they are -// merged into the product. Please consult Dean before making any changes. -// -// Event types cannot be implemented outside of this package, as they are -// imported by the coder/tallyman repository. -type Event interface { - usageEvent() // to prevent external types from implementing this interface - EventType() EventType - Valid() error - Fields() map[string]any // fields to be marshaled and sent to tallyman/Metronome -} - -// DiscreteEvent is a usage event that is collected as a discrete event. -type DiscreteEvent interface { - Event - discreteUsageEvent() // marker method, also prevents external types from implementing this interface -} - -// DCManagedAgentsV1 is a discrete usage event for the number of managed agents. -// This event is sent in the following situations: -// - Once on first startup after usage tracking is added to the product with -// the count of all existing managed agents (count=N) -// - A new managed agent is created (count=1) -type DCManagedAgentsV1 struct { - Count uint64 `json:"count"` -} - -var _ DiscreteEvent = DCManagedAgentsV1{} - -func (DCManagedAgentsV1) usageEvent() {} -func (DCManagedAgentsV1) discreteUsageEvent() {} -func (DCManagedAgentsV1) EventType() EventType { - return UsageEventTypeDCManagedAgentsV1 -} - -func (e DCManagedAgentsV1) Valid() error { - if e.Count == 0 { - return xerrors.New("count must be greater than 0") - } - return nil -} - -func (e DCManagedAgentsV1) Fields() map[string]any { - return map[string]any{ - "count": e.Count, - } -} diff --git a/coderd/usage/inserter.go b/coderd/usage/inserter.go index 3a0e85f273abb..7a0f42daf4724 100644 --- a/coderd/usage/inserter.go +++ b/coderd/usage/inserter.go @@ -4,6 +4,7 @@ import ( "context" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/usage/usagetypes" ) // Inserter accepts usage events generated by the product. @@ -12,7 +13,7 @@ type Inserter interface { // within the given transaction. // The caller context must be authorized to create usage events in the // database. - InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event DiscreteEvent) error + InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event usagetypes.DiscreteEvent) error } // AGPLInserter is a no-op implementation of Inserter. @@ -26,6 +27,6 @@ func NewAGPLInserter() Inserter { // InsertDiscreteUsageEvent is a no-op implementation of // InsertDiscreteUsageEvent. -func (AGPLInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, _ DiscreteEvent) error { +func (AGPLInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, _ usagetypes.DiscreteEvent) error { return nil } diff --git a/coderd/usage/usagetypes/events.go b/coderd/usage/usagetypes/events.go new file mode 100644 index 0000000000000..a8558fc49090e --- /dev/null +++ b/coderd/usage/usagetypes/events.go @@ -0,0 +1,129 @@ +// Package usagetypes contains the types for usage events. These are kept in +// their own package to avoid importing any real code from coderd. +// +// Imports in this package should be limited to the standard library and the +// following packages ONLY: +// - github.com/google/uuid +// - golang.org/x/xerrors +// +// This package is imported by the Tallyman codebase. +package usagetypes + +// Please read the package documentation before adding imports. +import ( + "bytes" + "encoding/json" + "strings" + + "golang.org/x/xerrors" +) + +// UsageEventType is an enum of all usage event types. It mirrors the database +// type `usage_event_type`. +type UsageEventType string + +const ( + UsageEventTypeDCManagedAgentsV1 UsageEventType = "dc_managed_agents_v1" +) + +func (e UsageEventType) Valid() bool { + switch e { + case UsageEventTypeDCManagedAgentsV1: + return true + default: + return false + } +} + +func (e UsageEventType) IsDiscrete() bool { + return e.Valid() && strings.HasPrefix(string(e), "dc_") +} + +func (e UsageEventType) IsHeartbeat() bool { + return e.Valid() && strings.HasPrefix(string(e), "hb_") +} + +// ParseEvent parses the raw event data into the specified Go type. It fails if +// there is any unknown fields or extra data after the event. The returned event +// is validated. +func ParseEvent[T Event](data json.RawMessage) (T, error) { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.DisallowUnknownFields() + + var event T + err := dec.Decode(&event) + if err != nil { + return event, xerrors.Errorf("unmarshal %T event: %w", event, err) + } + if dec.More() { + return event, xerrors.Errorf("extra data after %T event", event) + } + err = event.Valid() + if err != nil { + return event, xerrors.Errorf("invalid %T event: %w", event, err) + } + + return event, nil +} + +// ParseEventWithType parses the raw event data into the specified Go type. It +// fails if there is any unknown fields or extra data after the event. The +// returned event is validated. +func ParseEventWithType(eventType UsageEventType, data json.RawMessage) (Event, error) { + switch eventType { + case UsageEventTypeDCManagedAgentsV1: + return ParseEvent[DCManagedAgentsV1](data) + default: + return nil, xerrors.Errorf("unknown event type: %s", eventType) + } +} + +// Event is a usage event that can be collected by the usage collector. +// +// Note that the following event types should not be updated once they are +// merged into the product. Please consult Dean before making any changes. +// +// This type cannot be implemented outside of this package as it this package +// is the source of truth for the coder/tallyman repo. +type Event interface { + usageEvent() // to prevent external types from implementing this interface + EventType() UsageEventType + Valid() error + Fields() map[string]any // fields to be marshaled and sent to tallyman/Metronome +} + +// DiscreteEvent is a usage event that is collected as a discrete event. +type DiscreteEvent interface { + Event + discreteUsageEvent() // marker method, also prevents external types from implementing this interface +} + +// DCManagedAgentsV1 is a discrete usage event for the number of managed agents. +// This event is sent in the following situations: +// - Once on first startup after usage tracking is added to the product with +// the count of all existing managed agents (count=N) +// - A new managed agent is created (count=1) +type DCManagedAgentsV1 struct { + Count uint64 `json:"count"` +} + +var _ DiscreteEvent = DCManagedAgentsV1{} + +func (DCManagedAgentsV1) usageEvent() {} +func (DCManagedAgentsV1) discreteUsageEvent() {} +func (DCManagedAgentsV1) EventType() UsageEventType { + return UsageEventTypeDCManagedAgentsV1 +} + +func (e DCManagedAgentsV1) Valid() error { + if e.Count == 0 { + return xerrors.New("count must be greater than 0") + } + return nil +} + +func (e DCManagedAgentsV1) Fields() map[string]any { + return map[string]any{ + "count": e.Count, + } +} diff --git a/coderd/usage/usagetypes/events_test.go b/coderd/usage/usagetypes/events_test.go new file mode 100644 index 0000000000000..1e09aa07851c3 --- /dev/null +++ b/coderd/usage/usagetypes/events_test.go @@ -0,0 +1,61 @@ +package usagetypes_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/usage/usagetypes" +) + +func TestParseEvent(t *testing.T) { + t.Parallel() + + t.Run("ExtraFields", func(t *testing.T) { + t.Parallel() + _, err := usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": 1, "extra": "field"}`)) + require.ErrorContains(t, err, "unmarshal usagetypes.DCManagedAgentsV1 event") + }) + + t.Run("ExtraData", func(t *testing.T) { + t.Parallel() + _, err := usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": 1}{"count": 2}`)) + require.ErrorContains(t, err, "extra data after usagetypes.DCManagedAgentsV1 event") + }) + + t.Run("DCManagedAgentsV1", func(t *testing.T) { + t.Parallel() + + event, err := usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": 1}`)) + require.NoError(t, err) + require.Equal(t, usagetypes.DCManagedAgentsV1{Count: 1}, event) + require.Equal(t, map[string]any{"count": uint64(1)}, event.Fields()) + + _, err = usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": "invalid"}`)) + require.ErrorContains(t, err, "unmarshal usagetypes.DCManagedAgentsV1 event") + + _, err = usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{}`)) + require.ErrorContains(t, err, "invalid usagetypes.DCManagedAgentsV1 event: count must be greater than 0") + }) +} + +func TestParseEventWithType(t *testing.T) { + t.Parallel() + + t.Run("UnknownEvent", func(t *testing.T) { + t.Parallel() + _, err := usagetypes.ParseEventWithType(usagetypes.UsageEventType("fake"), []byte(`{}`)) + require.ErrorContains(t, err, "unknown event type: fake") + }) + + t.Run("DCManagedAgentsV1", func(t *testing.T) { + t.Parallel() + + eventType := usagetypes.UsageEventTypeDCManagedAgentsV1 + event, err := usagetypes.ParseEventWithType(eventType, []byte(`{"count": 1}`)) + require.NoError(t, err) + require.Equal(t, usagetypes.DCManagedAgentsV1{Count: 1}, event) + require.Equal(t, eventType, event.EventType()) + require.Equal(t, map[string]any{"count": uint64(1)}, event.Fields()) + }) +} diff --git a/coderd/usage/usagetypes/tallyman.go b/coderd/usage/usagetypes/tallyman.go new file mode 100644 index 0000000000000..38358b7a6d518 --- /dev/null +++ b/coderd/usage/usagetypes/tallyman.go @@ -0,0 +1,70 @@ +package usagetypes + +// Please read the package documentation before adding imports. +import ( + "encoding/json" + "time" + + "golang.org/x/xerrors" +) + +const ( + TallymanCoderLicenseKeyHeader = "Coder-License-Key" + TallymanCoderDeploymentIDHeader = "Coder-Deployment-ID" +) + +// TallymanV1Response is a generic response with a message from the Tallyman +// API. It is typically returned when there is an error. +type TallymanV1Response struct { + Message string `json:"message"` +} + +// TallymanV1IngestRequest is a request to the Tallyman API to ingest usage +// events. +type TallymanV1IngestRequest struct { + Events []TallymanV1IngestEvent `json:"events"` +} + +// TallymanV1IngestEvent is an event to be ingested into the Tallyman API. +type TallymanV1IngestEvent struct { + ID string `json:"id"` + EventType UsageEventType `json:"event_type"` + EventData json.RawMessage `json:"event_data"` + CreatedAt time.Time `json:"created_at"` +} + +// Valid validates the TallymanV1IngestEvent. It does not validate the event +// body. +func (e TallymanV1IngestEvent) Valid() error { + if e.ID == "" { + return xerrors.New("id is required") + } + if !e.EventType.Valid() { + return xerrors.Errorf("event_type %q is invalid", e.EventType) + } + if e.CreatedAt.IsZero() { + return xerrors.New("created_at cannot be zero") + } + return nil +} + +// TallymanV1IngestResponse is a response from the Tallyman API to ingest usage +// events. +type TallymanV1IngestResponse struct { + AcceptedEvents []TallymanV1IngestAcceptedEvent `json:"accepted_events"` + RejectedEvents []TallymanV1IngestRejectedEvent `json:"rejected_events"` +} + +// TallymanV1IngestAcceptedEvent is an event that was accepted by the Tallyman +// API. +type TallymanV1IngestAcceptedEvent struct { + ID string `json:"id"` +} + +// TallymanV1IngestRejectedEvent is an event that was rejected by the Tallyman +// API. +type TallymanV1IngestRejectedEvent struct { + ID string `json:"id"` + Message string `json:"message"` + Permanent bool `json:"permanent"` +} diff --git a/coderd/usage/usagetypes/tallyman_test.go b/coderd/usage/usagetypes/tallyman_test.go new file mode 100644 index 0000000000000..f8f09446dff51 --- /dev/null +++ b/coderd/usage/usagetypes/tallyman_test.go @@ -0,0 +1,85 @@ +package usagetypes_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/usage/usagetypes" +) + +func TestTallymanV1UsageEvent(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + event usagetypes.TallymanV1IngestEvent + errorMessage string + }{ + { + name: "OK", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventTypeDCManagedAgentsV1, + // EventData is not validated. + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: "", + }, + { + name: "NoID", + event: usagetypes.TallymanV1IngestEvent{ + EventType: usagetypes.UsageEventTypeDCManagedAgentsV1, + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: "id is required", + }, + { + name: "NoEventType", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventType(""), + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: `event_type "" is invalid`, + }, + { + name: "UnknownEventType", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventType("unknown"), + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: `event_type "unknown" is invalid`, + }, + { + name: "NoCreatedAt", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventTypeDCManagedAgentsV1, + EventData: json.RawMessage{}, + CreatedAt: time.Time{}, + }, + errorMessage: "created_at cannot be zero", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := tc.event.Valid() + if tc.errorMessage == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.errorMessage) + } + }) + } +} diff --git a/enterprise/coderd/usage/inserter.go b/enterprise/coderd/usage/inserter.go index e07b536e758bd..f3566595a181f 100644 --- a/enterprise/coderd/usage/inserter.go +++ b/enterprise/coderd/usage/inserter.go @@ -10,20 +10,21 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/quartz" ) -// dbCollector collects usage events and stores them in the database for +// dbInserter collects usage events and stores them in the database for // publishing. -type dbCollector struct { +type dbInserter struct { clock quartz.Clock } -var _ agplusage.Inserter = &dbCollector{} +var _ agplusage.Inserter = &dbInserter{} // NewDBInserter creates a new database-backed usage event inserter. func NewDBInserter(opts ...InserterOption) agplusage.Inserter { - c := &dbCollector{ + c := &dbInserter{ clock: quartz.NewReal(), } for _, opt := range opts { @@ -32,17 +33,17 @@ func NewDBInserter(opts ...InserterOption) agplusage.Inserter { return c } -type InserterOption func(*dbCollector) +type InserterOption func(*dbInserter) // InserterWithClock sets the quartz clock to use for the inserter. func InserterWithClock(clock quartz.Clock) InserterOption { - return func(c *dbCollector) { + return func(c *dbInserter) { c.clock = clock } } // InsertDiscreteUsageEvent implements agplusage.Inserter. -func (i *dbCollector) InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event agplusage.DiscreteEvent) error { +func (i *dbInserter) InsertDiscreteUsageEvent(ctx context.Context, tx database.Store, event usagetypes.DiscreteEvent) error { if !event.EventType().IsDiscrete() { return xerrors.Errorf("event type %q is not a discrete event", event.EventType()) } diff --git a/enterprise/coderd/usage/inserter_test.go b/enterprise/coderd/usage/inserter_test.go index b7ced536aef3b..7ac915be7a5a8 100644 --- a/enterprise/coderd/usage/inserter_test.go +++ b/enterprise/coderd/usage/inserter_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" - agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/enterprise/coderd/usage" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" @@ -33,37 +33,37 @@ func TestInserter(t *testing.T) { now := dbtime.Now() events := []struct { time time.Time - event agplusage.DiscreteEvent + event usagetypes.DiscreteEvent }{ { time: now, - event: agplusage.DCManagedAgentsV1{ + event: usagetypes.DCManagedAgentsV1{ Count: 1, }, }, { time: now.Add(1 * time.Minute), - event: agplusage.DCManagedAgentsV1{ + event: usagetypes.DCManagedAgentsV1{ Count: 2, }, }, } - for _, event := range events { - eventJSON := jsoninate(t, event.event) + for _, e := range events { + eventJSON := jsoninate(t, e.event) db.EXPECT().InsertUsageEvent(gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx any, params database.InsertUsageEventParams) error { + func(ctx interface{}, params database.InsertUsageEventParams) error { _, err := uuid.Parse(params.ID) assert.NoError(t, err) - assert.Equal(t, string(event.event.EventType()), params.EventType) + assert.Equal(t, e.event.EventType(), usagetypes.UsageEventType(params.EventType)) assert.JSONEq(t, eventJSON, string(params.EventData)) - assert.Equal(t, event.time, params.CreatedAt) + assert.Equal(t, e.time, params.CreatedAt) return nil }, ).Times(1) - clock.Set(event.time) - err := inserter.InsertDiscreteUsageEvent(ctx, db, event.event) + clock.Set(e.time) + err := inserter.InsertDiscreteUsageEvent(ctx, db, e.event) require.NoError(t, err) } }) @@ -77,7 +77,7 @@ func TestInserter(t *testing.T) { // We should get an error if the event is invalid. inserter := usage.NewDBInserter() - err := inserter.InsertDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := inserter.InsertDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: 0, // invalid }) assert.ErrorContains(t, err, `invalid "dc_managed_agents_v1" event: count must be greater than 0`) diff --git a/enterprise/coderd/usage/publisher.go b/enterprise/coderd/usage/publisher.go index 8c0811c7727c8..5c205ecd8c3b8 100644 --- a/enterprise/coderd/usage/publisher.go +++ b/enterprise/coderd/usage/publisher.go @@ -18,15 +18,13 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/pproflabel" - agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/cryptorand" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/quartz" ) const ( - CoderLicenseJWTHeader = "Coder-License-JWT" - tallymanURL = "https://tallyman-prod.coder.com" tallymanIngestURLV1 = tallymanURL + "/api/v1/events/ingest" @@ -217,20 +215,19 @@ func (p *tallymanPublisher) publishOnce(ctx context.Context, deploymentID uuid.U var ( eventIDs = make(map[string]struct{}) - tallymanReq = TallymanIngestRequestV1{ - DeploymentID: deploymentID, - Events: make([]TallymanIngestEventV1, 0, len(events)), + tallymanReq = usagetypes.TallymanV1IngestRequest{ + Events: make([]usagetypes.TallymanV1IngestEvent, 0, len(events)), } ) for _, event := range events { eventIDs[event.ID] = struct{}{} - eventType := agplusage.EventType(event.EventType) + eventType := usagetypes.UsageEventType(event.EventType) if !eventType.Valid() { // This should never happen due to the check constraint in the // database. return 0, xerrors.Errorf("event %q has an invalid event type %q", event.ID, event.EventType) } - tallymanReq.Events = append(tallymanReq.Events, TallymanIngestEventV1{ + tallymanReq.Events = append(tallymanReq.Events, usagetypes.TallymanV1IngestEvent{ ID: event.ID, EventType: eventType, EventData: event.EventData, @@ -243,17 +240,17 @@ func (p *tallymanPublisher) publishOnce(ctx context.Context, deploymentID uuid.U return 0, xerrors.Errorf("duplicate event IDs found in events for publishing") } - resp, err := p.sendPublishRequest(ctx, licenseJwt, tallymanReq) + resp, err := p.sendPublishRequest(ctx, deploymentID, licenseJwt, tallymanReq) allFailed := err != nil if err != nil { p.log.Warn(ctx, "failed to send publish request to tallyman", slog.F("count", len(events)), slog.Error(err)) // Fake a response with all events temporarily rejected. - resp = TallymanIngestResponseV1{ - AcceptedEvents: []TallymanIngestAcceptedEventV1{}, - RejectedEvents: make([]TallymanIngestRejectedEventV1, len(events)), + resp = usagetypes.TallymanV1IngestResponse{ + AcceptedEvents: []usagetypes.TallymanV1IngestAcceptedEvent{}, + RejectedEvents: make([]usagetypes.TallymanV1IngestRejectedEvent, len(events)), } for i, event := range events { - resp.RejectedEvents[i] = TallymanIngestRejectedEventV1{ + resp.RejectedEvents[i] = usagetypes.TallymanV1IngestRejectedEvent{ ID: event.ID, Message: fmt.Sprintf("failed to publish to tallyman: %v", err), Permanent: false, @@ -267,8 +264,8 @@ func (p *tallymanPublisher) publishOnce(ctx context.Context, deploymentID uuid.U p.log.Warn(ctx, "tallyman returned a different number of events than we sent", slog.F("sent", len(events)), slog.F("accepted", len(resp.AcceptedEvents)), slog.F("rejected", len(resp.RejectedEvents))) } - acceptedEvents := make(map[string]*TallymanIngestAcceptedEventV1) - rejectedEvents := make(map[string]*TallymanIngestRejectedEventV1) + acceptedEvents := make(map[string]*usagetypes.TallymanV1IngestAcceptedEvent) + rejectedEvents := make(map[string]*usagetypes.TallymanV1IngestRejectedEvent) for _, event := range resp.AcceptedEvents { acceptedEvents[event.ID] = &event } @@ -389,37 +386,38 @@ func (p *tallymanPublisher) getBestLicenseJWT(ctx context.Context) (string, erro return bestLicense.Raw, nil } -func (p *tallymanPublisher) sendPublishRequest(ctx context.Context, licenseJwt string, req TallymanIngestRequestV1) (TallymanIngestResponseV1, error) { +func (p *tallymanPublisher) sendPublishRequest(ctx context.Context, deploymentID uuid.UUID, licenseJwt string, req usagetypes.TallymanV1IngestRequest) (usagetypes.TallymanV1IngestResponse, error) { body, err := json.Marshal(req) if err != nil { - return TallymanIngestResponseV1{}, err + return usagetypes.TallymanV1IngestResponse{}, err } r, err := http.NewRequestWithContext(ctx, http.MethodPost, p.ingestURL, bytes.NewReader(body)) if err != nil { - return TallymanIngestResponseV1{}, err + return usagetypes.TallymanV1IngestResponse{}, err } - r.Header.Set(CoderLicenseJWTHeader, licenseJwt) + r.Header.Set(usagetypes.TallymanCoderLicenseKeyHeader, licenseJwt) + r.Header.Set(usagetypes.TallymanCoderDeploymentIDHeader, deploymentID.String()) resp, err := p.httpClient.Do(r) if err != nil { - return TallymanIngestResponseV1{}, err + return usagetypes.TallymanV1IngestResponse{}, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - var errBody TallymanErrorV1 + var errBody usagetypes.TallymanV1Response if err := json.NewDecoder(resp.Body).Decode(&errBody); err != nil { - errBody = TallymanErrorV1{ + errBody = usagetypes.TallymanV1Response{ Message: fmt.Sprintf("could not decode error response body: %v", err), } } - return TallymanIngestResponseV1{}, xerrors.Errorf("unexpected status code %v, error: %s", resp.StatusCode, errBody.Message) + return usagetypes.TallymanV1IngestResponse{}, xerrors.Errorf("unexpected status code %v, error: %s", resp.StatusCode, errBody.Message) } - var respBody TallymanIngestResponseV1 + var respBody usagetypes.TallymanV1IngestResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return TallymanIngestResponseV1{}, xerrors.Errorf("decode response body: %w", err) + return usagetypes.TallymanV1IngestResponse{}, xerrors.Errorf("decode response body: %w", err) } return respBody, nil @@ -431,34 +429,3 @@ func (p *tallymanPublisher) Close() error { <-p.done return nil } - -type TallymanErrorV1 struct { - Message string `json:"message"` -} - -type TallymanIngestRequestV1 struct { - DeploymentID uuid.UUID `json:"deployment_id"` - Events []TallymanIngestEventV1 `json:"events"` -} - -type TallymanIngestEventV1 struct { - ID string `json:"id"` - EventType agplusage.EventType `json:"event_type"` - EventData json.RawMessage `json:"event_data"` - CreatedAt time.Time `json:"created_at"` -} - -type TallymanIngestResponseV1 struct { - AcceptedEvents []TallymanIngestAcceptedEventV1 `json:"accepted_events"` - RejectedEvents []TallymanIngestRejectedEventV1 `json:"rejected_events"` -} - -type TallymanIngestAcceptedEventV1 struct { - ID string `json:"id"` -} - -type TallymanIngestRejectedEventV1 struct { - ID string `json:"id"` - Message string `json:"message"` - Permanent bool `json:"permanent"` -} diff --git a/enterprise/coderd/usage/publisher_test.go b/enterprise/coderd/usage/publisher_test.go index 7a17935a64a61..c104c9712e499 100644 --- a/enterprise/coderd/usage/publisher_test.go +++ b/enterprise/coderd/usage/publisher_test.go @@ -24,7 +24,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/rbac" - agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/usage" "github.com/coder/coder/v2/testutil" @@ -51,16 +51,15 @@ func TestIntegration(t *testing.T) { var ( calls int - handler func(req usage.TallymanIngestRequestV1) any + handler func(req usagetypes.TallymanV1IngestRequest) any ) - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { calls++ t.Logf("tallyman backend received call %d", calls) - assert.Equal(t, deploymentID, req.DeploymentID) if handler == nil { t.Errorf("handler is nil") - return usage.TallymanIngestResponseV1{} + return usagetypes.TallymanV1IngestResponse{} } return handler(req) })) @@ -70,7 +69,7 @@ func TestIntegration(t *testing.T) { ) // Insert an old event that should never be published. clock.Set(now.Add(-31 * 24 * time.Hour)) - err := inserter.InsertDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := inserter.InsertDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: 31, }) require.NoError(t, err) @@ -79,7 +78,7 @@ func TestIntegration(t *testing.T) { clock.Set(now.Add(1 * time.Second)) for i := 0; i < eventCount; i++ { clock.Advance(time.Second) - err := inserter.InsertDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := inserter.InsertDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: uint64(i + 1), // nolint:gosec // these numbers are tiny and will not overflow }) require.NoErrorf(t, err, "collecting event %d", i) @@ -117,33 +116,33 @@ func TestIntegration(t *testing.T) { // first event, temporarily reject the second, and permanently reject the // third. var temporarilyRejectedEventID string - handler = func(req usage.TallymanIngestRequestV1) any { + handler = func(req usagetypes.TallymanV1IngestRequest) any { // On the first call, accept the first event, temporarily reject the // second, and permanently reject the third. - acceptedEvents := make([]usage.TallymanIngestAcceptedEventV1, 1) - rejectedEvents := make([]usage.TallymanIngestRejectedEventV1, 2) + acceptedEvents := make([]usagetypes.TallymanV1IngestAcceptedEvent, 1) + rejectedEvents := make([]usagetypes.TallymanV1IngestRejectedEvent, 2) if assert.Len(t, req.Events, eventCount) { - assert.JSONEqf(t, jsoninate(t, agplusage.DCManagedAgentsV1{ + assert.JSONEqf(t, jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, }), string(req.Events[0].EventData), "event data did not match for event %d", 0) acceptedEvents[0].ID = req.Events[0].ID temporarilyRejectedEventID = req.Events[1].ID - assert.JSONEqf(t, jsoninate(t, agplusage.DCManagedAgentsV1{ + assert.JSONEqf(t, jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 2, }), string(req.Events[1].EventData), "event data did not match for event %d", 1) rejectedEvents[0].ID = req.Events[1].ID rejectedEvents[0].Message = "temporarily rejected" rejectedEvents[0].Permanent = false - assert.JSONEqf(t, jsoninate(t, agplusage.DCManagedAgentsV1{ + assert.JSONEqf(t, jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 3, }), string(req.Events[2].EventData), "event data did not match for event %d", 2) rejectedEvents[1].ID = req.Events[2].ID rejectedEvents[1].Message = "permanently rejected" rejectedEvents[1].Permanent = true } - return usage.TallymanIngestResponseV1{ + return usagetypes.TallymanV1IngestResponse{ AcceptedEvents: acceptedEvents, RejectedEvents: rejectedEvents, } @@ -162,16 +161,16 @@ func TestIntegration(t *testing.T) { // Set the handler for the next publish call. This call should only include // the temporarily rejected event from earlier. This time we'll accept it. - handler = func(req usage.TallymanIngestRequestV1) any { + handler = func(req usagetypes.TallymanV1IngestRequest) any { assert.Len(t, req.Events, 1) - acceptedEvents := make([]usage.TallymanIngestAcceptedEventV1, len(req.Events)) + acceptedEvents := make([]usagetypes.TallymanV1IngestAcceptedEvent, len(req.Events)) for i, event := range req.Events { assert.Equal(t, temporarilyRejectedEventID, event.ID) acceptedEvents[i].ID = event.ID } - return usage.TallymanIngestResponseV1{ + return usagetypes.TallymanV1IngestResponse{ AcceptedEvents: acceptedEvents, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } } @@ -211,11 +210,11 @@ func TestPublisherNoEligibleLicenses(t *testing.T) { db.EXPECT().GetDeploymentID(gomock.Any()).Return(deploymentID.String(), nil).Times(1) var calls int - ingestURL := fakeServer(t, tallymanHandler(t, "", func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), "", func(req usagetypes.TallymanV1IngestRequest) any { calls++ - return usage.TallymanIngestResponseV1{ - AcceptedEvents: []usage.TallymanIngestAcceptedEventV1{}, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + return usagetypes.TallymanV1IngestResponse{ + AcceptedEvents: []usagetypes.TallymanV1IngestAcceptedEvent{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } })) @@ -280,11 +279,11 @@ func TestPublisherClaimExpiry(t *testing.T) { log := slogtest.Make(t, nil) db, _ := dbtestutil.NewDB(t) clock := quartz.NewMock(t) - _, licenseJWT := configureDeployment(ctx, t, db) + deploymentID, licenseJWT := configureDeployment(ctx, t, db) now := time.Now() var calls int - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { calls++ return tallymanAcceptAllHandler(req) })) @@ -303,7 +302,7 @@ func TestPublisherClaimExpiry(t *testing.T) { // Create an event that was claimed 1h-18m ago. The ticker has a forced // delay of 17m in this test. clock.Set(now) - err := inserter.InsertDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := inserter.InsertDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: 1, }) require.NoError(t, err) @@ -358,17 +357,17 @@ func TestPublisherMissingEvents(t *testing.T) { log := slogtest.Make(t, nil) ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) - _, licenseJWT := configureMockDeployment(t, db) + deploymentID, licenseJWT := configureMockDeployment(t, db) clock := quartz.NewMock(t) now := time.Now() clock.Set(now) var calls int - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { calls++ - return usage.TallymanIngestResponseV1{ - AcceptedEvents: []usage.TallymanIngestAcceptedEventV1{}, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + return usagetypes.TallymanV1IngestResponse{ + AcceptedEvents: []usagetypes.TallymanV1IngestAcceptedEvent{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } })) @@ -382,8 +381,8 @@ func TestPublisherMissingEvents(t *testing.T) { events := []database.UsageEvent{ { ID: uuid.New().String(), - EventType: string(agplusage.UsageEventTypeDCManagedAgentsV1), - EventData: []byte(jsoninate(t, agplusage.DCManagedAgentsV1{ + EventType: string(usagetypes.UsageEventTypeDCManagedAgentsV1), + EventData: []byte(jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, })), CreatedAt: now, @@ -508,9 +507,8 @@ func TestPublisherLicenseSelection(t *testing.T) { }, nil) called := false - ingestURL := fakeServer(t, tallymanHandler(t, expectedLicense, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), expectedLicense, func(req usagetypes.TallymanV1IngestRequest) any { called = true - assert.Equal(t, deploymentID, req.DeploymentID) return tallymanAcceptAllHandler(req) })) @@ -536,8 +534,8 @@ func TestPublisherLicenseSelection(t *testing.T) { events := []database.UsageEvent{ { ID: uuid.New().String(), - EventType: string(agplusage.UsageEventTypeDCManagedAgentsV1), - EventData: []byte(jsoninate(t, agplusage.DCManagedAgentsV1{ + EventType: string(usagetypes.UsageEventTypeDCManagedAgentsV1), + EventData: []byte(jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, })), }, @@ -572,12 +570,12 @@ func TestPublisherTallymanError(t *testing.T) { now := time.Now() clock.Set(now) - _, licenseJWT := configureMockDeployment(t, db) + deploymentID, licenseJWT := configureMockDeployment(t, db) const errorMessage = "tallyman error" var calls int - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { calls++ - return usage.TallymanErrorV1{ + return usagetypes.TallymanV1Response{ Message: errorMessage, } })) @@ -604,8 +602,8 @@ func TestPublisherTallymanError(t *testing.T) { events := []database.UsageEvent{ { ID: uuid.New().String(), - EventType: string(agplusage.UsageEventTypeDCManagedAgentsV1), - EventData: []byte(jsoninate(t, agplusage.DCManagedAgentsV1{ + EventType: string(usagetypes.UsageEventTypeDCManagedAgentsV1), + EventData: []byte(jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, })), }, @@ -632,7 +630,7 @@ func TestPublisherTallymanError(t *testing.T) { func jsoninate(t *testing.T, v any) string { t.Helper() - if e, ok := v.(agplusage.Event); ok { + if e, ok := v.(usagetypes.Event); ok { v = e.Fields() } buf, err := json.Marshal(v) @@ -688,44 +686,61 @@ func fakeServer(t *testing.T, handler http.Handler) string { return server.URL } -func tallymanHandler(t *testing.T, expectLicenseJWT string, handler func(req usage.TallymanIngestRequestV1) any) http.Handler { +func tallymanHandler(t *testing.T, expectDeploymentID string, expectLicenseJWT string, handler func(req usagetypes.TallymanV1IngestRequest) any) http.Handler { t.Helper() return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { t.Helper() - licenseJWT := r.Header.Get(usage.CoderLicenseJWTHeader) + licenseJWT := r.Header.Get(usagetypes.TallymanCoderLicenseKeyHeader) if expectLicenseJWT != "" && !assert.Equal(t, expectLicenseJWT, licenseJWT, "license JWT in request did not match") { rw.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(rw).Encode(usage.TallymanErrorV1{ + _ = json.NewEncoder(rw).Encode(usagetypes.TallymanV1Response{ Message: "license JWT in request did not match", }) - require.NoError(t, err) return } - var req usage.TallymanIngestRequestV1 + deploymentID := r.Header.Get(usagetypes.TallymanCoderDeploymentIDHeader) + if expectDeploymentID != "" && !assert.Equal(t, expectDeploymentID, deploymentID, "deployment ID in request did not match") { + rw.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(rw).Encode(usagetypes.TallymanV1Response{ + Message: "deployment ID in request did not match", + }) + return + } + + var req usagetypes.TallymanV1IngestRequest err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) + if !assert.NoError(t, err, "could not decode request body") { + rw.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(rw).Encode(usagetypes.TallymanV1Response{ + Message: "could not decode request body", + }) + return + } resp := handler(req) switch resp.(type) { - case usage.TallymanErrorV1: + case usagetypes.TallymanV1Response: rw.WriteHeader(http.StatusInternalServerError) default: rw.WriteHeader(http.StatusOK) } err = json.NewEncoder(rw).Encode(resp) - require.NoError(t, err) + if !assert.NoError(t, err, "could not encode response body") { + rw.WriteHeader(http.StatusInternalServerError) + return + } }) } -func tallymanAcceptAllHandler(req usage.TallymanIngestRequestV1) usage.TallymanIngestResponseV1 { - acceptedEvents := make([]usage.TallymanIngestAcceptedEventV1, len(req.Events)) +func tallymanAcceptAllHandler(req usagetypes.TallymanV1IngestRequest) usagetypes.TallymanV1IngestResponse { + acceptedEvents := make([]usagetypes.TallymanV1IngestAcceptedEvent, len(req.Events)) for i, event := range req.Events { acceptedEvents[i].ID = event.ID } - return usage.TallymanIngestResponseV1{ + return usagetypes.TallymanV1IngestResponse{ AcceptedEvents: acceptedEvents, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } } From 5b08f8b4a007645b8e5bd476ee016b97462e1242 Mon Sep 17 00:00:00 2001 From: Susana Ferreira Date: Wed, 20 Aug 2025 14:58:00 +0100 Subject: [PATCH 08/11] fix: change createWorkspace to use dbtime.Time (#19414) The `createWorkspace` function was updated to use an injected Clock, which makes it possible to mock time in tests: https://github.com/coder/coder/pull/19264/files#diff-46f90baab52ea3ad914acbde30d656dbc8e46f5918d19bc056c445a1dc502482R1130 For database operations, however, it is recommended to use `dbtime.Time` since it rounds to the microsecond, the smallest unit of precision supported by Postgres. --- coderd/workspaces.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index b2b2610ff1349..e998aeb894c13 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -636,7 +636,7 @@ func createWorkspace( ) // Use injected Clock to allow time mocking in tests - now := api.Clock.Now() + now := dbtime.Time(api.Clock.Now()) templateVersionPresetID := req.TemplateVersionPresetID From fc7f53ffce80a0b20bfd90af8e0843bb90d8eb65 Mon Sep 17 00:00:00 2001 From: Jakub Domeracki Date: Wed, 20 Aug 2025 16:00:10 +0200 Subject: [PATCH 09/11] chore: update monaco-editor to resolve DOMPurify CVEs #19445 (#19446) https://github.com/coder/coder/issues/19445 --- site/package.json | 4 ++-- site/pnpm-lock.yaml | 37 +++++++++++++++++-------------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/site/package.json b/site/package.json index bb061511e1619..5693fc5d55220 100644 --- a/site/package.json +++ b/site/package.json @@ -47,7 +47,7 @@ "@fontsource/ibm-plex-mono": "5.1.1", "@fontsource/jetbrains-mono": "5.2.5", "@fontsource/source-code-pro": "5.2.5", - "@monaco-editor/react": "4.6.0", + "@monaco-editor/react": "4.7.0", "@mui/icons-material": "5.16.14", "@mui/material": "5.16.14", "@mui/system": "5.16.14", @@ -93,7 +93,7 @@ "jszip": "3.10.1", "lodash": "4.17.21", "lucide-react": "0.474.0", - "monaco-editor": "0.52.0", + "monaco-editor": "0.52.2", "pretty-bytes": "6.1.1", "react": "18.3.1", "react-color": "2.19.3", diff --git a/site/pnpm-lock.yaml b/site/pnpm-lock.yaml index 99ef8ac44af6d..31a8857901845 100644 --- a/site/pnpm-lock.yaml +++ b/site/pnpm-lock.yaml @@ -54,8 +54,8 @@ importers: specifier: 5.2.5 version: 5.2.5 '@monaco-editor/react': - specifier: 4.6.0 - version: 4.6.0(monaco-editor@0.52.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + specifier: 4.7.0 + version: 4.7.0(monaco-editor@0.52.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) '@mui/icons-material': specifier: 5.16.14 version: 5.16.14(@mui/material@5.16.14(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@emotion/styled@11.14.0(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(@types/react@18.3.12)(react@18.3.1) @@ -192,8 +192,8 @@ importers: specifier: 0.474.0 version: 0.474.0(react@18.3.1) monaco-editor: - specifier: 0.52.0 - version: 0.52.0 + specifier: 0.52.2 + version: 0.52.2 pretty-bytes: specifier: 6.1.1 version: 6.1.1 @@ -1260,17 +1260,15 @@ packages: '@mjackson/multipart-parser@0.6.3': resolution: {integrity: sha512-aQhySnM6OpAYMMG+m7LEygYye99hB1md/Cy1AFE0yD5hfNW+X4JDu7oNVY9Gc6IW8PZ45D1rjFLDIUdnkXmwrA==, tarball: https://registry.npmjs.org/@mjackson/multipart-parser/-/multipart-parser-0.6.3.tgz} - '@monaco-editor/loader@1.4.0': - resolution: {integrity: sha512-00ioBig0x642hytVspPl7DbQyaSWRaolYie/UFNjoTdvoKPzo6xrXLhTk9ixgIKcLH5b5vDOjVNiGyY+uDCUlg==, tarball: https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.4.0.tgz} - peerDependencies: - monaco-editor: '>= 0.21.0 < 1' + '@monaco-editor/loader@1.5.0': + resolution: {integrity: sha512-hKoGSM+7aAc7eRTRjpqAZucPmoNOC4UUbknb/VNoTkEIkCPhqV8LfbsgM1webRM7S/z21eHEx9Fkwx8Z/C/+Xw==, tarball: https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.5.0.tgz} - '@monaco-editor/react@4.6.0': - resolution: {integrity: sha512-RFkU9/i7cN2bsq/iTkurMWOEErmYcY6JiQI3Jn+WeR/FGISH8JbHERjpS9oRuSOPvDMJI0Z8nJeKkbOs9sBYQw==, tarball: https://registry.npmjs.org/@monaco-editor/react/-/react-4.6.0.tgz} + '@monaco-editor/react@4.7.0': + resolution: {integrity: sha512-cyzXQCtO47ydzxpQtCGSQGOC8Gk3ZUeBXFAxD+CWXYFo5OqZyZUonFl0DwUlTyAfRHntBfw2p3w4s9R6oe1eCA==, tarball: https://registry.npmjs.org/@monaco-editor/react/-/react-4.7.0.tgz} peerDependencies: monaco-editor: '>= 0.25.0 < 1' - react: ^16.8.0 || ^17.0.0 || ^18.0.0 - react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 '@mswjs/interceptors@0.35.9': resolution: {integrity: sha512-SSnyl/4ni/2ViHKkiZb8eajA/eN1DNFaHjhGiLUdZvDz6PKF4COSf/17xqSz64nOo2Ia29SA6B2KNCsyCbVmaQ==, tarball: https://registry.npmjs.org/@mswjs/interceptors/-/interceptors-0.35.9.tgz} @@ -4759,8 +4757,8 @@ packages: resolution: {integrity: sha512-qxBgB7Qa2sEQgHFjj0dSigq7fX4k6Saisd5Nelwp2q8mlbAFh5dHV9JTTlF8viYJLSSWgMCZFUom8PJcMNBoJw==, tarball: https://registry.npmjs.org/mock-socket/-/mock-socket-9.3.1.tgz} engines: {node: '>= 8'} - monaco-editor@0.52.0: - resolution: {integrity: sha512-OeWhNpABLCeTqubfqLMXGsqf6OmPU6pHM85kF3dhy6kq5hnhuVS1p3VrEW/XhWHc71P2tHyS5JFySD8mgs1crw==, tarball: https://registry.npmjs.org/monaco-editor/-/monaco-editor-0.52.0.tgz} + monaco-editor@0.52.2: + resolution: {integrity: sha512-GEQWEZmfkOGLdd3XK8ryrfWz3AIP8YymVXiPHEdewrUq7mh0qrKrfHLNCXcbB6sTnMLnOZ3ztSiKcciFUkIJwQ==, tarball: https://registry.npmjs.org/monaco-editor/-/monaco-editor-0.52.2.tgz} moo-color@1.0.3: resolution: {integrity: sha512-i/+ZKXMDf6aqYtBhuOcej71YSlbjT3wCO/4H1j8rPvxDJEifdwgg5MaFyu6iYAT8GBZJg2z0dkgK4YMzvURALQ==, tarball: https://registry.npmjs.org/moo-color/-/moo-color-1.0.3.tgz} @@ -7240,15 +7238,14 @@ snapshots: dependencies: '@mjackson/headers': 0.5.1 - '@monaco-editor/loader@1.4.0(monaco-editor@0.52.0)': + '@monaco-editor/loader@1.5.0': dependencies: - monaco-editor: 0.52.0 state-local: 1.0.7 - '@monaco-editor/react@4.6.0(monaco-editor@0.52.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + '@monaco-editor/react@4.7.0(monaco-editor@0.52.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': dependencies: - '@monaco-editor/loader': 1.4.0(monaco-editor@0.52.0) - monaco-editor: 0.52.0 + '@monaco-editor/loader': 1.5.0 + monaco-editor: 0.52.2 react: 18.3.1 react-dom: 18.3.1(react@18.3.1) @@ -11340,7 +11337,7 @@ snapshots: mock-socket@9.3.1: {} - monaco-editor@0.52.0: {} + monaco-editor@0.52.2: {} moo-color@1.0.3: dependencies: From d536b91bfc31f7d7320b0cc2fc7170e1a4bfd77d Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Wed, 20 Aug 2025 16:10:18 +0200 Subject: [PATCH 10/11] chore(coderd/database/dbauthz): migrate more tests to mocked db (#19300) Related to https://github.com/coder/internal/issues/869 --------- Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com> Co-authored-by: Steven Masley --- coderd/database/dbauthz/dbauthz_test.go | 451 +++++++++++------------- 1 file changed, 200 insertions(+), 251 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index ad444d1025514..971335c34019b 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "reflect" - "strings" "testing" "time" @@ -750,13 +749,11 @@ func (s *MethodTestSuite) TestProvisionerJob() { } func (s *MethodTestSuite) TestLicense() { - s.Run("GetLicenses", s.Subtest(func(db database.Store, check *expects) { - l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ - UUID: uuid.New(), - }) - require.NoError(s.T(), err) - check.Args().Asserts(l, policy.ActionRead). - Returns([]database.License{l}) + s.Run("GetLicenses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + a := database.License{ID: 1} + b := database.License{ID: 2} + dbm.EXPECT().GetLicenses(gomock.Any()).Return([]database.License{a, b}, nil).AnyTimes() + check.Args().Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns([]database.License{a, b}) })) s.Run("GetUnexpiredLicenses", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { l := database.License{ @@ -770,80 +767,73 @@ func (s *MethodTestSuite) TestLicense() { check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead). Returns([]database.License{l}) })) - s.Run("InsertLicense", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertLicenseParams{}). - Asserts(rbac.ResourceLicense, policy.ActionCreate) + s.Run("InsertLicense", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().InsertLicense(gomock.Any(), database.InsertLicenseParams{}).Return(database.License{}, nil).AnyTimes() + check.Args(database.InsertLicenseParams{}).Asserts(rbac.ResourceLicense, policy.ActionCreate) })) - s.Run("UpsertLogoURL", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertLogoURL(gomock.Any(), "value").Return(nil).AnyTimes() check.Args("value").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("UpsertAnnouncementBanners", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpsertAnnouncementBanners", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertAnnouncementBanners(gomock.Any(), "value").Return(nil).AnyTimes() check.Args("value").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("GetLicenseByID", s.Subtest(func(db database.Store, check *expects) { - l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ - UUID: uuid.New(), - }) - require.NoError(s.T(), err) - check.Args(l.ID).Asserts(l, policy.ActionRead).Returns(l) + s.Run("GetLicenseByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + l := database.License{ID: 1} + dbm.EXPECT().GetLicenseByID(gomock.Any(), int32(1)).Return(l, nil).AnyTimes() + check.Args(int32(1)).Asserts(l, policy.ActionRead).Returns(l) })) - s.Run("DeleteLicense", s.Subtest(func(db database.Store, check *expects) { - l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ - UUID: uuid.New(), - }) - require.NoError(s.T(), err) + s.Run("DeleteLicense", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + l := database.License{ID: 1} + dbm.EXPECT().GetLicenseByID(gomock.Any(), l.ID).Return(l, nil).AnyTimes() + dbm.EXPECT().DeleteLicense(gomock.Any(), l.ID).Return(int32(1), nil).AnyTimes() check.Args(l.ID).Asserts(l, policy.ActionDelete) })) - s.Run("GetDeploymentID", s.Subtest(func(db database.Store, check *expects) { - db.InsertDeploymentID(context.Background(), "value") + s.Run("GetDeploymentID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetDeploymentID(gomock.Any()).Return("value", nil).AnyTimes() check.Args().Asserts().Returns("value") })) - s.Run("GetDefaultProxyConfig", s.Subtest(func(db database.Store, check *expects) { - check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{ - DisplayName: "Default", - IconUrl: "/emojis/1f3e1.png", - }) + s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes() + check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}) })) - s.Run("GetLogoURL", s.Subtest(func(db database.Store, check *expects) { - err := db.UpsertLogoURL(context.Background(), "value") - require.NoError(s.T(), err) + s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes() check.Args().Asserts().Returns("value") })) - s.Run("GetAnnouncementBanners", s.Subtest(func(db database.Store, check *expects) { - err := db.UpsertAnnouncementBanners(context.Background(), "value") - require.NoError(s.T(), err) + s.Run("GetAnnouncementBanners", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetAnnouncementBanners(gomock.Any()).Return("value", nil).AnyTimes() check.Args().Asserts().Returns("value") })) - s.Run("GetManagedAgentCount", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetManagedAgentCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { start := dbtime.Now() end := start.Add(time.Hour) - check.Args(database.GetManagedAgentCountParams{ - StartTime: start, - EndTime: end, - }).Asserts(rbac.ResourceWorkspace, policy.ActionRead).Returns(int64(0)) + dbm.EXPECT().GetManagedAgentCount(gomock.Any(), database.GetManagedAgentCountParams{StartTime: start, EndTime: end}).Return(int64(0), nil).AnyTimes() + check.Args(database.GetManagedAgentCountParams{StartTime: start, EndTime: end}).Asserts(rbac.ResourceWorkspace, policy.ActionRead).Returns(int64(0)) })) } func (s *MethodTestSuite) TestOrganization() { - s.Run("Deployment/OIDCClaimFields", s.Subtest(func(db database.Store, check *expects) { + s.Run("Deployment/OIDCClaimFields", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().OIDCClaimFields(gomock.Any(), uuid.Nil).Return([]string{}, nil).AnyTimes() check.Args(uuid.Nil).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{}) })) - s.Run("Organization/OIDCClaimFields", s.Subtest(func(db database.Store, check *expects) { + s.Run("Organization/OIDCClaimFields", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { id := uuid.New() + dbm.EXPECT().OIDCClaimFields(gomock.Any(), id).Return([]string{}, nil).AnyTimes() check.Args(id).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) })) - s.Run("Deployment/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.OIDCClaimFieldValuesParams{ - ClaimField: "claim-field", - OrganizationID: uuid.Nil, - }).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{}) + s.Run("Deployment/OIDCClaimFieldValues", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.OIDCClaimFieldValuesParams{ClaimField: "claim-field", OrganizationID: uuid.Nil} + dbm.EXPECT().OIDCClaimFieldValues(gomock.Any(), arg).Return([]string{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{}) })) - s.Run("Organization/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) { + s.Run("Organization/OIDCClaimFieldValues", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { id := uuid.New() - check.Args(database.OIDCClaimFieldValuesParams{ - ClaimField: "claim-field", - OrganizationID: id, - }).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) + arg := database.OIDCClaimFieldValuesParams{ClaimField: "claim-field", OrganizationID: id} + dbm.EXPECT().OIDCClaimFieldValues(gomock.Any(), arg).Return([]string{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) })) s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) @@ -1150,41 +1140,43 @@ func (s *MethodTestSuite) TestOrganization() { } func (s *MethodTestSuite) TestWorkspaceProxy() { - s.Run("InsertWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertWorkspaceProxyParams{ - ID: uuid.New(), - }).Asserts(rbac.ResourceWorkspaceProxy, policy.ActionCreate) - })) - s.Run("RegisterWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args(database.RegisterWorkspaceProxyParams{ - ID: p.ID, - }).Asserts(p, policy.ActionUpdate) - })) - s.Run("GetWorkspaceProxyByID", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + s.Run("InsertWorkspaceProxy", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertWorkspaceProxyParams{ID: uuid.New()} + dbm.EXPECT().InsertWorkspaceProxy(gomock.Any(), arg).Return(database.WorkspaceProxy{}, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceWorkspaceProxy, policy.ActionCreate) + })) + s.Run("RegisterWorkspaceProxy", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + dbm.EXPECT().GetWorkspaceProxyByID(gomock.Any(), p.ID).Return(p, nil).AnyTimes() + dbm.EXPECT().RegisterWorkspaceProxy(gomock.Any(), database.RegisterWorkspaceProxyParams{ID: p.ID}).Return(p, nil).AnyTimes() + check.Args(database.RegisterWorkspaceProxyParams{ID: p.ID}).Asserts(p, policy.ActionUpdate) + })) + s.Run("GetWorkspaceProxyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + dbm.EXPECT().GetWorkspaceProxyByID(gomock.Any(), p.ID).Return(p, nil).AnyTimes() check.Args(p.ID).Asserts(p, policy.ActionRead).Returns(p) })) - s.Run("GetWorkspaceProxyByName", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + s.Run("GetWorkspaceProxyByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + dbm.EXPECT().GetWorkspaceProxyByName(gomock.Any(), p.Name).Return(p, nil).AnyTimes() check.Args(p.Name).Asserts(p, policy.ActionRead).Returns(p) })) - s.Run("UpdateWorkspaceProxyDeleted", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args(database.UpdateWorkspaceProxyDeletedParams{ - ID: p.ID, - Deleted: true, - }).Asserts(p, policy.ActionDelete) - })) - s.Run("UpdateWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { - p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - check.Args(database.UpdateWorkspaceProxyParams{ - ID: p.ID, - }).Asserts(p, policy.ActionUpdate) - })) - s.Run("GetWorkspaceProxies", s.Subtest(func(db database.Store, check *expects) { - p1, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) - p2, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + s.Run("UpdateWorkspaceProxyDeleted", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + dbm.EXPECT().GetWorkspaceProxyByID(gomock.Any(), p.ID).Return(p, nil).AnyTimes() + dbm.EXPECT().UpdateWorkspaceProxyDeleted(gomock.Any(), database.UpdateWorkspaceProxyDeletedParams{ID: p.ID, Deleted: true}).Return(nil).AnyTimes() + check.Args(database.UpdateWorkspaceProxyDeletedParams{ID: p.ID, Deleted: true}).Asserts(p, policy.ActionDelete) + })) + s.Run("UpdateWorkspaceProxy", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + dbm.EXPECT().GetWorkspaceProxyByID(gomock.Any(), p.ID).Return(p, nil).AnyTimes() + dbm.EXPECT().UpdateWorkspaceProxy(gomock.Any(), database.UpdateWorkspaceProxyParams{ID: p.ID}).Return(p, nil).AnyTimes() + check.Args(database.UpdateWorkspaceProxyParams{ID: p.ID}).Asserts(p, policy.ActionUpdate) + })) + s.Run("GetWorkspaceProxies", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + p1 := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + p2 := testutil.Fake(s.T(), faker, database.WorkspaceProxy{}) + dbm.EXPECT().GetWorkspaceProxies(gomock.Any()).Return([]database.WorkspaceProxy{p1, p2}, nil).AnyTimes() check.Args().Asserts(p1, policy.ActionRead, p2, policy.ActionRead).Returns(slice.New(p1, p2)) })) } @@ -3345,73 +3337,50 @@ func (s *MethodTestSuite) TestWorkspacePortSharing() { } func (s *MethodTestSuite) TestProvisionerKeys() { - s.Run("InsertProvisionerKey", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := database.ProvisionerKey{ - ID: uuid.New(), - CreatedAt: dbtestutil.NowInDefaultTimezone(), - OrganizationID: org.ID, - Name: strings.ToLower(coderdtest.RandomName(s.T())), - HashedSecret: []byte(coderdtest.RandomName(s.T())), - } - //nolint:gosimple // casting is not a simplification - check.Args(database.InsertProvisionerKeyParams{ - ID: pk.ID, - CreatedAt: pk.CreatedAt, - OrganizationID: pk.OrganizationID, - Name: pk.Name, - HashedSecret: pk.HashedSecret, - }).Asserts(pk, policy.ActionCreate).Returns(pk) - })) - s.Run("GetProvisionerKeyByID", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) + s.Run("InsertProvisionerKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + pk := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + arg := database.InsertProvisionerKeyParams{ID: pk.ID, CreatedAt: pk.CreatedAt, OrganizationID: pk.OrganizationID, Name: pk.Name, HashedSecret: pk.HashedSecret} + dbm.EXPECT().InsertProvisionerKey(gomock.Any(), arg).Return(pk, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceProvisionerDaemon.InOrg(org.ID).WithID(pk.ID), policy.ActionCreate).Returns(pk) + })) + s.Run("GetProvisionerKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + pk := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + dbm.EXPECT().GetProvisionerKeyByID(gomock.Any(), pk.ID).Return(pk, nil).AnyTimes() check.Args(pk.ID).Asserts(pk, policy.ActionRead).Returns(pk) })) - s.Run("GetProvisionerKeyByHashedSecret", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID, HashedSecret: []byte("foo")}) + s.Run("GetProvisionerKeyByHashedSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + pk := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID, HashedSecret: []byte("foo")}) + dbm.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), []byte("foo")).Return(pk, nil).AnyTimes() check.Args([]byte("foo")).Asserts(pk, policy.ActionRead).Returns(pk) })) - s.Run("GetProvisionerKeyByName", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) - check.Args(database.GetProvisionerKeyByNameParams{ - OrganizationID: org.ID, - Name: pk.Name, - }).Asserts(pk, policy.ActionRead).Returns(pk) - })) - s.Run("ListProvisionerKeysByOrganization", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) - pks := []database.ProvisionerKey{ - { - ID: pk.ID, - CreatedAt: pk.CreatedAt, - OrganizationID: pk.OrganizationID, - Name: pk.Name, - HashedSecret: pk.HashedSecret, - }, - } - check.Args(org.ID).Asserts(pk, policy.ActionRead).Returns(pks) - })) - s.Run("ListProvisionerKeysByOrganizationExcludeReserved", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) - pks := []database.ProvisionerKey{ - { - ID: pk.ID, - CreatedAt: pk.CreatedAt, - OrganizationID: pk.OrganizationID, - Name: pk.Name, - HashedSecret: pk.HashedSecret, - }, - } - check.Args(org.ID).Asserts(pk, policy.ActionRead).Returns(pks) - })) - s.Run("DeleteProvisionerKey", s.Subtest(func(db database.Store, check *expects) { - org := dbgen.Organization(s.T(), db, database.Organization{}) - pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) + s.Run("GetProvisionerKeyByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + pk := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + arg := database.GetProvisionerKeyByNameParams{OrganizationID: org.ID, Name: pk.Name} + dbm.EXPECT().GetProvisionerKeyByName(gomock.Any(), arg).Return(pk, nil).AnyTimes() + check.Args(arg).Asserts(pk, policy.ActionRead).Returns(pk) + })) + s.Run("ListProvisionerKeysByOrganization", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + a := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + b := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + dbm.EXPECT().ListProvisionerKeysByOrganization(gomock.Any(), org.ID).Return([]database.ProvisionerKey{a, b}, nil).AnyTimes() + check.Args(org.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns([]database.ProvisionerKey{a, b}) + })) + s.Run("ListProvisionerKeysByOrganizationExcludeReserved", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + pk := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + dbm.EXPECT().ListProvisionerKeysByOrganizationExcludeReserved(gomock.Any(), org.ID).Return([]database.ProvisionerKey{pk}, nil).AnyTimes() + check.Args(org.ID).Asserts(pk, policy.ActionRead).Returns([]database.ProvisionerKey{pk}) + })) + s.Run("DeleteProvisionerKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + org := testutil.Fake(s.T(), faker, database.Organization{}) + pk := testutil.Fake(s.T(), faker, database.ProvisionerKey{OrganizationID: org.ID}) + dbm.EXPECT().GetProvisionerKeyByID(gomock.Any(), pk.ID).Return(pk, nil).AnyTimes() + dbm.EXPECT().DeleteProvisionerKey(gomock.Any(), pk.ID).Return(nil).AnyTimes() check.Args(pk.ID).Asserts(pk, policy.ActionDelete).Returns() })) } @@ -3665,21 +3634,20 @@ func (s *MethodTestSuite) TestTailnetFunctions() { } func (s *MethodTestSuite) TestDBCrypt() { - s.Run("GetDBCryptKeys", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetDBCryptKeys", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetDBCryptKeys(gomock.Any()).Return([]database.DBCryptKey{}, nil).AnyTimes() check.Args(). Asserts(rbac.ResourceSystem, policy.ActionRead). Returns([]database.DBCryptKey{}) })) - s.Run("InsertDBCryptKey", s.Subtest(func(db database.Store, check *expects) { + s.Run("InsertDBCryptKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().InsertDBCryptKey(gomock.Any(), database.InsertDBCryptKeyParams{}).Return(nil).AnyTimes() check.Args(database.InsertDBCryptKeyParams{}). Asserts(rbac.ResourceSystem, policy.ActionCreate). Returns() })) - s.Run("RevokeDBCryptKey", s.Subtest(func(db database.Store, check *expects) { - err := db.InsertDBCryptKey(context.Background(), database.InsertDBCryptKeyParams{ - ActiveKeyDigest: "revoke me", - }) - s.NoError(err) + s.Run("RevokeDBCryptKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().RevokeDBCryptKey(gomock.Any(), "revoke me").Return(nil).AnyTimes() check.Args("revoke me"). Asserts(rbac.ResourceSystem, policy.ActionUpdate). Returns() @@ -3687,56 +3655,44 @@ func (s *MethodTestSuite) TestDBCrypt() { } func (s *MethodTestSuite) TestCryptoKeys() { - s.Run("GetCryptoKeys", s.Subtest(func(db database.Store, check *expects) { + s.Run("GetCryptoKeys", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetCryptoKeys(gomock.Any()).Return([]database.CryptoKey{}, nil).AnyTimes() check.Args(). Asserts(rbac.ResourceCryptoKey, policy.ActionRead) })) - s.Run("InsertCryptoKey", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.InsertCryptoKeyParams{ - Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, - }). + s.Run("InsertCryptoKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.InsertCryptoKeyParams{Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey} + dbm.EXPECT().InsertCryptoKey(gomock.Any(), arg).Return(database.CryptoKey{}, nil).AnyTimes() + check.Args(arg). Asserts(rbac.ResourceCryptoKey, policy.ActionCreate) })) - s.Run("DeleteCryptoKey", s.Subtest(func(db database.Store, check *expects) { - key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, - Sequence: 4, - }) - check.Args(database.DeleteCryptoKeyParams{ - Feature: key.Feature, - Sequence: key.Sequence, - }).Asserts(rbac.ResourceCryptoKey, policy.ActionDelete) - })) - s.Run("GetCryptoKeyByFeatureAndSequence", s.Subtest(func(db database.Store, check *expects) { - key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, - Sequence: 4, - }) - check.Args(database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: key.Feature, - Sequence: key.Sequence, - }).Asserts(rbac.ResourceCryptoKey, policy.ActionRead).Returns(key) - })) - s.Run("GetLatestCryptoKeyByFeature", s.Subtest(func(db database.Store, check *expects) { - dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, - Sequence: 4, - }) - check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).Asserts(rbac.ResourceCryptoKey, policy.ActionRead) - })) - s.Run("UpdateCryptoKeyDeletesAt", s.Subtest(func(db database.Store, check *expects) { - key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, - Sequence: 4, - }) - check.Args(database.UpdateCryptoKeyDeletesAtParams{ - Feature: key.Feature, - Sequence: key.Sequence, - DeletesAt: sql.NullTime{Time: time.Now(), Valid: true}, - }).Asserts(rbac.ResourceCryptoKey, policy.ActionUpdate) - })) - s.Run("GetCryptoKeysByFeature", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey). + s.Run("DeleteCryptoKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.CryptoKey{Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4}) + arg := database.DeleteCryptoKeyParams{Feature: key.Feature, Sequence: key.Sequence} + dbm.EXPECT().DeleteCryptoKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceCryptoKey, policy.ActionDelete) + })) + s.Run("GetCryptoKeyByFeatureAndSequence", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.CryptoKey{Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4}) + arg := database.GetCryptoKeyByFeatureAndSequenceParams{Feature: key.Feature, Sequence: key.Sequence} + dbm.EXPECT().GetCryptoKeyByFeatureAndSequence(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceCryptoKey, policy.ActionRead).Returns(key) + })) + s.Run("GetLatestCryptoKeyByFeature", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + feature := database.CryptoKeyFeatureWorkspaceAppsAPIKey + dbm.EXPECT().GetLatestCryptoKeyByFeature(gomock.Any(), feature).Return(database.CryptoKey{}, nil).AnyTimes() + check.Args(feature).Asserts(rbac.ResourceCryptoKey, policy.ActionRead) + })) + s.Run("UpdateCryptoKeyDeletesAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.CryptoKey{Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4}) + arg := database.UpdateCryptoKeyDeletesAtParams{Feature: key.Feature, Sequence: key.Sequence, DeletesAt: sql.NullTime{Time: time.Now(), Valid: true}} + dbm.EXPECT().UpdateCryptoKeyDeletesAt(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceCryptoKey, policy.ActionUpdate) + })) + s.Run("GetCryptoKeysByFeature", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + feature := database.CryptoKeyFeatureWorkspaceAppsAPIKey + dbm.EXPECT().GetCryptoKeysByFeature(gomock.Any(), feature).Return([]database.CryptoKey{}, nil).AnyTimes() + check.Args(feature). Asserts(rbac.ResourceCryptoKey, policy.ActionRead) })) } @@ -5638,63 +5594,56 @@ func (s *MethodTestSuite) TestAuthorizePrebuiltWorkspace() { } func (s *MethodTestSuite) TestUserSecrets() { - s.Run("GetUserSecretByUserIDAndName", s.Subtest(func(db database.Store, check *expects) { - user := dbgen.User(s.T(), db, database.User{}) - userSecret := dbgen.UserSecret(s.T(), db, database.UserSecret{ - UserID: user.ID, - }) - arg := database.GetUserSecretByUserIDAndNameParams{ - UserID: user.ID, - Name: userSecret.Name, - } + s.Run("GetUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + user := testutil.Fake(s.T(), faker, database.User{}) + secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID}) + arg := database.GetUserSecretByUserIDAndNameParams{UserID: user.ID, Name: secret.Name} + dbm.EXPECT().GetUserSecretByUserIDAndName(gomock.Any(), arg).Return(secret, nil).AnyTimes() check.Args(arg). - Asserts(rbac.ResourceUserSecret.WithOwner(arg.UserID.String()), policy.ActionRead). - Returns(userSecret) - })) - s.Run("GetUserSecret", s.Subtest(func(db database.Store, check *expects) { - user := dbgen.User(s.T(), db, database.User{}) - userSecret := dbgen.UserSecret(s.T(), db, database.UserSecret{ - UserID: user.ID, - }) - check.Args(userSecret.ID). - Asserts(userSecret, policy.ActionRead). - Returns(userSecret) - })) - s.Run("ListUserSecrets", s.Subtest(func(db database.Store, check *expects) { - user := dbgen.User(s.T(), db, database.User{}) - userSecret := dbgen.UserSecret(s.T(), db, database.UserSecret{ - UserID: user.ID, - }) + Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead). + Returns(secret) + })) + s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + secret := testutil.Fake(s.T(), faker, database.UserSecret{}) + dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes() + check.Args(secret.ID). + Asserts(secret, policy.ActionRead). + Returns(secret) + })) + s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + user := testutil.Fake(s.T(), faker, database.User{}) + secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID}) + dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes() check.Args(user.ID). Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead). - Returns([]database.UserSecret{userSecret}) + Returns([]database.UserSecret{secret}) })) - s.Run("CreateUserSecret", s.Subtest(func(db database.Store, check *expects) { - user := dbgen.User(s.T(), db, database.User{}) - arg := database.CreateUserSecretParams{ - UserID: user.ID, - } + s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + user := testutil.Fake(s.T(), faker, database.User{}) + arg := database.CreateUserSecretParams{UserID: user.ID} + ret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID}) + dbm.EXPECT().CreateUserSecret(gomock.Any(), arg).Return(ret, nil).AnyTimes() check.Args(arg). - Asserts(rbac.ResourceUserSecret.WithOwner(arg.UserID.String()), policy.ActionCreate) - })) - s.Run("UpdateUserSecret", s.Subtest(func(db database.Store, check *expects) { - user := dbgen.User(s.T(), db, database.User{}) - userSecret := dbgen.UserSecret(s.T(), db, database.UserSecret{ - UserID: user.ID, - }) - arg := database.UpdateUserSecretParams{ - ID: userSecret.ID, - } + Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate). + Returns(ret) + })) + s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + secret := testutil.Fake(s.T(), faker, database.UserSecret{}) + updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID}) + arg := database.UpdateUserSecretParams{ID: secret.ID} + dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes() + dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes() check.Args(arg). - Asserts(userSecret, policy.ActionUpdate) - })) - s.Run("DeleteUserSecret", s.Subtest(func(db database.Store, check *expects) { - user := dbgen.User(s.T(), db, database.User{}) - userSecret := dbgen.UserSecret(s.T(), db, database.UserSecret{ - UserID: user.ID, - }) - check.Args(userSecret.ID). - Asserts(userSecret, policy.ActionRead, userSecret, policy.ActionDelete) + Asserts(secret, policy.ActionUpdate). + Returns(updated) + })) + s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + secret := testutil.Fake(s.T(), faker, database.UserSecret{}) + dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes() + dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes() + check.Args(secret.ID). + Asserts(secret, policy.ActionRead, secret, policy.ActionDelete). + Returns() })) } From a19dfa9a0a10d106de3c4b91d9ff43214c1feca5 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 20 Aug 2025 09:13:40 -0500 Subject: [PATCH 11/11] docs: add generative ai contribution guidelines (#19427) Initial language that gives us something to point to if needed. --- docs/about/contributing/AI_CONTRIBUTING.md | 32 ++++++++++++++++++++++ docs/about/contributing/CONTRIBUTING.md | 5 ++++ docs/images/icons/ai_intelligence.svg | 1 + docs/manifest.json | 6 ++++ 4 files changed, 44 insertions(+) create mode 100644 docs/about/contributing/AI_CONTRIBUTING.md create mode 100644 docs/images/icons/ai_intelligence.svg diff --git a/docs/about/contributing/AI_CONTRIBUTING.md b/docs/about/contributing/AI_CONTRIBUTING.md new file mode 100644 index 0000000000000..8771528f0c1ce --- /dev/null +++ b/docs/about/contributing/AI_CONTRIBUTING.md @@ -0,0 +1,32 @@ +# AI Contribution Guidelines + +This document defines rules for contributions where an AI system is the primary author of the code (i.e., most of the pull request was generated by AI). +It applies to all Coder repositories and is a supplement to the [existing contributing guidelines](./CONTRIBUTING.md), not a replacement. + +For minor AI-assisted edits, suggestions, or completions where the human contributor is clearly the primary author, these rules do not apply — standard contributing guidelines are sufficient. + +## Disclosure + +Contributors must **disclose AI involvement** in the pull request description whenever these guidelines apply. + +## Human Ownership & Attribution + +- All pull requests must be opened under **user accounts linked to a human**, and not an application ("bot account"). +- Contributors are personally accountable for the content of their PRs, regardless of how it was generated. + +## Verification & Evidence + +All AI-assisted contributions require **manual verification**. +Contributions without verification evidence will be rejected. + +- Test your changes yourself. Don’t assume AI is correct. +- Provide screenshots showing that the change works as intended. + - For visual/UI changes: include before/after screenshots. + - For CLI or backend changes: include terminal or api output. + +## Why These Rules Exist + +Traditionally, maintainers assumed that producing a pull request required more effort than reviewing it. +With AI-assisted tools, the balance has shifted: generating code is often faster than reviewing it. + +Our guidelines exist to safeguard maintainers’ time, uphold contributor accountability, and preserve the overall quality of the project. diff --git a/docs/about/contributing/CONTRIBUTING.md b/docs/about/contributing/CONTRIBUTING.md index 7eedebb146dc5..98243d3790f77 100644 --- a/docs/about/contributing/CONTRIBUTING.md +++ b/docs/about/contributing/CONTRIBUTING.md @@ -236,6 +236,11 @@ Breaking changes can be triggered in two ways: [`release/breaking`](https://github.com/coder/coder/issues?q=sort%3Aupdated-desc+label%3Arelease%2Fbreaking) label to a PR that has, or will be, merged into `main`. +### Generative AI + +Using AI to help with contributions is acceptable, but only if the [AI Contribution Guidelines](./AI_CONTRIBUTING.md) +are followed. If most of your PR was generated by AI, please read and comply with these rules before submitting. + ### Security > [!CAUTION] diff --git a/docs/images/icons/ai_intelligence.svg b/docs/images/icons/ai_intelligence.svg new file mode 100644 index 0000000000000..bcef647bf3c3a --- /dev/null +++ b/docs/images/icons/ai_intelligence.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/manifest.json b/docs/manifest.json index bd08ccfe372e6..4a382da8ec25a 100644 --- a/docs/manifest.json +++ b/docs/manifest.json @@ -76,6 +76,12 @@ "description": "Security vulnerability disclosure policy", "path": "./about/contributing/SECURITY.md", "icon_path": "./images/icons/lock.svg" + }, + { + "title": "AI Contribution Guidelines", + "description": "Guidelines for AI-generated contributions.", + "path": "./about/contributing/AI_CONTRIBUTING.md", + "icon_path": "./images/icons/ai_intelligence.svg" } ] }