diff --git a/cli/server.go b/cli/server.go index d7dea720978e9..c47cf8271de9e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -68,6 +68,7 @@ import ( "github.com/coder/coder/coderd/database/dbmetrics" "github.com/coder/coder/coderd/database/dbpurge" "github.com/coder/coder/coderd/database/migrations" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/coderd/devtunnel" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitsshkey" @@ -463,7 +464,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. Logger: logger.Named("coderd"), Database: dbfake.New(), DERPMap: derpMap, - Pubsub: database.NewPubsubInMemory(), + Pubsub: pubsub.NewInMemory(), CacheDir: cacheDir, GoogleTokenValidator: googleTokenValidator, GitAuthConfigs: gitAuthConfigs, @@ -589,7 +590,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. if cfg.InMemoryDatabase { // This is only used for testing. options.Database = dbmetrics.New(dbfake.New(), options.PrometheusRegistry) - options.Pubsub = database.NewPubsubInMemory() + options.Pubsub = pubsub.NewInMemory() } else { sqlDB, err := connectToPostgres(ctx, logger, sqlDriver, cfg.PostgresURL.String()) if err != nil { @@ -600,7 +601,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. }() options.Database = dbmetrics.New(database.New(sqlDB), options.PrometheusRegistry) - options.Pubsub, err = database.NewPubsub(ctx, sqlDB, cfg.PostgresURL.String()) + options.Pubsub, err = pubsub.New(ctx, sqlDB, cfg.PostgresURL.String()) if err != nil { return xerrors.Errorf("create pubsub: %w", err) } diff --git a/coderd/coderd.go b/coderd/coderd.go index 64cc157c723d4..82a7d36e80551 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -48,6 +48,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/database/dbmetrics" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/healthcheck" @@ -95,7 +96,7 @@ type Options struct { AppHostnameRegex *regexp.Regexp Logger slog.Logger Database database.Store - Pubsub database.Pubsub + Pubsub pubsub.Pubsub // CacheDir is used for caching files served by the API. CacheDir string diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index b6623fbd6f942..56984cab13d88 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -59,6 +59,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/healthcheck" @@ -130,7 +131,7 @@ type Options struct { // It should only be used in cases where multiple Coder // test instances are running against the same database. Database database.Store - Pubsub database.Pubsub + Pubsub pubsub.Pubsub ConfigSSH codersdk.SSHConfigResponse diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 7726a7174861c..932e4aaf4739a 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -11,13 +11,14 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbfake" "github.com/coder/coder/coderd/database/postgres" + "github.com/coder/coder/coderd/database/pubsub" ) -func NewDB(t testing.TB) (database.Store, database.Pubsub) { +func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) { t.Helper() db := dbfake.New() - pubsub := database.NewPubsubInMemory() + ps := pubsub.NewInMemory() if os.Getenv("DB") != "" { connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") if connectionURL == "" { @@ -36,12 +37,12 @@ func NewDB(t testing.TB) (database.Store, database.Pubsub) { }) db = database.New(sqlDB) - pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL) + ps, err = pubsub.New(context.Background(), sqlDB, connectionURL) require.NoError(t, err) t.Cleanup(func() { - _ = pubsub.Close() + _ = ps.Close() }) } - return db, pubsub + return db, ps } diff --git a/coderd/database/pubsub.go b/coderd/database/pubsub/pubsub.go similarity index 92% rename from coderd/database/pubsub.go rename to coderd/database/pubsub/pubsub.go index 6a6d1f2f07751..f661e885c2848 100644 --- a/coderd/database/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -1,4 +1,4 @@ -package database +package pubsub import ( "context" @@ -48,7 +48,7 @@ type msgOrErr struct { type msgQueue struct { ctx context.Context cond *sync.Cond - q [PubsubBufferSize]msgOrErr + q [BufferSize]msgOrErr front int size int closed bool @@ -82,7 +82,7 @@ func (q *msgQueue) run() { return } item := q.q[q.front] - q.front = (q.front + 1) % PubsubBufferSize + q.front = (q.front + 1) % BufferSize q.size-- q.cond.L.Unlock() @@ -111,20 +111,20 @@ func (q *msgQueue) enqueue(msg []byte) { q.cond.L.Lock() defer q.cond.L.Unlock() - if q.size == PubsubBufferSize { + if q.size == BufferSize { // queue is full, so we're going to drop the msg we got called with. // We also need to record that messages are being dropped, which we // do at the last message in the queue. This potentially makes us // lose 2 messages instead of one, but it's more important at this // point to warn the subscriber that they're losing messages so they // can do something about it. - back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize + back := (q.front + BufferSize - 1) % BufferSize q.q[back].msg = nil q.q[back].err = ErrDroppedMessages return } // queue is not full, insert the message - next := (q.front + q.size) % PubsubBufferSize + next := (q.front + q.size) % BufferSize q.q[next].msg = msg q.q[next].err = nil q.size++ @@ -143,17 +143,17 @@ func (q *msgQueue) dropped() { q.cond.L.Lock() defer q.cond.L.Unlock() - if q.size == PubsubBufferSize { + if q.size == BufferSize { // queue is full, but we need to record that messages are being dropped, // which we do at the last message in the queue. This potentially drops // another message, but it's more important for the subscriber to know. - back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize + back := (q.front + BufferSize - 1) % BufferSize q.q[back].msg = nil q.q[back].err = ErrDroppedMessages return } // queue is not full, insert the error - next := (q.front + q.size) % PubsubBufferSize + next := (q.front + q.size) % BufferSize q.q[next].msg = nil q.q[next].err = ErrDroppedMessages q.size++ @@ -171,9 +171,9 @@ type pgPubsub struct { queues map[string]map[uuid.UUID]*msgQueue } -// PubsubBufferSize is the maximum number of unhandled messages we will buffer +// BufferSize is the maximum number of unhandled messages we will buffer // for a subscriber before dropping messages. -const PubsubBufferSize = 2048 +const BufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { @@ -295,8 +295,8 @@ func (p *pgPubsub) recordReconnect() { } } -// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection. -func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) { +// New creates a new Pubsub implementation using a PostgreSQL connection. +func New(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) { // Creates a new listener using pq. errCh := make(chan error) listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) { diff --git a/coderd/database/pubsub_internal_test.go b/coderd/database/pubsub/pubsub_internal_test.go similarity index 94% rename from coderd/database/pubsub_internal_test.go rename to coderd/database/pubsub/pubsub_internal_test.go index 31c50ce172176..adfa70286dbe0 100644 --- a/coderd/database/pubsub_internal_test.go +++ b/coderd/database/pubsub/pubsub_internal_test.go @@ -1,4 +1,4 @@ -package database +package pubsub import ( "context" @@ -26,7 +26,7 @@ func Test_msgQueue_ListenerWithError(t *testing.T) { // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned // when we wrap around the end of the circular buffer. This tests that we correctly handle // the wrapping and aren't dequeueing misaligned data. - cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring + cycles := (BufferSize / 5) * 2 // almost twice around the ring for j := 0; j < cycles; j++ { for i := 0; i < 4; i++ { uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) @@ -75,7 +75,7 @@ func Test_msgQueue_Listener(t *testing.T) { // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned // when we wrap around the end of the circular buffer. This tests that we correctly handle // the wrapping and aren't dequeueing misaligned data. - cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring + cycles := (BufferSize / 5) * 2 // almost twice around the ring for j := 0; j < cycles; j++ { for i := 0; i < 4; i++ { uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) @@ -119,7 +119,7 @@ func Test_msgQueue_Full(t *testing.T) { // we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks // but only after we've dequeued a message, and then another extra because we want to exceed // the capacity, not just reach it. - for i := 0; i < PubsubBufferSize+2; i++ { + for i := 0; i < BufferSize+2; i++ { uut.enqueue([]byte(fmt.Sprintf("%d", i))) // ensure the first dequeue has happened before proceeding, so that this function isn't racing // against the goroutine that dequeues items. @@ -136,5 +136,5 @@ func Test_msgQueue_Full(t *testing.T) { // Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last // message we send doesn't get queued, AND, it bumps a message out of the queue to make room // for the error, so we read 2 less than we sent. - require.Equal(t, PubsubBufferSize, n) + require.Equal(t, BufferSize, n) } diff --git a/coderd/database/pubsub_memory.go b/coderd/database/pubsub/pubsub_memory.go similarity index 97% rename from coderd/database/pubsub_memory.go rename to coderd/database/pubsub/pubsub_memory.go index 0ab4684c80a3f..ec4c26a4f01e0 100644 --- a/coderd/database/pubsub_memory.go +++ b/coderd/database/pubsub/pubsub_memory.go @@ -1,4 +1,4 @@ -package database +package pubsub import ( "context" @@ -87,7 +87,7 @@ func (*memoryPubsub) Close() error { return nil } -func NewPubsubInMemory() Pubsub { +func NewInMemory() Pubsub { return &memoryPubsub{ listeners: make(map[string]map[uuid.UUID]genericListener), } diff --git a/coderd/database/pubsub_memory_test.go b/coderd/database/pubsub/pubsub_memory_test.go similarity index 89% rename from coderd/database/pubsub_memory_test.go rename to coderd/database/pubsub/pubsub_memory_test.go index 7856880d856c2..80553c8fa73da 100644 --- a/coderd/database/pubsub_memory_test.go +++ b/coderd/database/pubsub/pubsub_memory_test.go @@ -1,4 +1,4 @@ -package database_test +package pubsub_test import ( "context" @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/pubsub" ) func TestPubsubMemory(t *testing.T) { @@ -16,7 +16,7 @@ func TestPubsubMemory(t *testing.T) { t.Run("Legacy", func(t *testing.T) { t.Parallel() - pubsub := database.NewPubsubInMemory() + pubsub := pubsub.NewInMemory() event := "test" data := "testing" messageChannel := make(chan []byte) @@ -36,7 +36,7 @@ func TestPubsubMemory(t *testing.T) { t.Run("WithErr", func(t *testing.T) { t.Parallel() - pubsub := database.NewPubsubInMemory() + pubsub := pubsub.NewInMemory() event := "test" data := "testing" messageChannel := make(chan []byte) diff --git a/coderd/database/pubsub_test.go b/coderd/database/pubsub/pubsub_test.go similarity index 86% rename from coderd/database/pubsub_test.go rename to coderd/database/pubsub/pubsub_test.go index 60fb1821af55d..d1f80fa5a1aed 100644 --- a/coderd/database/pubsub_test.go +++ b/coderd/database/pubsub/pubsub_test.go @@ -1,6 +1,6 @@ //go:build linux -package database_test +package pubsub_test import ( "context" @@ -15,8 +15,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/postgres" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/testutil" ) @@ -39,7 +39,7 @@ func TestPubsub(t *testing.T) { db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() - pubsub, err := database.NewPubsub(ctx, db, connectionURL) + pubsub, err := pubsub.New(ctx, db, connectionURL) require.NoError(t, err) defer pubsub.Close() event := "test" @@ -67,7 +67,7 @@ func TestPubsub(t *testing.T) { db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() - pubsub, err := database.NewPubsub(ctx, db, connectionURL) + pubsub, err := pubsub.New(ctx, db, connectionURL) require.NoError(t, err) defer pubsub.Close() cancelFunc() @@ -82,7 +82,7 @@ func TestPubsub(t *testing.T) { db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() - pubsub, err := database.NewPubsub(ctx, db, connectionURL) + pubsub, err := pubsub.New(ctx, db, connectionURL) require.NoError(t, err) defer pubsub.Close() @@ -114,7 +114,7 @@ func TestPubsub(t *testing.T) { db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() - pubsub, err := database.NewPubsub(ctx, db, connectionURL) + pubsub, err := pubsub.New(ctx, db, connectionURL) require.NoError(t, err) defer pubsub.Close() @@ -171,12 +171,12 @@ func TestPubsub_ordering(t *testing.T) { db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() - pubsub, err := database.NewPubsub(ctx, db, connectionURL) + ps, err := pubsub.New(ctx, db, connectionURL) require.NoError(t, err) - defer pubsub.Close() + defer ps.Close() event := "test" messageChannel := make(chan []byte, 100) - cancelSub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + cancelSub, err := ps.Subscribe(event, func(ctx context.Context, message []byte) { // sleep a random amount of time to simulate handlers taking different amount of time // to process, depending on the message // nolint: gosec @@ -187,7 +187,7 @@ func TestPubsub_ordering(t *testing.T) { require.NoError(t, err) defer cancelSub() for i := 0; i < 100; i++ { - err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i))) + err = ps.Publish(event, []byte(fmt.Sprintf("%d", i))) assert.NoError(t, err) } for i := 0; i < 100; i++ { @@ -219,14 +219,14 @@ func TestPubsub_Disconnect(t *testing.T) { ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancelFunc() - pubsub, err := database.NewPubsub(ctx, db, connectionURL) + ps, err := pubsub.New(ctx, db, connectionURL) require.NoError(t, err) - defer pubsub.Close() + defer ps.Close() event := "test" // buffer responses so that when the test completes, goroutines don't get blocked & leak - errors := make(chan error, database.PubsubBufferSize) - messages := make(chan string, database.PubsubBufferSize) + errors := make(chan error, pubsub.BufferSize) + messages := make(chan string, pubsub.BufferSize) readOne := func() (m string, e error) { t.Helper() select { @@ -244,7 +244,7 @@ func TestPubsub_Disconnect(t *testing.T) { return m, e } - cancelSub, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) { + cancelSub, err := ps.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) { messages <- string(msg) errors <- err }) @@ -252,7 +252,7 @@ func TestPubsub_Disconnect(t *testing.T) { defer cancelSub() for i := 0; i < 100; i++ { - err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i))) + err = ps.Publish(event, []byte(fmt.Sprintf("%d", i))) require.NoError(t, err) } // make sure we're getting at least one message. @@ -270,7 +270,7 @@ func TestPubsub_Disconnect(t *testing.T) { default: // ok } - err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j))) + err = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) j++ if err != nil { break @@ -292,7 +292,7 @@ func TestPubsub_Disconnect(t *testing.T) { default: // ok } - err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j))) + err = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) if err == nil { break } @@ -303,7 +303,7 @@ func TestPubsub_Disconnect(t *testing.T) { k := j // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB // reconnect - require.Less(t, k, database.PubsubBufferSize, "exceeded buffer") + require.Less(t, k, pubsub.BufferSize, "exceeded buffer") // We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As // soon as we see k or higher we know we're getting messages after the restart. @@ -315,7 +315,7 @@ func TestPubsub_Disconnect(t *testing.T) { default: // ok } - _ = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j))) + _ = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) j++ time.Sleep(testutil.IntervalFast) } @@ -324,7 +324,7 @@ func TestPubsub_Disconnect(t *testing.T) { gotDroppedErr := false for { m, err := readOne() - if xerrors.Is(err, database.ErrDroppedMessages) { + if xerrors.Is(err, pubsub.ErrDroppedMessages) { gotDroppedErr = true continue } @@ -334,7 +334,7 @@ func TestPubsub_Disconnect(t *testing.T) { if l >= k { // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than // DB reconnect - require.Less(t, l, database.PubsubBufferSize, "exceeded buffer") + require.Less(t, l, pubsub.BufferSize, "exceeded buffer") break } } diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index f0de9939bf06a..a33c4a048a6d3 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/schedule" @@ -56,7 +57,7 @@ type Server struct { GitAuthConfigs []*gitauth.Config Tags json.RawMessage Database database.Store - Pubsub database.Pubsub + Pubsub pubsub.Pubsub Telemetry telemetry.Reporter Tracer trace.Tracer QuotaCommitter *atomic.Pointer[proto.QuotaCommitter] diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 6b42556f0aee9..6b881210b3f6a 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbfake" "github.com/coder/coder/coderd/database/dbgen" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/schedule" @@ -51,14 +52,14 @@ func TestAcquireJob(t *testing.T) { t.Run("Debounce", func(t *testing.T) { t.Parallel() db := dbfake.New() - pubsub := database.NewPubsubInMemory() + ps := pubsub.NewInMemory() srv := &provisionerdserver.Server{ ID: uuid.New(), Logger: slogtest.Make(t, nil), AccessURL: &url.URL{}, Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho}, Database: db, - Pubsub: pubsub, + Pubsub: ps, Telemetry: telemetry.NewNoop(), AcquireJobDebounce: time.Hour, Auditor: mockAuditor(), @@ -1256,7 +1257,7 @@ func TestInsertWorkspaceResource(t *testing.T) { func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server { t.Helper() db := dbfake.New() - pubsub := database.NewPubsubInMemory() + ps := pubsub.NewInMemory() return &provisionerdserver.Server{ ID: uuid.New(), @@ -1265,7 +1266,7 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server { AccessURL: &url.URL{}, Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho}, Database: db, - Pubsub: pubsub, + Pubsub: ps, Telemetry: telemetry.NewNoop(), Auditor: mockAuditor(), TemplateScheduleStore: testTemplateScheduleStore(), diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 99fb647385f5e..3926d353d1017 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/db2sdk" "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisionersdk" @@ -268,7 +269,7 @@ type logFollower struct { ctx context.Context logger slog.Logger db database.Store - pubsub database.Pubsub + pubsub pubsub.Pubsub r *http.Request rw http.ResponseWriter conn *websocket.Conn @@ -281,14 +282,14 @@ type logFollower struct { } func newLogFollower( - ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, + ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub, rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64, ) *logFollower { return &logFollower{ ctx: ctx, logger: logger, db: db, - pubsub: pubsub, + pubsub: ps, r: r, rw: rw, jobID: job.ID, diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index ee34e451058b0..acbf303efc957 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbmock" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/testutil" @@ -138,7 +139,7 @@ func Test_logFollower_completeBeforeFollow(t *testing.T) { logger := slogtest.Make(t, nil) ctrl := gomock.NewController(t) mDB := dbmock.NewMockStore(ctrl) - pubsub := database.NewPubsubInMemory() + ps := pubsub.NewInMemory() now := database.Now() job := database.ProvisionerJob{ ID: uuid.New(), @@ -157,7 +158,7 @@ func Test_logFollower_completeBeforeFollow(t *testing.T) { // we need an HTTP server to get a websocket srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 10) + uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 10) uut.follow() })) defer srv.Close() @@ -200,7 +201,7 @@ func Test_logFollower_completeBeforeSubscribe(t *testing.T) { logger := slogtest.Make(t, nil) ctrl := gomock.NewController(t) mDB := dbmock.NewMockStore(ctrl) - pubsub := database.NewPubsubInMemory() + ps := pubsub.NewInMemory() now := database.Now() job := database.ProvisionerJob{ ID: uuid.New(), @@ -217,7 +218,7 @@ func Test_logFollower_completeBeforeSubscribe(t *testing.T) { // we need an HTTP server to get a websocket srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0) + uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0) uut.follow() })) defer srv.Close() @@ -276,7 +277,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) { logger := slogtest.Make(t, nil) ctrl := gomock.NewController(t) mDB := dbmock.NewMockStore(ctrl) - pubsub := database.NewPubsubInMemory() + ps := pubsub.NewInMemory() now := database.Now() job := database.ProvisionerJob{ ID: uuid.New(), @@ -293,7 +294,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) { // we need an HTTP server to get a websocket srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0) + uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0) uut.follow() })) defer srv.Close() @@ -342,7 +343,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) { } msg, err = json.Marshal(&n) require.NoError(t, err) - err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg) + err = ps.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg) require.NoError(t, err) mt, msg, err = client.Read(ctx) @@ -360,7 +361,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) { n.CreatedAfter = 0 msg, err = json.Marshal(&n) require.NoError(t, err) - err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg) + err = ps.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg) require.NoError(t, err) // server should now close diff --git a/enterprise/replicasync/replicasync.go b/enterprise/replicasync/replicasync.go index 4b31b912ea673..a2bcb8837288e 100644 --- a/enterprise/replicasync/replicasync.go +++ b/enterprise/replicasync/replicasync.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/database/pubsub" ) var PubsubEvent = "replica" @@ -36,7 +37,7 @@ type Options struct { // New registers the replica with the database and periodically updates to ensure // it's healthy. It contacts all other alive replicas to ensure they are reachable. -func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) { +func New(ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub, options *Options) (*Manager, error) { if options == nil { options = &Options{} } @@ -77,7 +78,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data if err != nil { return nil, xerrors.Errorf("insert replica: %w", err) } - err = pubsub.Publish(PubsubEvent, []byte(options.ID.String())) + err = ps.Publish(PubsubEvent, []byte(options.ID.String())) if err != nil { return nil, xerrors.Errorf("publish new replica: %w", err) } @@ -86,7 +87,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data id: options.ID, options: options, db: db, - pubsub: pubsub, + pubsub: ps, self: replica, logger: logger, closed: make(chan struct{}), @@ -110,7 +111,7 @@ type Manager struct { id uuid.UUID options *Options db database.Store - pubsub database.Pubsub + pubsub pubsub.Pubsub logger slog.Logger closeWait sync.WaitGroup diff --git a/enterprise/replicasync/replicasync_test.go b/enterprise/replicasync/replicasync_test.go index f2c0eebd8cd5c..741be64fa12cc 100644 --- a/enterprise/replicasync/replicasync_test.go +++ b/enterprise/replicasync/replicasync_test.go @@ -18,6 +18,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbfake" "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/enterprise/replicasync" "github.com/coder/coder/testutil" ) @@ -212,7 +213,7 @@ func TestReplica(t *testing.T) { // this many PostgreSQL connections takes some // configuration tweaking. db := dbfake.New() - pubsub := database.NewPubsubInMemory() + pubsub := pubsub.NewInMemory() logger := slogtest.Make(t, nil) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index c25a9c2f773f3..b0d9cfa64032f 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -16,13 +16,13 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/pubsub" agpl "github.com/coder/coder/tailnet" ) // NewCoordinator creates a new high availability coordinator // that uses PostgreSQL pubsub to exchange handshakes. -func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) { +func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, error) { ctx, cancelFunc := context.WithCancel(context.Background()) nameCache, err := lru.New[uuid.UUID, string](512) @@ -33,7 +33,7 @@ func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinato coord := &haCoordinator{ id: uuid.New(), log: logger, - pubsub: pubsub, + pubsub: ps, closeFunc: cancelFunc, close: make(chan struct{}), nodes: map[uuid.UUID]*agpl.Node{}, @@ -53,7 +53,7 @@ type haCoordinator struct { id uuid.UUID log slog.Logger mutex sync.RWMutex - pubsub database.Pubsub + pubsub pubsub.Pubsub close chan struct{} closeFunc context.CancelFunc diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go index cf85af4a5a565..bcc3ddca34d05 100644 --- a/enterprise/tailnet/coordinator_test.go +++ b/enterprise/tailnet/coordinator_test.go @@ -10,8 +10,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/enterprise/tailnet" agpl "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" @@ -21,7 +21,7 @@ func TestCoordinatorSingle(t *testing.T) { t.Parallel() t.Run("ClientWithoutAgent", func(t *testing.T) { t.Parallel() - coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory()) require.NoError(t, err) defer coordinator.Close() @@ -49,7 +49,7 @@ func TestCoordinatorSingle(t *testing.T) { t.Run("AgentWithoutClients", func(t *testing.T) { t.Parallel() - coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory()) require.NoError(t, err) defer coordinator.Close() @@ -77,7 +77,7 @@ func TestCoordinatorSingle(t *testing.T) { t.Run("AgentWithClient", func(t *testing.T) { t.Parallel() - coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory()) require.NoError(t, err) defer coordinator.Close()