diff --git a/.gitignore b/.gitignore index d633f94583ec9..8d29eff1048d1 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,6 @@ result # Zed .zed_server + +# dlv debug binaries for go tests +__debug_bin* diff --git a/Makefile b/Makefile index e8cdcd3a3a1ba..4ada1cd6d488c 100644 --- a/Makefile +++ b/Makefile @@ -581,7 +581,8 @@ GEN_FILES := \ $(TAILNETTEST_MOCKS) \ coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ - agent/agentcontainers/dcspec/dcspec_gen.go + agent/agentcontainers/dcspec/dcspec_gen.go \ + coderd/httpmw/loggermw/loggermock/loggermock.go # all gen targets should be added here and to gen/mark-fresh gen: gen/db gen/golden-files $(GEN_FILES) @@ -630,6 +631,7 @@ gen/mark-fresh: coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ + coderd/httpmw/loggermw/loggermock/loggermock.go \ " for file in $$files; do @@ -669,6 +671,10 @@ agent/agentcontainers/acmock/acmock.go: agent/agentcontainers/containers.go go generate ./agent/agentcontainers/acmock/ touch "$@" +coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.go + go generate ./coderd/httpmw/loggermw/loggermock/ + touch "$@" + agent/agentcontainers/dcspec/dcspec_gen.go: \ node_modules/.installed \ agent/agentcontainers/dcspec/devContainer.base.schema.json \ diff --git a/agent/agent.go b/agent/agent.go index 4f07eec69db95..df962384305d5 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -907,7 +907,7 @@ func (a *agent) run() (retErr error) { defer func() { cErr := aAPI.DRPCConn().Close() if cErr != nil { - a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err)) + a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(cErr)) } }() diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d6f8f72dc5f23..7b8e024136dff 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -1913,7 +1913,9 @@ Expire-Date: 0 tpty.WriteLine("gpg --list-keys && echo gpg-''-listkeys-command-done") listKeysOutput := tpty.ExpectMatch("gpg--listkeys-command-done") require.Contains(t, listKeysOutput, "[ultimate] Coder Test ") - require.Contains(t, listKeysOutput, "[ultimate] Dean Sheather (work key) ") + // It's fine that this key is expired. We're just testing that the key trust + // gets synced properly. + require.Contains(t, listKeysOutput, "[ expired] Dean Sheather (work key) ") // Try to sign something. This demonstrates that the forwarding is // working as expected, since the workspace doesn't have access to the diff --git a/coderd/coderd.go b/coderd/coderd.go index c9a0f741afd1f..f5b0bcbefc48c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -64,6 +64,7 @@ import ( "github.com/coder/coder/v2/coderd/healthcheck/derphealth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/metricscache" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/portsharing" @@ -665,10 +666,11 @@ func New(options *Options) *API { api.Auditor.Store(&options.Auditor) api.TailnetCoordinator.Store(&options.TailnetCoordinator) dialer := &InmemTailnetDialer{ - CoordPtr: &api.TailnetCoordinator, - DERPFn: api.DERPMap, - Logger: options.Logger, - ClientID: uuid.New(), + CoordPtr: &api.TailnetCoordinator, + DERPFn: api.DERPMap, + Logger: options.Logger, + ClientID: uuid.New(), + DatabaseHealthCheck: api.Database, } stn, err := NewServerTailnet(api.ctx, options.Logger, @@ -800,7 +802,7 @@ func New(options *Options) *API { tracing.Middleware(api.TracerProvider), httpmw.AttachRequestID, httpmw.ExtractRealIP(api.RealIPConfig), - httpmw.Logger(api.Logger), + loggermw.Logger(api.Logger), singleSlashMW, rolestore.CustomRoleMW, prometheusMW, @@ -1147,64 +1149,74 @@ func New(options *Options) *API { r.Get("/", api.AssignableSiteRoles) }) r.Route("/{user}", func(r chi.Router) { - r.Use(httpmw.ExtractUserParam(options.Database)) - r.Post("/convert-login", api.postConvertLoginType) - r.Delete("/", api.deleteUser) - r.Get("/", api.userByName) - r.Get("/autofill-parameters", api.userAutofillParameters) - r.Get("/login-type", api.userLoginType) - r.Put("/profile", api.putUserProfile) - r.Route("/status", func(r chi.Router) { - r.Put("/suspend", api.putSuspendUserAccount()) - r.Put("/activate", api.putActivateUserAccount()) + r.Group(func(r chi.Router) { + r.Use(httpmw.ExtractUserParamOptional(options.Database)) + // Creating workspaces does not require permissions on the user, only the + // organization member. This endpoint should match the authz story of + // postWorkspacesByOrganization + r.Post("/workspaces", api.postUserWorkspaces) }) - r.Get("/appearance", api.userAppearanceSettings) - r.Put("/appearance", api.putUserAppearanceSettings) - r.Route("/password", func(r chi.Router) { - r.Use(httpmw.RateLimit(options.LoginRateLimit, time.Minute)) - r.Put("/", api.putUserPassword) - }) - // These roles apply to the site wide permissions. - r.Put("/roles", api.putUserRoles) - r.Get("/roles", api.userRoles) - - r.Route("/keys", func(r chi.Router) { - r.Post("/", api.postAPIKey) - r.Route("/tokens", func(r chi.Router) { - r.Post("/", api.postToken) - r.Get("/", api.tokens) - r.Get("/tokenconfig", api.tokenConfig) - r.Route("/{keyname}", func(r chi.Router) { - r.Get("/", api.apiKeyByName) - }) + + r.Group(func(r chi.Router) { + r.Use(httpmw.ExtractUserParam(options.Database)) + + r.Post("/convert-login", api.postConvertLoginType) + r.Delete("/", api.deleteUser) + r.Get("/", api.userByName) + r.Get("/autofill-parameters", api.userAutofillParameters) + r.Get("/login-type", api.userLoginType) + r.Put("/profile", api.putUserProfile) + r.Route("/status", func(r chi.Router) { + r.Put("/suspend", api.putSuspendUserAccount()) + r.Put("/activate", api.putActivateUserAccount()) }) - r.Route("/{keyid}", func(r chi.Router) { - r.Get("/", api.apiKeyByID) - r.Delete("/", api.deleteAPIKey) + r.Get("/appearance", api.userAppearanceSettings) + r.Put("/appearance", api.putUserAppearanceSettings) + r.Route("/password", func(r chi.Router) { + r.Use(httpmw.RateLimit(options.LoginRateLimit, time.Minute)) + r.Put("/", api.putUserPassword) + }) + // These roles apply to the site wide permissions. + r.Put("/roles", api.putUserRoles) + r.Get("/roles", api.userRoles) + + r.Route("/keys", func(r chi.Router) { + r.Post("/", api.postAPIKey) + r.Route("/tokens", func(r chi.Router) { + r.Post("/", api.postToken) + r.Get("/", api.tokens) + r.Get("/tokenconfig", api.tokenConfig) + r.Route("/{keyname}", func(r chi.Router) { + r.Get("/", api.apiKeyByName) + }) + }) + r.Route("/{keyid}", func(r chi.Router) { + r.Get("/", api.apiKeyByID) + r.Delete("/", api.deleteAPIKey) + }) }) - }) - r.Route("/organizations", func(r chi.Router) { - r.Get("/", api.organizationsByUser) - r.Get("/{organizationname}", api.organizationByUserAndName) - }) - r.Post("/workspaces", api.postUserWorkspaces) - r.Route("/workspace/{workspacename}", func(r chi.Router) { - r.Get("/", api.workspaceByOwnerAndName) - r.Get("/builds/{buildnumber}", api.workspaceBuildByBuildNumber) - }) - r.Get("/gitsshkey", api.gitSSHKey) - r.Put("/gitsshkey", api.regenerateGitSSHKey) - r.Route("/notifications", func(r chi.Router) { - r.Route("/preferences", func(r chi.Router) { - r.Get("/", api.userNotificationPreferences) - r.Put("/", api.putUserNotificationPreferences) + r.Route("/organizations", func(r chi.Router) { + r.Get("/", api.organizationsByUser) + r.Get("/{organizationname}", api.organizationByUserAndName) + }) + r.Route("/workspace/{workspacename}", func(r chi.Router) { + r.Get("/", api.workspaceByOwnerAndName) + r.Get("/builds/{buildnumber}", api.workspaceBuildByBuildNumber) + }) + r.Get("/gitsshkey", api.gitSSHKey) + r.Put("/gitsshkey", api.regenerateGitSSHKey) + r.Route("/notifications", func(r chi.Router) { + r.Route("/preferences", func(r chi.Router) { + r.Get("/", api.userNotificationPreferences) + r.Put("/", api.putUserNotificationPreferences) + }) + }) + r.Route("/webpush", func(r chi.Router) { + r.Post("/subscription", api.postUserWebpushSubscription) + r.Delete("/subscription", api.deleteUserWebpushSubscription) + r.Post("/test", api.postUserPushNotificationTest) }) - }) - r.Route("/webpush", func(r chi.Router) { - r.Post("/subscription", api.postUserWebpushSubscription) - r.Delete("/subscription", api.deleteUserWebpushSubscription) - r.Post("/test", api.postUserPushNotificationTest) }) }) }) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index af52f7fc70f53..279405c4e6a21 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -81,7 +81,7 @@ func AssertRBAC(t *testing.T, api *coderd.API, client *codersdk.Client) RBACAsse // Note that duplicate rbac calls are handled by the rbac.Cacher(), but // will be recorded twice. So AllCalls() returns calls regardless if they // were returned from the cached or not. -func (a RBACAsserter) AllCalls() []AuthCall { +func (a RBACAsserter) AllCalls() AuthCalls { return a.Recorder.AllCalls(&a.Subject) } @@ -140,8 +140,11 @@ func (a RBACAsserter) Reset() RBACAsserter { return a } +type AuthCalls []AuthCall + type AuthCall struct { rbac.AuthCall + Err error asserted bool // callers is a small stack trace for debugging. @@ -252,7 +255,7 @@ func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did } // recordAuthorize is the internal method that records the Authorize() call. -func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action policy.Action, object rbac.Object) { +func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action policy.Action, object rbac.Object, authzErr error) { r.Lock() defer r.Unlock() @@ -262,6 +265,7 @@ func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action polic Action: action, Object: object, }, + Err: authzErr, callers: []string{ // This is a decent stack trace for debugging. // Some dbauthz calls are a bit nested, so we skip a few. @@ -288,11 +292,12 @@ func caller(skip int) string { } func (r *RecordingAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action policy.Action, object rbac.Object) error { - r.recordAuthorize(subject, action, object) if r.Wrapped == nil { panic("Developer error: RecordingAuthorizer.Wrapped is nil") } - return r.Wrapped.Authorize(ctx, subject, action, object) + authzErr := r.Wrapped.Authorize(ctx, subject, action, object) + r.recordAuthorize(subject, action, object, authzErr) + return authzErr } func (r *RecordingAuthorizer) Prepare(ctx context.Context, subject rbac.Subject, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) { @@ -339,10 +344,11 @@ func (s *PreparedRecorder) Authorize(ctx context.Context, object rbac.Object) er s.rw.Lock() defer s.rw.Unlock() + authzErr := s.prepped.Authorize(ctx, object) if !s.usingSQL { - s.rec.recordAuthorize(s.subject, s.action, object) + s.rec.recordAuthorize(s.subject, s.action, object, authzErr) } - return s.prepped.Authorize(ctx, object) + return authzErr } func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.ConvertConfig) (string, error) { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 7ab078d32ad4f..ed1242f674d0f 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -24,6 +24,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/provisionersdk" @@ -162,6 +163,7 @@ func ActorFromContext(ctx context.Context) (rbac.Subject, bool) { var ( subjectProvisionerd = rbac.Subject{ + Type: rbac.SubjectTypeProvisionerd, FriendlyName: "Provisioner Daemon", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -196,6 +198,7 @@ var ( }.WithCachedASTValue() subjectAutostart = rbac.Subject{ + Type: rbac.SubjectTypeAutostart, FriendlyName: "Autostart", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -219,6 +222,7 @@ var ( // See unhanger package. subjectHangDetector = rbac.Subject{ + Type: rbac.SubjectTypeHangDetector, FriendlyName: "Hang Detector", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -239,6 +243,7 @@ var ( // See cryptokeys package. subjectCryptoKeyRotator = rbac.Subject{ + Type: rbac.SubjectTypeCryptoKeyRotator, FriendlyName: "Crypto Key Rotator", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -257,6 +262,7 @@ var ( // See cryptokeys package. subjectCryptoKeyReader = rbac.Subject{ + Type: rbac.SubjectTypeCryptoKeyReader, FriendlyName: "Crypto Key Reader", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -274,6 +280,7 @@ var ( }.WithCachedASTValue() subjectNotifier = rbac.Subject{ + Type: rbac.SubjectTypeNotifier, FriendlyName: "Notifier", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -294,6 +301,7 @@ var ( }.WithCachedASTValue() subjectResourceMonitor = rbac.Subject{ + Type: rbac.SubjectTypeResourceMonitor, FriendlyName: "Resource Monitor", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -312,6 +320,7 @@ var ( }.WithCachedASTValue() subjectSystemRestricted = rbac.Subject{ + Type: rbac.SubjectTypeSystemRestricted, FriendlyName: "System", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -346,6 +355,7 @@ var ( }.WithCachedASTValue() subjectSystemReadProvisionerDaemons = rbac.Subject{ + Type: rbac.SubjectTypeSystemReadProvisionerDaemons, FriendlyName: "Provisioner Daemons Reader", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -366,53 +376,53 @@ var ( // AsProvisionerd returns a context with an actor that has permissions required // for provisionerd to function. func AsProvisionerd(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectProvisionerd) + return As(ctx, subjectProvisionerd) } // AsAutostart returns a context with an actor that has permissions required // for autostart to function. func AsAutostart(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectAutostart) + return As(ctx, subjectAutostart) } // AsHangDetector returns a context with an actor that has permissions required // for unhanger.Detector to function. func AsHangDetector(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectHangDetector) + return As(ctx, subjectHangDetector) } // AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys. func AsKeyRotator(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator) + return As(ctx, subjectCryptoKeyRotator) } // AsKeyReader returns a context with an actor that has permissions required for reading crypto keys. func AsKeyReader(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader) + return As(ctx, subjectCryptoKeyReader) } // AsNotifier returns a context with an actor that has permissions required for // creating/reading/updating/deleting notifications. func AsNotifier(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectNotifier) + return As(ctx, subjectNotifier) } // AsResourceMonitor returns a context with an actor that has permissions required for // updating resource monitors. func AsResourceMonitor(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectResourceMonitor) + return As(ctx, subjectResourceMonitor) } // AsSystemRestricted returns a context with an actor that has permissions // required for various system operations (login, logout, metrics cache). func AsSystemRestricted(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectSystemRestricted) + return As(ctx, subjectSystemRestricted) } // AsSystemReadProvisionerDaemons returns a context with an actor that has permissions // to read provisioner daemons. func AsSystemReadProvisionerDaemons(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectSystemReadProvisionerDaemons) + return As(ctx, subjectSystemReadProvisionerDaemons) } var AsRemoveActor = rbac.Subject{ @@ -430,6 +440,9 @@ func As(ctx context.Context, actor rbac.Subject) context.Context { // should be removed from the context. return context.WithValue(ctx, authContextKey{}, nil) } + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithAuthContext(actor) + } return context.WithValue(ctx, authContextKey{}, actor) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 59d717531324a..81004abcd8a50 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -11628,10 +11628,10 @@ func (q *sqlQuerier) GetActiveUserCount(ctx context.Context, includeSystem bool) const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT - -- username is returned just to help for logging purposes + -- username and email are returned just to help for logging purposes -- status is used to enforce 'suspended' users, as all roles are ignored -- when suspended. - id, username, status, + id, username, status, email, -- All user roles, including their org roles. array_cat( -- All users are members @@ -11672,6 +11672,7 @@ type GetAuthorizationUserRolesRow struct { ID uuid.UUID `db:"id" json:"id"` Username string `db:"username" json:"username"` Status UserStatus `db:"status" json:"status"` + Email string `db:"email" json:"email"` Roles []string `db:"roles" json:"roles"` Groups []string `db:"groups" json:"groups"` } @@ -11685,6 +11686,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. &i.ID, &i.Username, &i.Status, + &i.Email, pq.Array(&i.Roles), pq.Array(&i.Groups), ) diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index c4304cfc3e60e..ebf0e10b5d61c 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -271,10 +271,10 @@ WHERE -- This function returns roles for authorization purposes. Implied member roles -- are included. SELECT - -- username is returned just to help for logging purposes + -- username and email are returned just to help for logging purposes -- status is used to enforce 'suspended' users, as all roles are ignored -- when suspended. - id, username, status, + id, username, status, email, -- All user roles, including their org roles. array_cat( -- All users are members diff --git a/coderd/httpapi/noop.go b/coderd/httpapi/noop.go new file mode 100644 index 0000000000000..52a0f5dd4d8a4 --- /dev/null +++ b/coderd/httpapi/noop.go @@ -0,0 +1,10 @@ +package httpapi + +import "net/http" + +// NoopResponseWriter is a response writer that does nothing. +type NoopResponseWriter struct{} + +func (NoopResponseWriter) Header() http.Header { return http.Header{} } +func (NoopResponseWriter) Write(p []byte) (int, error) { return len(p), nil } +func (NoopResponseWriter) WriteHeader(int) {} diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 1574affa30b65..d614b37a3d897 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -465,7 +465,9 @@ func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, s } actor := rbac.Subject{ + Type: rbac.SubjectTypeUser, FriendlyName: roles.Username, + Email: roles.Email, ID: userID.String(), Roles: rbacRoles, Groups: roles.Groups, diff --git a/coderd/httpmw/logger.go b/coderd/httpmw/logger.go deleted file mode 100644 index 79e95cf859d8e..0000000000000 --- a/coderd/httpmw/logger.go +++ /dev/null @@ -1,76 +0,0 @@ -package httpmw - -import ( - "context" - "fmt" - "net/http" - "time" - - "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/httpapi" - "github.com/coder/coder/v2/coderd/tracing" -) - -func Logger(log slog.Logger) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - start := time.Now() - - sw, ok := rw.(*tracing.StatusWriter) - if !ok { - panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw)) - } - - httplog := log.With( - slog.F("host", httpapi.RequestHost(r)), - slog.F("path", r.URL.Path), - slog.F("proto", r.Proto), - slog.F("remote_addr", r.RemoteAddr), - // Include the start timestamp in the log so that we have the - // source of truth. There is at least a theoretical chance that - // there can be a delay between `next.ServeHTTP` ending and us - // actually logging the request. This can also be useful when - // filtering logs that started at a certain time (compared to - // trying to compute the value). - slog.F("start", start), - ) - - next.ServeHTTP(sw, r) - - end := time.Now() - - // Don't log successful health check requests. - if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK { - return - } - - httplog = httplog.With( - slog.F("took", end.Sub(start)), - slog.F("status_code", sw.Status), - slog.F("latency_ms", float64(end.Sub(start)/time.Millisecond)), - ) - - // For status codes 400 and higher we - // want to log the response body. - if sw.Status >= http.StatusInternalServerError { - httplog = httplog.With( - slog.F("response_body", string(sw.ResponseBody())), - ) - } - - // We should not log at level ERROR for 5xx status codes because 5xx - // includes proxy errors etc. It also causes slogtest to fail - // instantly without an error message by default. - logLevelFn := httplog.Debug - if sw.Status >= http.StatusInternalServerError { - logLevelFn = httplog.Warn - } - - // We already capture most of this information in the span (minus - // the response body which we don't want to capture anyways). - tracing.RunWithoutSpan(r.Context(), func(ctx context.Context) { - logLevelFn(ctx, r.Method) - }) - }) - } -} diff --git a/coderd/httpmw/loggermw/logger.go b/coderd/httpmw/loggermw/logger.go new file mode 100644 index 0000000000000..9eeb07a5f10e5 --- /dev/null +++ b/coderd/httpmw/loggermw/logger.go @@ -0,0 +1,203 @@ +package loggermw + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/tracing" +) + +func Logger(log slog.Logger) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + start := time.Now() + + sw, ok := rw.(*tracing.StatusWriter) + if !ok { + panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw)) + } + + httplog := log.With( + slog.F("host", httpapi.RequestHost(r)), + slog.F("path", r.URL.Path), + slog.F("proto", r.Proto), + slog.F("remote_addr", r.RemoteAddr), + // Include the start timestamp in the log so that we have the + // source of truth. There is at least a theoretical chance that + // there can be a delay between `next.ServeHTTP` ending and us + // actually logging the request. This can also be useful when + // filtering logs that started at a certain time (compared to + // trying to compute the value). + slog.F("start", start), + ) + + logContext := NewRequestLogger(httplog, r.Method, start) + + ctx := WithRequestLogger(r.Context(), logContext) + + next.ServeHTTP(sw, r.WithContext(ctx)) + + // Don't log successful health check requests. + if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK { + return + } + + // For status codes 500 and higher we + // want to log the response body. + if sw.Status >= http.StatusInternalServerError { + logContext.WithFields( + slog.F("response_body", string(sw.ResponseBody())), + ) + } + + logContext.WriteLog(r.Context(), sw.Status) + }) + } +} + +type RequestLogger interface { + WithFields(fields ...slog.Field) + WriteLog(ctx context.Context, status int) + WithAuthContext(actor rbac.Subject) +} + +type SlogRequestLogger struct { + log slog.Logger + written bool + message string + start time.Time + // Protects actors map for concurrent writes. + mu sync.RWMutex + actors map[rbac.SubjectType]rbac.Subject +} + +var _ RequestLogger = &SlogRequestLogger{} + +func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger { + return &SlogRequestLogger{ + log: log, + written: false, + message: message, + start: start, + actors: make(map[rbac.SubjectType]rbac.Subject), + } +} + +func (c *SlogRequestLogger) WithFields(fields ...slog.Field) { + c.log = c.log.With(fields...) +} + +func (c *SlogRequestLogger) WithAuthContext(actor rbac.Subject) { + c.mu.Lock() + defer c.mu.Unlock() + c.actors[actor.Type] = actor +} + +func (c *SlogRequestLogger) addAuthContextFields() { + c.mu.RLock() + defer c.mu.RUnlock() + + usr, ok := c.actors[rbac.SubjectTypeUser] + if ok { + c.log = c.log.With( + slog.F("requestor_id", usr.ID), + slog.F("requestor_name", usr.FriendlyName), + slog.F("requestor_email", usr.Email), + ) + } else { + // If there is no user, we log the requestor name for the first + // actor in a defined order. + for _, v := range actorLogOrder { + subj, ok := c.actors[v] + if !ok { + continue + } + c.log = c.log.With( + slog.F("requestor_name", subj.FriendlyName), + ) + break + } + } +} + +var actorLogOrder = []rbac.SubjectType{ + rbac.SubjectTypeAutostart, + rbac.SubjectTypeCryptoKeyReader, + rbac.SubjectTypeCryptoKeyRotator, + rbac.SubjectTypeHangDetector, + rbac.SubjectTypeNotifier, + rbac.SubjectTypePrebuildsOrchestrator, + rbac.SubjectTypeProvisionerd, + rbac.SubjectTypeResourceMonitor, + rbac.SubjectTypeSystemReadProvisionerDaemons, + rbac.SubjectTypeSystemRestricted, +} + +func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) { + if c.written { + return + } + c.written = true + end := time.Now() + + // Right before we write the log, we try to find the user in the actors + // and add the fields to the log. + c.addAuthContextFields() + + logger := c.log.With( + slog.F("took", end.Sub(c.start)), + slog.F("status_code", status), + slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)), + ) + + // If the request is routed, add the route parameters to the log. + if chiCtx := chi.RouteContext(ctx); chiCtx != nil { + urlParams := chiCtx.URLParams + routeParamsFields := make([]slog.Field, 0, len(urlParams.Keys)) + + for k, v := range urlParams.Keys { + if urlParams.Values[k] != "" { + routeParamsFields = append(routeParamsFields, slog.F("params_"+v, urlParams.Values[k])) + } + } + + if len(routeParamsFields) > 0 { + logger = logger.With(routeParamsFields...) + } + } + + // We already capture most of this information in the span (minus + // the response body which we don't want to capture anyways). + tracing.RunWithoutSpan(ctx, func(ctx context.Context) { + // We should not log at level ERROR for 5xx status codes because 5xx + // includes proxy errors etc. It also causes slogtest to fail + // instantly without an error message by default. + if status >= http.StatusInternalServerError { + logger.Warn(ctx, c.message) + } else { + logger.Debug(ctx, c.message) + } + }) +} + +type logContextKey struct{} + +func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context { + return context.WithValue(ctx, logContextKey{}, rl) +} + +func RequestLoggerFromContext(ctx context.Context) RequestLogger { + val := ctx.Value(logContextKey{}) + if logCtx, ok := val.(RequestLogger); ok { + return logCtx + } + return nil +} diff --git a/coderd/httpmw/loggermw/logger_internal_test.go b/coderd/httpmw/loggermw/logger_internal_test.go new file mode 100644 index 0000000000000..e88f8a69c178e --- /dev/null +++ b/coderd/httpmw/loggermw/logger_internal_test.go @@ -0,0 +1,311 @@ +package loggermw + +import ( + "context" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +func TestRequestLogger_WriteLog(t *testing.T) { + t.Parallel() + ctx := context.Background() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + logCtx := NewRequestLogger(logger, "GET", time.Now()) + + // Add custom fields + logCtx.WithFields( + slog.F("custom_field", "custom_value"), + ) + + // Write log for 200 status + logCtx.WriteLog(ctx, http.StatusOK) + + require.Len(t, sink.entries, 1, "log was written twice") + + require.Equal(t, sink.entries[0].Message, "GET") + + require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value") + + // Attempt to write again (should be skipped). + logCtx.WriteLog(ctx, http.StatusInternalServerError) + + require.Len(t, sink.entries, 1, "log was written twice") +} + +func TestLoggerMiddleware_SingleRequest(t *testing.T) { + t.Parallel() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + // Create a test handler to simulate an HTTP request + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte("OK")) + }) + + // Wrap the test handler with the Logger middleware + loggerMiddleware := Logger(logger) + wrappedHandler := loggerMiddleware(testHandler) + + // Create a test HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil) + require.NoError(t, err, "failed to create request") + + sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} + + // Serve the request + wrappedHandler.ServeHTTP(sw, req) + + require.Len(t, sink.entries, 1, "log was written twice") + + require.Equal(t, sink.entries[0].Message, "GET") + + fieldsMap := make(map[string]any) + for _, field := range sink.entries[0].Fields { + fieldsMap[field.Name] = field.Value + } + + // Check that the log contains the expected fields + requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"} + for _, field := range requiredFields { + _, exists := fieldsMap[field] + require.True(t, exists, "field %q is missing in log fields", field) + } + + require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields") + + // Check value of the status code + require.Equal(t, fieldsMap["status_code"], http.StatusOK) +} + +func TestLoggerMiddleware_WebSocket(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + sink := &fakeSink{ + newEntries: make(chan slog.SinkEntry, 2), + } + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + done := make(chan struct{}) + wg := sync.WaitGroup{} + // Create a test handler to simulate a WebSocket connection + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(rw, r, nil) + if !assert.NoError(t, err, "failed to accept websocket") { + return + } + defer conn.Close(websocket.StatusGoingAway, "") + + requestLgr := RequestLoggerFromContext(r.Context()) + requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols) + // Block so we can be sure the end of the middleware isn't being called. + wg.Wait() + }) + + // Wrap the test handler with the Logger middleware + loggerMiddleware := Logger(logger) + wrappedHandler := loggerMiddleware(testHandler) + + // RequestLogger expects the ResponseWriter to be *tracing.StatusWriter + customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + defer close(done) + sw := &tracing.StatusWriter{ResponseWriter: rw} + wrappedHandler.ServeHTTP(sw, r) + }) + + srv := httptest.NewServer(customHandler) + defer srv.Close() + wg.Add(1) + // nolint: bodyclose + conn, _, err := websocket.Dial(ctx, srv.URL, nil) + require.NoError(t, err, "failed to dial WebSocket") + defer conn.Close(websocket.StatusNormalClosure, "") + + // Wait for the log from within the handler + newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries) + require.Equal(t, newEntry.Message, "GET") + + // Signal the websocket handler to return (and read to handle the close frame) + wg.Done() + _, _, err = conn.Read(ctx) + require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error") + + // Wait for the request to finish completely and verify we only logged once + _ = testutil.RequireRecvCtx(ctx, t, done) + require.Len(t, sink.entries, 1, "log was written twice") +} + +func TestRequestLogger_HTTPRouteParams(t *testing.T) { + t.Parallel() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("workspace", "test-workspace") + chiCtx.URLParams.Add("agent", "test-agent") + + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + + // Create a test handler to simulate an HTTP request + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte("OK")) + }) + + // Wrap the test handler with the Logger middleware + loggerMiddleware := Logger(logger) + wrappedHandler := loggerMiddleware(testHandler) + + // Create a test HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path/}", nil) + require.NoError(t, err, "failed to create request") + + sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} + + // Serve the request + wrappedHandler.ServeHTTP(sw, req) + + fieldsMap := make(map[string]any) + for _, field := range sink.entries[0].Fields { + fieldsMap[field.Name] = field.Value + } + + // Check that the log contains the expected fields + requiredFields := []string{"workspace", "agent"} + for _, field := range requiredFields { + _, exists := fieldsMap["params_"+field] + require.True(t, exists, "field %q is missing in log fields", field) + } +} + +func TestRequestLogger_RouteParamsLogging(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params map[string]string + expectedFields []string + }{ + { + name: "EmptyParams", + params: map[string]string{}, + expectedFields: []string{}, + }, + { + name: "SingleParam", + params: map[string]string{ + "workspace": "test-workspace", + }, + expectedFields: []string{"params_workspace"}, + }, + { + name: "MultipleParams", + params: map[string]string{ + "workspace": "test-workspace", + "agent": "test-agent", + "user": "test-user", + }, + expectedFields: []string{"params_workspace", "params_agent", "params_user"}, + }, + { + name: "EmptyValueParam", + params: map[string]string{ + "workspace": "test-workspace", + "agent": "", + }, + expectedFields: []string{"params_workspace"}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + + // Create a route context with the test parameters + chiCtx := chi.NewRouteContext() + for key, value := range tt.params { + chiCtx.URLParams.Add(key, value) + } + + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + logCtx := NewRequestLogger(logger, "GET", time.Now()) + + // Write the log + logCtx.WriteLog(ctx, http.StatusOK) + + require.Len(t, sink.entries, 1, "expected exactly one log entry") + + // Convert fields to map for easier checking + fieldsMap := make(map[string]any) + for _, field := range sink.entries[0].Fields { + fieldsMap[field.Name] = field.Value + } + + // Verify expected fields are present + for _, field := range tt.expectedFields { + value, exists := fieldsMap[field] + require.True(t, exists, "field %q should be present in log", field) + require.Equal(t, tt.params[strings.TrimPrefix(field, "params_")], value, "field %q has incorrect value", field) + } + + // Verify no unexpected fields are present + for field := range fieldsMap { + if field == "took" || field == "status_code" || field == "latency_ms" { + continue // Skip standard fields + } + require.True(t, slices.Contains(tt.expectedFields, field), "unexpected field %q in log", field) + } + }) + } +} + +type fakeSink struct { + entries []slog.SinkEntry + newEntries chan slog.SinkEntry +} + +func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) { + s.entries = append(s.entries, e) + if s.newEntries != nil { + select { + case s.newEntries <- e: + default: + } + } +} + +func (*fakeSink) Sync() {} diff --git a/coderd/httpmw/loggermw/loggermock/loggermock.go b/coderd/httpmw/loggermw/loggermock/loggermock.go new file mode 100644 index 0000000000000..008f862107ae6 --- /dev/null +++ b/coderd/httpmw/loggermw/loggermock/loggermock.go @@ -0,0 +1,83 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/coderd/httpmw/loggermw (interfaces: RequestLogger) +// +// Generated by this command: +// +// mockgen -destination=loggermock/loggermock.go -package=loggermock . RequestLogger +// + +// Package loggermock is a generated GoMock package. +package loggermock + +import ( + context "context" + reflect "reflect" + + slog "cdr.dev/slog" + rbac "github.com/coder/coder/v2/coderd/rbac" + gomock "go.uber.org/mock/gomock" +) + +// MockRequestLogger is a mock of RequestLogger interface. +type MockRequestLogger struct { + ctrl *gomock.Controller + recorder *MockRequestLoggerMockRecorder + isgomock struct{} +} + +// MockRequestLoggerMockRecorder is the mock recorder for MockRequestLogger. +type MockRequestLoggerMockRecorder struct { + mock *MockRequestLogger +} + +// NewMockRequestLogger creates a new mock instance. +func NewMockRequestLogger(ctrl *gomock.Controller) *MockRequestLogger { + mock := &MockRequestLogger{ctrl: ctrl} + mock.recorder = &MockRequestLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequestLogger) EXPECT() *MockRequestLoggerMockRecorder { + return m.recorder +} + +// WithAuthContext mocks base method. +func (m *MockRequestLogger) WithAuthContext(actor rbac.Subject) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "WithAuthContext", actor) +} + +// WithAuthContext indicates an expected call of WithAuthContext. +func (mr *MockRequestLoggerMockRecorder) WithAuthContext(actor any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithAuthContext", reflect.TypeOf((*MockRequestLogger)(nil).WithAuthContext), actor) +} + +// WithFields mocks base method. +func (m *MockRequestLogger) WithFields(fields ...slog.Field) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range fields { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "WithFields", varargs...) +} + +// WithFields indicates an expected call of WithFields. +func (mr *MockRequestLoggerMockRecorder) WithFields(fields ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockRequestLogger)(nil).WithFields), fields...) +} + +// WriteLog mocks base method. +func (m *MockRequestLogger) WriteLog(ctx context.Context, status int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "WriteLog", ctx, status) +} + +// WriteLog indicates an expected call of WriteLog. +func (mr *MockRequestLoggerMockRecorder) WriteLog(ctx, status any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLog", reflect.TypeOf((*MockRequestLogger)(nil).WriteLog), ctx, status) +} diff --git a/coderd/httpmw/organizationparam.go b/coderd/httpmw/organizationparam.go index 18938ec1e792d..782a0d37e1985 100644 --- a/coderd/httpmw/organizationparam.go +++ b/coderd/httpmw/organizationparam.go @@ -117,7 +117,7 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H // very important that we do not add the User object to the request context or otherwise // leak it to the API handler. // nolint:gocritic - user, ok := extractUserContext(dbauthz.AsSystemRestricted(ctx), db, rw, r) + user, ok := ExtractUserContext(dbauthz.AsSystemRestricted(ctx), db, rw, r) if !ok { return } diff --git a/coderd/httpmw/prometheus.go b/coderd/httpmw/prometheus.go index b96be84e879e3..8b7b33381c74d 100644 --- a/coderd/httpmw/prometheus.go +++ b/coderd/httpmw/prometheus.go @@ -3,6 +3,7 @@ package httpmw import ( "net/http" "strconv" + "strings" "time" "github.com/go-chi/chi/v5" @@ -22,18 +23,18 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler Name: "requests_processed_total", Help: "The total number of processed API requests", }, []string{"code", "method", "path"}) - requestsConcurrent := factory.NewGauge(prometheus.GaugeOpts{ + requestsConcurrent := factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "api", Name: "concurrent_requests", Help: "The number of concurrent API requests.", - }) - websocketsConcurrent := factory.NewGauge(prometheus.GaugeOpts{ + }, []string{"method", "path"}) + websocketsConcurrent := factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "api", Name: "concurrent_websockets", Help: "The total number of concurrent API websockets.", - }) + }, []string{"path"}) websocketsDist := factory.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "coderd", Subsystem: "api", @@ -61,7 +62,6 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler var ( start = time.Now() method = r.Method - rctx = chi.RouteContext(r.Context()) ) sw, ok := w.(*tracing.StatusWriter) @@ -72,16 +72,18 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler var ( dist *prometheus.HistogramVec distOpts []string + path = getRoutePattern(r) ) + // We want to count WebSockets separately. if httpapi.IsWebsocketUpgrade(r) { - websocketsConcurrent.Inc() - defer websocketsConcurrent.Dec() + websocketsConcurrent.WithLabelValues(path).Inc() + defer websocketsConcurrent.WithLabelValues(path).Dec() dist = websocketsDist } else { - requestsConcurrent.Inc() - defer requestsConcurrent.Dec() + requestsConcurrent.WithLabelValues(method, path).Inc() + defer requestsConcurrent.WithLabelValues(method, path).Dec() dist = requestsDist distOpts = []string{method} @@ -89,7 +91,6 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler next.ServeHTTP(w, r) - path := rctx.RoutePattern() distOpts = append(distOpts, path) statusStr := strconv.Itoa(sw.Status) @@ -98,3 +99,34 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler }) } } + +func getRoutePattern(r *http.Request) string { + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + return "" + } + + if pattern := rctx.RoutePattern(); pattern != "" { + // Pattern is already available + return pattern + } + + routePath := r.URL.Path + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } + + tctx := chi.NewRouteContext() + routes := rctx.Routes + if routes != nil && !routes.Match(tctx, r.Method, routePath) { + // No matching pattern. /api/* requests will be matched as "UNKNOWN" + // All other ones will be matched as "STATIC". + if strings.HasPrefix(routePath, "/api/") { + return "UNKNOWN" + } + return "STATIC" + } + + // tctx has the updated pattern, since Match mutates it + return tctx.RoutePattern() +} diff --git a/coderd/httpmw/prometheus_test.go b/coderd/httpmw/prometheus_test.go index a51eea5d00312..d40558f5ca5e7 100644 --- a/coderd/httpmw/prometheus_test.go +++ b/coderd/httpmw/prometheus_test.go @@ -8,14 +8,19 @@ import ( "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus" + cm "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" ) func TestPrometheus(t *testing.T) { t.Parallel() + t.Run("All", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/", nil) @@ -29,4 +34,90 @@ func TestPrometheus(t *testing.T) { require.NoError(t, err) require.Greater(t, len(metrics), 0) }) + + t.Run("Concurrent", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + reg := prometheus.NewRegistry() + promMW := httpmw.Prometheus(reg) + + // Create a test handler to simulate a WebSocket connection + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(rw, r, nil) + if !assert.NoError(t, err, "failed to accept websocket") { + return + } + defer conn.Close(websocket.StatusGoingAway, "") + }) + + wrappedHandler := promMW(testHandler) + + r := chi.NewRouter() + r.Use(tracing.StatusWriterMiddleware, promMW) + r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) { + wrappedHandler.ServeHTTP(rw, r) + }) + + srv := httptest.NewServer(r) + defer srv.Close() + // nolint: bodyclose + conn, _, err := websocket.Dial(ctx, srv.URL+"/api/v2/build/1/logs", nil) + require.NoError(t, err, "failed to dial WebSocket") + defer conn.Close(websocket.StatusNormalClosure, "") + + metrics, err := reg.Gather() + require.NoError(t, err) + require.Greater(t, len(metrics), 0) + metricLabels := getMetricLabels(metrics) + + concurrentWebsockets, ok := metricLabels["coderd_api_concurrent_websockets"] + require.True(t, ok, "coderd_api_concurrent_websockets metric not found") + require.Equal(t, "/api/v2/build/{build}/logs", concurrentWebsockets["path"]) + }) + + t.Run("UserRoute", func(t *testing.T) { + t.Parallel() + reg := prometheus.NewRegistry() + promMW := httpmw.Prometheus(reg) + + r := chi.NewRouter() + r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {}) + + req := httptest.NewRequest("GET", "/api/v2/users/john", nil) + + sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} + + r.ServeHTTP(sw, req) + + metrics, err := reg.Gather() + require.NoError(t, err) + require.Greater(t, len(metrics), 0) + metricLabels := getMetricLabels(metrics) + + reqProcessed, ok := metricLabels["coderd_api_requests_processed_total"] + require.True(t, ok, "coderd_api_requests_processed_total metric not found") + require.Equal(t, "/api/v2/users/{user}", reqProcessed["path"]) + require.Equal(t, "GET", reqProcessed["method"]) + + concurrentRequests, ok := metricLabels["coderd_api_concurrent_requests"] + require.True(t, ok, "coderd_api_concurrent_requests metric not found") + require.Equal(t, "/api/v2/users/{user}", concurrentRequests["path"]) + require.Equal(t, "GET", concurrentRequests["method"]) + }) +} + +func getMetricLabels(metrics []*cm.MetricFamily) map[string]map[string]string { + metricLabels := map[string]map[string]string{} + for _, metricFamily := range metrics { + metricName := metricFamily.GetName() + metricLabels[metricName] = map[string]string{} + for _, metric := range metricFamily.GetMetric() { + for _, labelPair := range metric.GetLabel() { + metricLabels[metricName][labelPair.GetName()] = labelPair.GetValue() + } + } + } + return metricLabels } diff --git a/coderd/httpmw/userparam.go b/coderd/httpmw/userparam.go index 03bff9bbb5596..2fbcc458489f9 100644 --- a/coderd/httpmw/userparam.go +++ b/coderd/httpmw/userparam.go @@ -31,13 +31,18 @@ func UserParam(r *http.Request) database.User { return user } +func UserParamOptional(r *http.Request) (database.User, bool) { + user, ok := r.Context().Value(userParamContextKey{}).(database.User) + return user, ok +} + // ExtractUserParam extracts a user from an ID/username in the {user} URL // parameter. func ExtractUserParam(db database.Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - user, ok := extractUserContext(ctx, db, rw, r) + user, ok := ExtractUserContext(ctx, db, rw, r) if !ok { // response already handled return @@ -48,15 +53,31 @@ func ExtractUserParam(db database.Store) func(http.Handler) http.Handler { } } -// extractUserContext queries the database for the parameterized `{user}` from the request URL. -func extractUserContext(ctx context.Context, db database.Store, rw http.ResponseWriter, r *http.Request) (user database.User, ok bool) { +// ExtractUserParamOptional does not fail if no user is present. +func ExtractUserParamOptional(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + user, ok := ExtractUserContext(ctx, db, &httpapi.NoopResponseWriter{}, r) + if ok { + ctx = context.WithValue(ctx, userParamContextKey{}, user) + } + + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} + +// ExtractUserContext queries the database for the parameterized `{user}` from the request URL. +func ExtractUserContext(ctx context.Context, db database.Store, rw http.ResponseWriter, r *http.Request) (user database.User, ok bool) { // userQuery is either a uuid, a username, or 'me' userQuery := chi.URLParam(r, "user") if userQuery == "" { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "\"user\" must be provided.", }) - return database.User{}, true + return database.User{}, false } if userQuery == "me" { diff --git a/coderd/httpmw/workspaceagentparam.go b/coderd/httpmw/workspaceagentparam.go index a47ce3c377ae0..434e057c0eccc 100644 --- a/coderd/httpmw/workspaceagentparam.go +++ b/coderd/httpmw/workspaceagentparam.go @@ -6,8 +6,11 @@ import ( "github.com/go-chi/chi/v5" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/codersdk" ) @@ -81,6 +84,14 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl ctx = context.WithValue(ctx, workspaceAgentParamContextKey{}, agent) chi.RouteContext(ctx).URLParams.Add("workspace", build.WorkspaceID.String()) + + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithFields( + slog.F("workspace_name", resource.Name), + slog.F("agent_name", agent.Name), + ) + } + next.ServeHTTP(rw, r.WithContext(ctx)) }) } diff --git a/coderd/httpmw/workspaceparam.go b/coderd/httpmw/workspaceparam.go index 21e8dcfd62863..0c4e4f77354fc 100644 --- a/coderd/httpmw/workspaceparam.go +++ b/coderd/httpmw/workspaceparam.go @@ -9,8 +9,11 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/codersdk" ) @@ -48,6 +51,11 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler { } ctx = context.WithValue(ctx, workspaceParamContextKey{}, workspace) + + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithFields(slog.F("workspace_name", workspace.Name)) + } + next.ServeHTTP(rw, r.WithContext(ctx)) }) } @@ -154,6 +162,13 @@ func ExtractWorkspaceAndAgentParam(db database.Store) func(http.Handler) http.Ha ctx = context.WithValue(ctx, workspaceParamContextKey{}, workspace) ctx = context.WithValue(ctx, workspaceAgentParamContextKey{}, agent) + + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithFields( + slog.F("workspace_name", workspace.Name), + slog.F("agent_name", agent.Name), + ) + } next.ServeHTTP(rw, r.WithContext(ctx)) }) } diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 4524284260359..b85ce1b749e28 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -268,6 +268,9 @@ func (s *GroupSyncSettings) Set(v string) error { } func (s *GroupSyncSettings) String() string { + if s.Mapping == nil { + s.Mapping = make(map[string][]uuid.UUID) + } return runtimeconfig.JSONString(s) } diff --git a/coderd/idpsync/organization.go b/coderd/idpsync/organization.go index 87fd9af5e935d..22f5f6e257d3e 100644 --- a/coderd/idpsync/organization.go +++ b/coderd/idpsync/organization.go @@ -168,6 +168,9 @@ func (s *OrganizationSyncSettings) Set(v string) error { } func (s *OrganizationSyncSettings) String() string { + if s.Mapping == nil { + s.Mapping = make(map[string][]uuid.UUID) + } return runtimeconfig.JSONString(s) } diff --git a/coderd/idpsync/role.go b/coderd/idpsync/role.go index 54ec787661826..c21e7c99c4614 100644 --- a/coderd/idpsync/role.go +++ b/coderd/idpsync/role.go @@ -286,5 +286,8 @@ func (s *RoleSyncSettings) Set(v string) error { } func (s *RoleSyncSettings) String() string { + if s.Mapping == nil { + s.Mapping = make(map[string][]string) + } return runtimeconfig.JSONString(s) } diff --git a/coderd/inboxnotifications.go b/coderd/inboxnotifications.go index 6da047241d790..bc357bf2e35f2 100644 --- a/coderd/inboxnotifications.go +++ b/coderd/inboxnotifications.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/pubsub" markdown "github.com/coder/coder/v2/coderd/render" @@ -219,6 +220,9 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) encoder := wsjson.NewEncoder[codersdk.GetInboxNotificationResponse](conn, websocket.MessageText) defer encoder.Close(websocket.StatusNormalClosure) + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + for { select { case <-ctx.Done(): diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index bcf344fc56c3f..46976e814483d 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -1415,13 +1415,15 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) return nil, xerrors.Errorf("update template version external auth providers: %w", err) } - err = s.Database.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{ - JobID: jobID, - CachedPlan: jobType.TemplateImport.Plan, - UpdatedAt: now, - }) - if err != nil { - return nil, xerrors.Errorf("insert template version terraform data: %w", err) + if len(jobType.TemplateImport.Plan) > 0 { + err := s.Database.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{ + JobID: jobID, + CachedPlan: jobType.TemplateImport.Plan, + UpdatedAt: now, + }) + if err != nil { + return nil, xerrors.Errorf("insert template version terraform data: %w", err) + } } err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 47963798f4d32..6d75227a14ccd 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/slice" @@ -554,6 +555,9 @@ func (f *logFollower) follow() { return } + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(f.ctx).WriteLog(f.ctx, http.StatusAccepted) + // no need to wait if the job is done if f.complete { return diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index af5a7d66a6f4c..f3bc2eb1dea99 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -19,6 +19,8 @@ import ( "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw/loggermock" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/testutil" @@ -305,11 +307,16 @@ func Test_logFollower_EndOfLogs(t *testing.T) { JobStatus: database.ProvisionerJobStatusRunning, } + mockLogger := loggermock.NewMockRequestLogger(ctrl) + mockLogger.EXPECT().WriteLog(gomock.Any(), http.StatusAccepted).Times(1) + ctx = loggermw.WithRequestLogger(ctx, mockLogger) + // we need an HTTP server to get a websocket srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0) uut.follow() })) + defer srv.Close() // job was incomplete when we create the logFollower, and still incomplete when it queries diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index aaba7d6eae3af..02268e052d2e2 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -57,6 +57,23 @@ func hashAuthorizeCall(actor Subject, action policy.Action, object Object) [32]b return hashOut } +// SubjectType represents the type of subject in the RBAC system. +type SubjectType string + +const ( + SubjectTypeUser SubjectType = "user" + SubjectTypeProvisionerd SubjectType = "provisionerd" + SubjectTypeAutostart SubjectType = "autostart" + SubjectTypeHangDetector SubjectType = "hang_detector" + SubjectTypeResourceMonitor SubjectType = "resource_monitor" + SubjectTypeCryptoKeyRotator SubjectType = "crypto_key_rotator" + SubjectTypeCryptoKeyReader SubjectType = "crypto_key_reader" + SubjectTypePrebuildsOrchestrator SubjectType = "prebuilds_orchestrator" + SubjectTypeSystemReadProvisionerDaemons SubjectType = "system_read_provisioner_daemons" + SubjectTypeSystemRestricted SubjectType = "system_restricted" + SubjectTypeNotifier SubjectType = "notifier" +) + // Subject is a struct that contains all the elements of a subject in an rbac // authorize. type Subject struct { @@ -66,6 +83,14 @@ type Subject struct { // external workspace proxy or other service type actor. FriendlyName string + // Email is entirely optional and is used for logging and debugging + // It is not used in any functional way. + Email string + + // Type indicates what kind of subject this is (user, system, provisioner, etc.) + // It is not used in any functional way, only for logging. + Type SubjectType + ID string Roles ExpandableRoles Groups []string diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index 4f42de94a4c52..9beef03dd8f9a 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -1,10 +1,14 @@ package rbac import ( + "fmt" + "strings" + "github.com/google/uuid" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/rbac/policy" + cstrings "github.com/coder/coder/v2/coderd/util/strings" ) // ResourceUserObject is a helper function to create a user object for authz checks. @@ -37,6 +41,25 @@ type Object struct { ACLGroupList map[string][]policy.Action ` json:"acl_group_list"` } +// String is not perfect, but decent enough for human display +func (z Object) String() string { + var parts []string + if z.OrgID != "" { + parts = append(parts, fmt.Sprintf("org:%s", cstrings.Truncate(z.OrgID, 4))) + } + if z.Owner != "" { + parts = append(parts, fmt.Sprintf("owner:%s", cstrings.Truncate(z.Owner, 4))) + } + parts = append(parts, z.Type) + if z.ID != "" { + parts = append(parts, fmt.Sprintf("id:%s", cstrings.Truncate(z.ID, 4))) + } + if len(z.ACLGroupList) > 0 || len(z.ACLUserList) > 0 { + parts = append(parts, fmt.Sprintf("acl:%d", len(z.ACLUserList)+len(z.ACLGroupList))) + } + return strings.Join(parts, ".") +} + // ValidAction checks if the action is valid for the given object type. func (z Object) ValidAction(action policy.Action) error { perms, ok := policy.RBACPermissions[z.Type] diff --git a/coderd/tailnet.go b/coderd/tailnet.go index b06219db40a78..cfdc667f4da0f 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -24,9 +24,11 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/site" "github.com/coder/coder/v2/tailnet" @@ -534,6 +536,10 @@ func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer tra return m } +type Pinger interface { + Ping(context.Context) (time.Duration, error) +} + // InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap // service running in the same memory space. type InmemTailnetDialer struct { @@ -541,9 +547,17 @@ type InmemTailnetDialer struct { DERPFn func() *tailcfg.DERPMap Logger slog.Logger ClientID uuid.UUID + // DatabaseHealthCheck is used to validate that the store is reachable. + DatabaseHealthCheck Pinger } -func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { +func (a *InmemTailnetDialer) Dial(ctx context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { + if a.DatabaseHealthCheck != nil { + if _, err := a.DatabaseHealthCheck.Ping(ctx); err != nil { + return tailnet.ControlProtocolClients{}, xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err) + } + } + coord := a.CoordPtr.Load() if coord == nil { return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized") diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index b0aaaedc769c0..28265404c3eae 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -11,6 +11,7 @@ import ( "strconv" "sync/atomic" "testing" + "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" @@ -18,6 +19,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" "tailscale.com/tailcfg" "github.com/coder/coder/v2/agent" @@ -25,6 +27,7 @@ import ( "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/tailnet" @@ -365,6 +368,44 @@ func TestServerTailnet_ReverseProxy(t *testing.T) { }) } +func TestDialFailure(t *testing.T) { + t.Parallel() + + // Setup. + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + + // Given: a tailnet coordinator. + coord := tailnet.NewCoordinator(logger) + t.Cleanup(func() { + _ = coord.Close() + }) + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + + // Given: a fake DB healthchecker which will always fail. + fch := &failingHealthcheck{} + + // When: dialing the in-memory coordinator. + dialer := &coderd.InmemTailnetDialer{ + CoordPtr: &coordPtr, + Logger: logger, + ClientID: uuid.UUID{5}, + DatabaseHealthCheck: fch, + } + _, err := dialer.Dial(ctx, nil) + + // Then: the error returned reflects the database has failed its healthcheck. + require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable) +} + +type failingHealthcheck struct{} + +func (failingHealthcheck) Ping(context.Context) (time.Duration, error) { + // Simulate a database connection error. + return 0, xerrors.New("oops") +} + type wrappedListener struct { net.Listener dials int32 diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 3ed880d40970f..5bfefdfda2a3e 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -33,6 +33,7 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" @@ -555,6 +556,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { t := time.NewTicker(recheckInterval) defer t.Stop() + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + go func() { defer func() { logger.Debug(ctx, "end log streaming loop") @@ -928,6 +932,9 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary) defer encoder.Close(websocket.StatusGoingAway) + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + go func(ctx context.Context) { // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? t := time.NewTicker(api.AgentConnectionUpdateFrequency) @@ -989,6 +996,16 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + // Ensure the database is reachable before proceeding. + _, err := api.Database.Ping(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: codersdk.DatabaseNotReachable, + Detail: err.Error(), + }) + return + } + // This route accepts user API key auth and workspace proxy auth. The moon actor has // full permissions so should be able to pass this authz check. workspace := httpmw.WorkspaceParam(r) @@ -1293,6 +1310,9 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ sendTicker := time.NewTicker(sendInterval) defer sendTicker.Stop() + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + // Send initial metadata. sendMetadata() diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 84862d9c400c9..8b91abc07206b 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -406,31 +406,84 @@ func (api *API) postUserWorkspaces(rw http.ResponseWriter, r *http.Request) { ctx = r.Context() apiKey = httpmw.APIKey(r) auditor = api.Auditor.Load() - user = httpmw.UserParam(r) ) + var req codersdk.CreateWorkspaceRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + var owner workspaceOwner + // This user fetch is an optimization path for the most common case of creating a + // workspace for 'Me'. + // + // This is also required to allow `owners` to create workspaces for users + // that are not in an organization. + user, ok := httpmw.UserParamOptional(r) + if ok { + owner = workspaceOwner{ + ID: user.ID, + Username: user.Username, + AvatarURL: user.AvatarURL, + } + } else { + // A workspace can still be created if the caller can read the organization + // member. The organization is required, which can be sourced from the + // template. + // + // TODO: This code gets called twice for each workspace build request. + // This is inefficient and costs at most 2 extra RTTs to the DB. + // This can be optimized. It exists as it is now for code simplicity. + // The most common case is to create a workspace for 'Me'. Which does + // not enter this code branch. + template, ok := requestTemplate(ctx, rw, req, api.Database) + if !ok { + return + } + + // We need to fetch the original user as a system user to fetch the + // user_id. 'ExtractUserContext' handles all cases like usernames, + // 'Me', etc. + // nolint:gocritic // The user_id needs to be fetched. This handles all those cases. + user, ok := httpmw.ExtractUserContext(dbauthz.AsSystemRestricted(ctx), api.Database, rw, r) + if !ok { + return + } + + organizationMember, err := database.ExpectOne(api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{ + OrganizationID: template.OrganizationID, + UserID: user.ID, + IncludeSystem: false, + })) + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching organization member.", + Detail: err.Error(), + }) + return + } + owner = workspaceOwner{ + ID: organizationMember.OrganizationMember.UserID, + Username: organizationMember.Username, + AvatarURL: organizationMember.AvatarURL, + } + } + aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{ Audit: *auditor, Log: api.Logger, Request: r, Action: database.AuditActionCreate, AdditionalFields: audit.AdditionalFields{ - WorkspaceOwner: user.Username, + WorkspaceOwner: owner.Username, }, }) defer commitAudit() - - var req codersdk.CreateWorkspaceRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - owner := workspaceOwner{ - ID: user.ID, - Username: user.Username, - AvatarURL: user.AvatarURL, - } createWorkspace(ctx, aReq, apiKey.UserID, api, owner, req, rw, r) } @@ -450,65 +503,8 @@ func createWorkspace( rw http.ResponseWriter, r *http.Request, ) { - // If we were given a `TemplateVersionID`, we need to determine the `TemplateID` from it. - templateID := req.TemplateID - if templateID == uuid.Nil { - templateVersion, err := api.Database.GetTemplateVersionByID(ctx, req.TemplateVersionID) - if httpapi.Is404Error(err) { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: fmt.Sprintf("Template version %q doesn't exist.", templateID.String()), - Validations: []codersdk.ValidationError{{ - Field: "template_version_id", - Detail: "template not found", - }}, - }) - return - } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching template version.", - Detail: err.Error(), - }) - return - } - if templateVersion.Archived { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Archived template versions cannot be used to make a workspace.", - Validations: []codersdk.ValidationError{ - { - Field: "template_version_id", - Detail: "template version archived", - }, - }, - }) - return - } - - templateID = templateVersion.TemplateID.UUID - } - - template, err := api.Database.GetTemplateByID(ctx, templateID) - if httpapi.Is404Error(err) { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: fmt.Sprintf("Template %q doesn't exist.", templateID.String()), - Validations: []codersdk.ValidationError{{ - Field: "template_id", - Detail: "template not found", - }}, - }) - return - } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching template.", - Detail: err.Error(), - }) - return - } - if template.Deleted { - httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ - Message: fmt.Sprintf("Template %q has been deleted!", template.Name), - }) + template, ok := requestTemplate(ctx, rw, req, api.Database) + if !ok { return } @@ -776,6 +772,72 @@ func createWorkspace( httpapi.Write(ctx, rw, http.StatusCreated, w) } +func requestTemplate(ctx context.Context, rw http.ResponseWriter, req codersdk.CreateWorkspaceRequest, db database.Store) (database.Template, bool) { + // If we were given a `TemplateVersionID`, we need to determine the `TemplateID` from it. + templateID := req.TemplateID + + if templateID == uuid.Nil { + templateVersion, err := db.GetTemplateVersionByID(ctx, req.TemplateVersionID) + if httpapi.Is404Error(err) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Template version %q doesn't exist.", req.TemplateVersionID), + Validations: []codersdk.ValidationError{{ + Field: "template_version_id", + Detail: "template not found", + }}, + }) + return database.Template{}, false + } + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching template version.", + Detail: err.Error(), + }) + return database.Template{}, false + } + if templateVersion.Archived { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Archived template versions cannot be used to make a workspace.", + Validations: []codersdk.ValidationError{ + { + Field: "template_version_id", + Detail: "template version archived", + }, + }, + }) + return database.Template{}, false + } + + templateID = templateVersion.TemplateID.UUID + } + + template, err := db.GetTemplateByID(ctx, templateID) + if httpapi.Is404Error(err) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Template %q doesn't exist.", templateID), + Validations: []codersdk.ValidationError{{ + Field: "template_id", + Detail: "template not found", + }}, + }) + return database.Template{}, false + } + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching template.", + Detail: err.Error(), + }) + return database.Template{}, false + } + if template.Deleted { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: fmt.Sprintf("Template %q has been deleted!", template.Name), + }) + return database.Template{}, false + } + return template, true +} + func (api *API) notifyWorkspaceCreated( ctx context.Context, receiverID uuid.UUID, diff --git a/codersdk/database.go b/codersdk/database.go new file mode 100644 index 0000000000000..1a33da6362e0d --- /dev/null +++ b/codersdk/database.go @@ -0,0 +1,7 @@ +package codersdk + +import "golang.org/x/xerrors" + +const DatabaseNotReachable = "database not reachable" + +var ErrDatabaseNotReachable = xerrors.New(DatabaseNotReachable) diff --git a/codersdk/workspacesdk/dialer.go b/codersdk/workspacesdk/dialer.go index 23d618761b807..71cac0c5f04b1 100644 --- a/codersdk/workspacesdk/dialer.go +++ b/codersdk/workspacesdk/dialer.go @@ -11,17 +11,19 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/websocket" + "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" - "github.com/coder/websocket" ) var permanentErrorStatuses = []int{ - http.StatusConflict, // returned if client/agent connections disabled (browser only) - http.StatusBadRequest, // returned if API mismatch - http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist + http.StatusConflict, // returned if client/agent connections disabled (browser only) + http.StatusBadRequest, // returned if API mismatch + http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist + http.StatusInternalServerError, // returned if database is not reachable, } type WebsocketDialer struct { @@ -89,6 +91,11 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl "Ensure your client release version (%s, different than the API version) matches the server release version", buildinfo.Version()) } + + if sdkErr.Message == codersdk.DatabaseNotReachable && + sdkErr.StatusCode() == http.StatusInternalServerError { + err = xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err) + } } w.connected <- err return tailnet.ControlProtocolClients{}, err diff --git a/codersdk/workspacesdk/workspacesdk_test.go b/codersdk/workspacesdk/workspacesdk_test.go index 317db4471319f..e7ccd96e208fa 100644 --- a/codersdk/workspacesdk/workspacesdk_test.go +++ b/codersdk/workspacesdk/workspacesdk_test.go @@ -1,13 +1,21 @@ package workspacesdk_test import ( + "net/http" + "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" + "github.com/coder/websocket" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/testutil" ) func TestWorkspaceRewriteDERPMap(t *testing.T) { @@ -37,3 +45,30 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) { require.Equal(t, "coconuts.org", node.HostName) require.Equal(t, 44558, node.DERPPort) } + +func TestWorkspaceDialerFailure(t *testing.T) { + t.Parallel() + + // Setup. + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + + // Given: a mock HTTP server which mimicks an unreachable database when calling the coordination endpoint. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{ + Message: codersdk.DatabaseNotReachable, + Detail: "oops", + }) + })) + t.Cleanup(srv.Close) + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + + // When: calling the coordination endpoint. + dialer := workspacesdk.NewWebsocketDialer(logger, u, &websocket.DialOptions{}) + _, err = dialer.Dial(ctx, nil) + + // Then: an error indicating a database issue is returned, to conditionalize the behavior of the caller. + require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable) +} diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 5b0f0ca197743..6ffa15851214d 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -24,6 +24,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" @@ -376,6 +377,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) logger.Debug(ctx, "drpc server error", slog.Error(err)) }, }) + + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + err = server.Serve(ctx, session) srvCancel() logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err)) diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index eedd6f1bcfa1c..72859c5460fa7 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" agplschedule "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/util/ptr" @@ -245,7 +246,131 @@ func TestCreateWorkspace(t *testing.T) { func TestCreateUserWorkspace(t *testing.T) { t.Parallel() + // Create a custom role that can create workspaces for another user. + t.Run("ForAnotherUser", func(t *testing.T) { + t.Parallel() + + owner, first := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureCustomRoles: 1, + codersdk.FeatureTemplateRBAC: 1, + }, + }, + }) + ctx := testutil.Context(t, testutil.WaitShort) + //nolint:gocritic // using owner to setup roles + r, err := owner.CreateOrganizationRole(ctx, codersdk.Role{ + Name: "creator", + OrganizationID: first.OrganizationID.String(), + DisplayName: "Creator", + OrganizationPermissions: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ + codersdk.ResourceWorkspace: {codersdk.ActionCreate, codersdk.ActionWorkspaceStart, codersdk.ActionUpdate, codersdk.ActionRead}, + codersdk.ResourceOrganizationMember: {codersdk.ActionRead}, + }), + }) + require.NoError(t, err) + + // use admin for setting up test + admin, adminID := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, rbac.RoleTemplateAdmin()) + + // try the test action with this user & custom role + creator, _ := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, rbac.RoleMember(), rbac.RoleIdentifier{ + Name: r.Name, + OrganizationID: first.OrganizationID, + }) + + version := coderdtest.CreateTemplateVersion(t, admin, first.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, admin, version.ID) + template := coderdtest.CreateTemplate(t, admin, first.OrganizationID, version.ID) + + ctx = testutil.Context(t, testutil.WaitLong*1000) // Reset the context to avoid timeouts. + + _, err = creator.CreateUserWorkspace(ctx, adminID.ID.String(), codersdk.CreateWorkspaceRequest{ + TemplateID: template.ID, + Name: "workspace", + }) + require.NoError(t, err) + }) + + // Asserting some authz calls when creating a workspace. + t.Run("AuthzStory", func(t *testing.T) { + t.Parallel() + owner, _, api, first := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureCustomRoles: 1, + codersdk.FeatureTemplateRBAC: 1, + }, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong*2000) + defer cancel() + + //nolint:gocritic // using owner to setup roles + creatorRole, err := owner.CreateOrganizationRole(ctx, codersdk.Role{ + Name: "creator", + OrganizationID: first.OrganizationID.String(), + OrganizationPermissions: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{ + codersdk.ResourceWorkspace: {codersdk.ActionCreate, codersdk.ActionWorkspaceStart, codersdk.ActionUpdate, codersdk.ActionRead}, + codersdk.ResourceOrganizationMember: {codersdk.ActionRead}, + }), + }) + require.NoError(t, err) + + version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID) + template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID) + _, userID := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID) + creator, _ := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, rbac.RoleIdentifier{ + Name: creatorRole.Name, + OrganizationID: first.OrganizationID, + }) + + // Create a workspace with the current api using an org admin. + authz := coderdtest.AssertRBAC(t, api.AGPL, creator) + authz.Reset() // Reset all previous checks done in setup. + _, err = creator.CreateUserWorkspace(ctx, userID.ID.String(), codersdk.CreateWorkspaceRequest{ + TemplateID: template.ID, + Name: "test-user", + }) + require.NoError(t, err) + + // Assert all authz properties + t.Run("OnlyOrganizationAuthzCalls", func(t *testing.T) { + // Creating workspaces is an organization action. So organization + // permissions should be sufficient to complete the action. + for _, call := range authz.AllCalls() { + if call.Action == policy.ActionRead && + call.Object.Equal(rbac.ResourceUser.WithOwner(userID.ID.String()).WithID(userID.ID)) { + // User read checks are called. If they fail, ignore them. + if call.Err != nil { + continue + } + } + + if call.Object.Type == rbac.ResourceDeploymentConfig.Type { + continue // Ignore + } + + assert.Falsef(t, call.Object.OrgID == "", + "call %q for object %q has no organization set. Site authz calls not expected here", + call.Action, call.Object.String(), + ) + } + }) + }) + t.Run("NoTemplateAccess", func(t *testing.T) { + // NoTemplateAccess intentionally does not use provisioners. The template + // version will be stuck in 'pending' forever. t.Parallel() client, first := coderdenttest.New(t, &coderdenttest.Options{ diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index 9108283513e4f..0b434c767f53a 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" @@ -336,7 +337,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { tracing.Middleware(s.TracerProvider), httpmw.AttachRequestID, httpmw.ExtractRealIP(s.Options.RealIPConfig), - httpmw.Logger(s.Logger), + loggermw.Logger(s.Logger), prometheusMW, corsMW, diff --git a/examples/templates/docker-devcontainer/main.tf b/examples/templates/docker-devcontainer/main.tf index d0f328ea46f38..52877214caa7c 100644 --- a/examples/templates/docker-devcontainer/main.tf +++ b/examples/templates/docker-devcontainer/main.tf @@ -2,7 +2,7 @@ terraform { required_providers { coder = { source = "coder/coder" - version = "~> 1.0.0" + version = "~> 2.0" } docker = { source = "kreuzwerker/docker" @@ -340,11 +340,11 @@ module "jetbrains_gateway" { source = "registry.coder.com/modules/jetbrains-gateway/coder" # JetBrains IDEs to make available for the user to select - jetbrains_ides = ["IU", "PY", "WS", "PS", "RD", "CL", "GO", "RM"] + jetbrains_ides = ["IU", "PS", "WS", "PY", "CL", "GO", "RM", "RD", "RR"] default = "IU" # Default folder to open when starting a JetBrains IDE - folder = "/home/coder" + folder = "/workspaces" # This ensures that the latest version of the module gets downloaded, you can also pin the module version to prevent breaking changes in production. version = ">= 1.0.0" diff --git a/examples/templates/kubernetes-devcontainer/main.tf b/examples/templates/kubernetes-devcontainer/main.tf index c9a86f08df6d2..69e53565d3c78 100644 --- a/examples/templates/kubernetes-devcontainer/main.tf +++ b/examples/templates/kubernetes-devcontainer/main.tf @@ -2,7 +2,7 @@ terraform { required_providers { coder = { source = "coder/coder" - version = "~> 1.0.0" + version = "~> 2.0" } kubernetes = { source = "hashicorp/kubernetes" diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index b461bc593ee36..9adf9951fa488 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -20,12 +20,13 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/retry" + "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" - "github.com/coder/retry" ) // Dialer represents the function to create a daemon client connection. @@ -290,7 +291,7 @@ func (p *Server) acquireLoop() { defer p.wg.Done() defer func() { close(p.acquireDoneCh) }() ctx := p.closeContext - for { + for retrier := retry.New(10*time.Millisecond, 1*time.Second); retrier.Wait(ctx); { if p.acquireExit() { return } @@ -299,7 +300,17 @@ func (p *Server) acquireLoop() { p.opts.Logger.Debug(ctx, "shut down before client (re) connected") return } - p.acquireAndRunOne(client) + err := p.acquireAndRunOne(client) + if err != nil && ctx.Err() == nil { // Only log if context is not done. + // Short-circuit: don't wait for the retry delay to exit, if required. + if p.acquireExit() { + return + } + p.opts.Logger.Warn(ctx, "failed to acquire job, retrying", slog.F("delay", fmt.Sprintf("%vms", retrier.Delay.Milliseconds())), slog.Error(err)) + } else { + // Reset the retrier after each successful acquisition. + retrier.Reset() + } } } @@ -318,7 +329,7 @@ func (p *Server) acquireExit() bool { return false } -func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { +func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) error { ctx := p.closeContext p.opts.Logger.Debug(ctx, "start of acquireAndRunOne") job, err := p.acquireGraceful(client) @@ -327,15 +338,15 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { if errors.Is(err, context.Canceled) || errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, fasthttputil.ErrInmemoryListenerClosed) { - return + return err } p.opts.Logger.Warn(ctx, "provisionerd was unable to acquire job", slog.Error(err)) - return + return xerrors.Errorf("failed to acquire job: %w", err) } if job.JobId == "" { p.opts.Logger.Debug(ctx, "acquire job successfully canceled") - return + return nil } if len(job.TraceMetadata) > 0 { @@ -390,9 +401,9 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error), }) if err != nil { - p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err)) + p.opts.Logger.Error(ctx, "failed to report provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err)) } - return + return xerrors.Errorf("failed to report provisioner job failed: %w", err) } p.mutex.Lock() @@ -416,6 +427,7 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { p.mutex.Lock() p.activeJob = nil p.mutex.Unlock() + return nil } // acquireGraceful attempts to acquire a job from the server, handling canceling the acquisition if we gracefully shut diff --git a/scripts/release/check_commit_metadata.sh b/scripts/release/check_commit_metadata.sh index f53de8e107430..1368425d00639 100755 --- a/scripts/release/check_commit_metadata.sh +++ b/scripts/release/check_commit_metadata.sh @@ -118,6 +118,23 @@ main() { title2=${parts2[*]:2} fi + # Handle cherry-pick bot, it turns "chore: foo bar (#42)" to + # "chore: foo bar (cherry-pick #42) (#43)". + if [[ ${title1} == *"(cherry-pick #"* ]]; then + title1=${title1%" ("*} + pr=${title1##*#} + pr=${pr%)} + title1=${title1%" ("*} + title1="${title1} (#${pr})"$'\n' + fi + if [[ ${title2} == *"(cherry-pick #"* ]]; then + title2=${title2%" ("*} + pr=${title2##*#} + pr=${pr%)} + title2=${title2%" ("*} + title2="${title2} (#${pr})"$'\n' + fi + if [[ ${title1} != "${title2}" ]]; then log "Invariant failed, cherry-picked commits have different titles: \"${title1%$'\n'}\" != \"${title2%$'\n'}\", attempting to check commit body for cherry-pick information..." diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index ab8e58d4574f4..7d5d717ff6f92 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -589,6 +589,9 @@ export interface DangerousConfig { readonly allow_all_cors: boolean; } +// From codersdk/database.go +export const DatabaseNotReachable = "database not reachable"; + // From healthsdk/healthsdk.go export interface DatabaseReport extends BaseReport { readonly healthy: boolean; diff --git a/site/src/pages/TemplatePage/TemplateVersionsPage/VersionRow.tsx b/site/src/pages/TemplatePage/TemplateVersionsPage/VersionRow.tsx index bd8e7e846a011..e41ac97ec6217 100644 --- a/site/src/pages/TemplatePage/TemplateVersionsPage/VersionRow.tsx +++ b/site/src/pages/TemplatePage/TemplateVersionsPage/VersionRow.tsx @@ -33,7 +33,6 @@ export const VersionRow: FC = ({ }); const jobStatus = version.job.status; - const showActions = onPromoteClick || onArchiveClick; return ( = ({ )} - {showActions && jobStatus === "failed" ? ( + {jobStatus === "failed" && onArchiveClick && ( - ) : ( + )} + + {jobStatus === "succeeded" && onPromoteClick && (