Skip to content

fix: race panic in test/go/postgres #1752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import (
"path/filepath"
"time"

"github.com/coder/coder/provisioner/echo"

"github.com/briandowns/spinner"
"github.com/coreos/go-systemd/daemon"
"github.com/google/go-github/v43/github"
Expand Down Expand Up @@ -50,6 +48,7 @@ import (
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisioner/terraform"
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionersdk"
Expand Down Expand Up @@ -229,18 +228,17 @@ func server() *cobra.Command {
URLs: []string{stunServer},
})
}
options := &coderd.Options{
options := (&coderd.Options{
AccessURL: accessURLParsed,
ICEServers: iceServers,
Logger: logger.Named("coderd"),
Database: databasefake.New(),
Pubsub: database.NewPubsubInMemory(),
GoogleTokenValidator: validator,
SecureAuthCookie: secureAuthCookie,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TURNServer: turnServer,
TracerProvider: tracerProvider,
}
}).SetLogger(logger.Named("coderd"))

if oauth2GithubClientSecret != "" {
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations)
Expand Down
4 changes: 2 additions & 2 deletions coderd/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ func (api *api) Authorize(rw http.ResponseWriter, r *http.Request, action rbac.A

// Log the errors for debugging
internalError := new(rbac.UnauthorizedError)
logger := api.Logger
logger := api.Logger()
if xerrors.As(err, internalError) {
logger = api.Logger.With(slog.F("internal", internalError.Internal()))
logger = api.Logger().With(slog.F("internal", internalError.Internal()))
}
// Log information for debugging. This will be very helpful
// in the early days
Expand Down
25 changes: 19 additions & 6 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ import (
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"

"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/pion/webrtc/v3"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"

sdktrace "go.opentelemetry.io/otel/sdk/trace"

"cdr.dev/slog"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/awsidentity"
Expand All @@ -35,7 +35,7 @@ import (
// Options are requires parameters for Coder to start.
type Options struct {
AccessURL *url.URL
Logger slog.Logger
logger atomic.Value
Database database.Store
Pubsub database.Pubsub

Expand All @@ -57,6 +57,7 @@ type Options struct {
}

type CoderD interface {
SetLogger(logger slog.Logger)
Handler() http.Handler
CloseWait()

Expand Down Expand Up @@ -116,7 +117,7 @@ func newRouter(options *Options, a *api) chi.Router {
r.Use(
// Specific routes can specify smaller limits.
httpmw.RateLimitPerMinute(options.APIRateLimit),
debugLogRequest(a.Logger),
debugLogRequest(a.Logger()),
)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(w, http.StatusOK, httpapi.Response{
Expand Down Expand Up @@ -337,8 +338,6 @@ func newRouter(options *Options, a *api) chi.Router {
})
})

var _ = xerrors.New("test")

r.NotFound(site.DefaultHandler().ServeHTTP)
return r
}
Expand All @@ -362,6 +361,20 @@ func (c *coderD) Handler() http.Handler {
return c.router
}

func (o *Options) Logger() slog.Logger {
logger, _ := o.logger.Load().(slog.Logger)
return logger
}

func (c *coderD) SetLogger(logger slog.Logger) {
c.options.SetLogger(logger)
}

func (o *Options) SetLogger(logger slog.Logger) *Options {
o.logger.Store(logger)
return o
}

// API contains all route handlers. Only HTTP handlers should
// be added to this struct for code clarity.
type api struct {
Expand Down
13 changes: 8 additions & 5 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"testing"
"time"

"github.com/coder/coder/coderd/rbac"

"cloud.google.com/go/compute/metadata"
"github.com/fullsailor/pkcs7"
"github.com/golang-jwt/jwt"
Expand All @@ -46,6 +44,7 @@ import (
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
Expand Down Expand Up @@ -82,13 +81,15 @@ func NewWithServer(t *testing.T, options *Options) (*httptest.Server, *codersdk.
if options == nil {
options = &Options{}
}

if options.GoogleTokenValidator == nil {
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(cancelFunc)
var err error
options.GoogleTokenValidator, err = idtoken.NewValidator(ctx, option.WithoutAuthentication())
require.NoError(t, err)
}

if options.AutobuildTicker == nil {
ticker := make(chan time.Time)
options.AutobuildTicker = ticker
Expand Down Expand Up @@ -144,10 +145,9 @@ func NewWithServer(t *testing.T, options *Options) (*httptest.Server, *codersdk.
require.NoError(t, err)

// We set the handler after server creation for the access URL.
coderDaemon := coderd.New(&coderd.Options{
coderDaemon := coderd.New((&coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
AccessURL: serverURL,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
Database: db,
Pubsub: pubsub,

Expand All @@ -159,16 +159,19 @@ func NewWithServer(t *testing.T, options *Options) (*httptest.Server, *codersdk.
TURNServer: turnServer,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
})
}).SetLogger(slogtest.Make(t, nil).Leveled(slog.LevelDebug)))
srv.Config.Handler = coderDaemon.Handler()
if options.IncludeProvisionerD {
// This is automatically closed.
_ = NewProvisionerDaemon(t, coderDaemon)
}

t.Cleanup(func() {
cancelFunc()
_ = turnServer.Close()
srv.Close()
coderDaemon.CloseWait()
coderDaemon.SetLogger(slog.Logger{})
})

return srv, codersdk.New(serverURL), coderDaemon
Expand Down
4 changes: 4 additions & 0 deletions coderd/database/databasefake/databasefake.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ type fakeQuerier struct {
workspaces []database.Workspace
}

func (*fakeQuerier) Close() error {
return nil
}

// InTx doesn't rollback data properly for in-memory yet.
func (q *fakeQuerier) InTx(fn func(database.Store) error) error {
return fn(q)
Expand Down
5 changes: 5 additions & 0 deletions coderd/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Store interface {
querier

InTx(func(Store) error) error
Close() error
}

// DBTX represents a database connection or transaction.
Expand All @@ -45,6 +46,10 @@ type sqlQuerier struct {
db DBTX
}

func (q *sqlQuerier) Close() error {
return q.sdb.Close()
}

// InTx performs database operations inside a transaction.
func (q *sqlQuerier) InTx(function func(Store) error) error {
if q.sdb == nil {
Expand Down
9 changes: 4 additions & 5 deletions coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"storj.io/drpc/drpcserver"

"cdr.dev/slog"

"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/parameter"
Expand Down Expand Up @@ -75,23 +74,23 @@ func (c *coderD) ListenProvisionerDaemon(ctx context.Context) (client proto.DRPC
Database: c.options.Database,
Pubsub: c.options.Pubsub,
Provisioners: daemon.Provisioners,
Logger: c.options.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
Logger: c.options.Logger().Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
})
if err != nil {
return nil, err
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || ctx.Err() != nil {
return
}
c.options.Logger.Debug(ctx, "drpc server error", slog.Error(err))
c.options.Logger().Debug(ctx, "drpc server error", slog.Error(err))
},
})
go func() {
err = server.Serve(ctx, serverSession)
if err != nil && !xerrors.Is(err, io.EOF) {
c.options.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
c.options.Logger().Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
}
// close the sessions so we don't leak goroutines serving them.
_ = clientSession.Close()
Expand Down
6 changes: 3 additions & 3 deletions coderd/provisionerjobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (api *api) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
var logs []database.ProvisionerJobLog
err := json.Unmarshal(message, &logs)
if err != nil {
api.Logger.Warn(r.Context(), fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
api.Logger().Warn(r.Context(), fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
return
}

Expand All @@ -106,7 +106,7 @@ func (api *api) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
default:
// If this overflows users could miss logs streaming. This can happen
// if a database request takes a long amount of time, and we get a lot of logs.
api.Logger.Warn(r.Context(), "provisioner job log overflowing channel")
api.Logger().Warn(r.Context(), "provisioner job log overflowing channel")
}
}
})
Expand Down Expand Up @@ -168,7 +168,7 @@ func (api *api) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
case <-ticker.C:
job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID)
if err != nil {
api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
api.Logger().Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
continue
}
if job.CompletedAt.Valid {
Expand Down
12 changes: 6 additions & 6 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
}
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Logger: api.Logger().Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
if err != nil {
Expand Down Expand Up @@ -173,7 +173,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Pubsub: api.Pubsub,
Logger: api.Logger.Named("peerbroker-proxy-listen"),
Logger: api.Logger().Named("peerbroker-proxy-listen"),
})
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
Expand Down Expand Up @@ -241,7 +241,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return
}

api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
api.Logger().Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))

ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
Expand Down Expand Up @@ -314,12 +314,12 @@ func (api *api) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
_ = wsConn.Close(websocket.StatusNormalClosure, "")
}()
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
api.Logger().Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
select {
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
case <-r.Context().Done():
}
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
api.Logger().Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
}

// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
Expand Down Expand Up @@ -400,7 +400,7 @@ func (api *api) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
go func() {
_ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{
ChannelID: agentID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Logger: api.Logger().Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
_ = client.Close()
Expand Down
2 changes: 1 addition & 1 deletion coderd/workspaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ func (api *api) watchWorkspace(rw http.ResponseWriter, r *http.Request) {
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
api.Logger.Warn(r.Context(), "accept websocket connection", slog.Error(err))
api.Logger().Warn(r.Context(), "accept websocket connection", slog.Error(err))
return
}
defer c.Close(websocket.StatusInternalError, "internal error")
Expand Down
2 changes: 2 additions & 0 deletions coderd/workspaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,15 @@ func TestPostWorkspacesByOrganization(t *testing.T) {

func TestWorkspacesByOrganization(t *testing.T) {
t.Parallel()

t.Run("ListEmpty", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_, err := client.WorkspacesByOrganization(context.Background(), user.OrganizationID)
require.NoError(t, err)
})

t.Run("List", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
Expand Down