diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 107f4cb7a8b85..c9e8ef12cba11 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -6,10 +6,11 @@ import ( "database/sql" "testing" - "github.com/coder/coder/database/postgres" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "github.com/coder/coder/database/postgres" + _ "github.com/lib/pq" ) diff --git a/database/pubsub.go b/database/pubsub.go new file mode 100644 index 0000000000000..656e923246040 --- /dev/null +++ b/database/pubsub.go @@ -0,0 +1,155 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "sync" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "golang.org/x/xerrors" +) + +// Listener represents a pubsub handler. +type Listener func(ctx context.Context, message []byte) + +// Pubsub is a generic interface for broadcasting and receiving messages. +// Implementors should assume high-availability with the backing implementation. +type Pubsub interface { + Subscribe(event string, listener Listener) (cancel func(), err error) + Publish(event string, message []byte) error + Close() error +} + +// Pubsub implementation using PostgreSQL. +type pgPubsub struct { + pgListener *pq.Listener + db *sql.DB + mut sync.Mutex + listeners map[string]map[string]Listener +} + +// Subscribe calls the listener when an event matching the name is received. +func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { + p.mut.Lock() + defer p.mut.Unlock() + + err = p.pgListener.Listen(event) + if errors.Is(err, pq.ErrChannelAlreadyOpen) { + // It's ok if it's already open! + err = nil + } + if err != nil { + return nil, xerrors.Errorf("listen: %w", err) + } + + var listeners map[string]Listener + var ok bool + if listeners, ok = p.listeners[event]; !ok { + listeners = map[string]Listener{} + p.listeners[event] = listeners + } + var id string + for { + id = uuid.New().String() + if _, ok = listeners[id]; !ok { + break + } + } + listeners[id] = listener + return func() { + p.mut.Lock() + defer p.mut.Unlock() + listeners := p.listeners[event] + delete(listeners, id) + + if len(listeners) == 0 { + _ = p.pgListener.Unlisten(event) + } + }, nil +} + +func (p *pgPubsub) Publish(event string, message []byte) error { + _, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) + if err != nil { + return xerrors.Errorf("exec: %w", err) + } + return nil +} + +// Close closes the pubsub instance. +func (p *pgPubsub) Close() error { + return p.pgListener.Close() +} + +// listen begins receiving messages on the pq listener. +func (p *pgPubsub) listen(ctx context.Context) { + var ( + notif *pq.Notification + ok bool + ) + defer p.pgListener.Close() + for { + select { + case <-ctx.Done(): + return + case notif, ok = <-p.pgListener.Notify: + if !ok { + return + } + } + // A nil notification can be dispatched on reconnect. + if notif == nil { + continue + } + p.listenReceive(ctx, notif) + } +} + +func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) { + p.mut.Lock() + defer p.mut.Unlock() + listeners, ok := p.listeners[notif.Channel] + if !ok { + return + } + extra := []byte(notif.Extra) + for _, listener := range listeners { + go listener(ctx, extra) + } +} + +// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection. +func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, error) { + // Creates a new listener using pq. + errCh := make(chan error) + listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) { + // This callback gets events whenever the connection state changes. + // Don't send if the errChannel has already been closed. + select { + case <-errCh: + return + default: + errCh <- err + close(errCh) + } + }) + select { + case err := <-errCh: + if err != nil { + return nil, xerrors.Errorf("create pq listener: %w", err) + } + case <-ctx.Done(): + return nil, ctx.Err() + } + pg := &pgPubsub{ + db: db, + pgListener: listener, + listeners: make(map[string]map[string]Listener), + } + go pg.listen(ctx) + + return pg, nil +} diff --git a/database/pubsub_memory.go b/database/pubsub_memory.go new file mode 100644 index 0000000000000..92244f8bbc6d6 --- /dev/null +++ b/database/pubsub_memory.go @@ -0,0 +1,63 @@ +package database + +import ( + "context" + "sync" + + "github.com/google/uuid" +) + +// memoryPubsub is an in-memory Pubsub implementation. +type memoryPubsub struct { + mut sync.RWMutex + listeners map[string]map[uuid.UUID]Listener +} + +func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { + m.mut.Lock() + defer m.mut.Unlock() + + var listeners map[uuid.UUID]Listener + var ok bool + if listeners, ok = m.listeners[event]; !ok { + listeners = map[uuid.UUID]Listener{} + m.listeners[event] = listeners + } + var id uuid.UUID + for { + id = uuid.New() + if _, ok = listeners[id]; !ok { + break + } + } + listeners[id] = listener + return func() { + m.mut.Lock() + defer m.mut.Unlock() + listeners := m.listeners[event] + delete(listeners, id) + }, nil +} + +func (m *memoryPubsub) Publish(event string, message []byte) error { + m.mut.RLock() + defer m.mut.RUnlock() + listeners, ok := m.listeners[event] + if !ok { + return nil + } + for _, listener := range listeners { + listener(context.Background(), message) + } + return nil +} + +func (m *memoryPubsub) Close() error { + return nil +} + +func NewPubsubInMemory() Pubsub { + return &memoryPubsub{ + listeners: make(map[string]map[uuid.UUID]Listener), + } +} diff --git a/database/pubsub_memory_test.go b/database/pubsub_memory_test.go new file mode 100644 index 0000000000000..7293e255d7e01 --- /dev/null +++ b/database/pubsub_memory_test.go @@ -0,0 +1,32 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/coder/coder/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPubsubMemory(t *testing.T) { + t.Parallel() + + t.Run("Memory", func(t *testing.T) { + pubsub := database.NewPubsubInMemory() + event := "test" + data := "testing" + ch := make(chan []byte) + cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + ch <- message + }) + require.NoError(t, err) + defer cancelFunc() + go func() { + err = pubsub.Publish(event, []byte(data)) + require.NoError(t, err) + }() + message := <-ch + assert.Equal(t, string(message), data) + }) +} diff --git a/database/pubsub_test.go b/database/pubsub_test.go new file mode 100644 index 0000000000000..c9c43026a5225 --- /dev/null +++ b/database/pubsub_test.go @@ -0,0 +1,47 @@ +//go:build linux + +package database_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/coder/coder/database" + "github.com/coder/coder/database/postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPubsub(t *testing.T) { + t.Parallel() + + t.Run("Postgres", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + connectionURL, close, err := postgres.Open() + require.NoError(t, err) + defer close() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := database.NewPubsub(ctx, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + event := "test" + data := "testing" + ch := make(chan []byte) + cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + ch <- message + }) + require.NoError(t, err) + defer cancelFunc() + go func() { + err = pubsub.Publish(event, []byte(data)) + require.NoError(t, err) + }() + message := <-ch + assert.Equal(t, string(message), data) + }) +}