Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion database/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"database/sql"
"testing"

"github.com/coder/coder/database/postgres"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/coder/coder/database/postgres"

_ "github.com/lib/pq"
)

Expand Down
155 changes: 155 additions & 0 deletions database/pubsub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package database

import (
"context"
"database/sql"
"errors"
"sync"
"time"

"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/xerrors"
)

// Listener represents a pubsub handler.
type Listener func(ctx context.Context, message []byte)

// Pubsub is a generic interface for broadcasting and receiving messages.
// Implementors should assume high-availability with the backing implementation.
type Pubsub interface {
Subscribe(event string, listener Listener) (cancel func(), err error)
Publish(event string, message []byte) error
Close() error
}

// Pubsub implementation using PostgreSQL.
type pgPubsub struct {
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[string]Listener
}

// Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
p.mut.Lock()
defer p.mut.Unlock()

err = p.pgListener.Listen(event)
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
// It's ok if it's already open!
err = nil
}
if err != nil {
return nil, xerrors.Errorf("listen: %w", err)
}

var listeners map[string]Listener
var ok bool
if listeners, ok = p.listeners[event]; !ok {
listeners = map[string]Listener{}
p.listeners[event] = listeners
}
var id string
for {
id = uuid.New().String()
if _, ok = listeners[id]; !ok {
break
}
}
listeners[id] = listener
return func() {
p.mut.Lock()
defer p.mut.Unlock()
listeners := p.listeners[event]
delete(listeners, id)

if len(listeners) == 0 {
_ = p.pgListener.Unlisten(event)
}
}, nil
}

func (p *pgPubsub) Publish(event string, message []byte) error {
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
return xerrors.Errorf("exec: %w", err)
}
return nil
}

// Close closes the pubsub instance.
func (p *pgPubsub) Close() error {
return p.pgListener.Close()
}

// listen begins receiving messages on the pq listener.
func (p *pgPubsub) listen(ctx context.Context) {
var (
notif *pq.Notification
ok bool
)
defer p.pgListener.Close()
for {
select {
case <-ctx.Done():
return
case notif, ok = <-p.pgListener.Notify:
if !ok {
return
}
}
// A nil notification can be dispatched on reconnect.
if notif == nil {
continue
}
p.listenReceive(ctx, notif)
}
}

func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
p.mut.Lock()
defer p.mut.Unlock()
listeners, ok := p.listeners[notif.Channel]
if !ok {
return
}
extra := []byte(notif.Extra)
for _, listener := range listeners {
go listener(ctx, extra)
}
}

// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
})
select {
case err := <-errCh:
if err != nil {
return nil, xerrors.Errorf("create pq listener: %w", err)
}
case <-ctx.Done():
return nil, ctx.Err()
}
pg := &pgPubsub{
db: db,
pgListener: listener,
listeners: make(map[string]map[string]Listener),
}
go pg.listen(ctx)

return pg, nil
}
63 changes: 63 additions & 0 deletions database/pubsub_memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package database

import (
"context"
"sync"

"github.com/google/uuid"
)

// memoryPubsub is an in-memory Pubsub implementation.
type memoryPubsub struct {
mut sync.RWMutex
listeners map[string]map[uuid.UUID]Listener
}

func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
m.mut.Lock()
defer m.mut.Unlock()

var listeners map[uuid.UUID]Listener
var ok bool
if listeners, ok = m.listeners[event]; !ok {
listeners = map[uuid.UUID]Listener{}
m.listeners[event] = listeners
}
var id uuid.UUID
for {
id = uuid.New()
if _, ok = listeners[id]; !ok {
break
}
}
listeners[id] = listener
return func() {
m.mut.Lock()
defer m.mut.Unlock()
listeners := m.listeners[event]
delete(listeners, id)
}, nil
}

func (m *memoryPubsub) Publish(event string, message []byte) error {
m.mut.RLock()
defer m.mut.RUnlock()
listeners, ok := m.listeners[event]
if !ok {
return nil
}
for _, listener := range listeners {
listener(context.Background(), message)
}
return nil
}

func (m *memoryPubsub) Close() error {
return nil
}

func NewPubsubInMemory() Pubsub {
return &memoryPubsub{
listeners: make(map[string]map[uuid.UUID]Listener),
}
}
32 changes: 32 additions & 0 deletions database/pubsub_memory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package database_test

import (
"context"
"testing"

"github.com/coder/coder/database"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPubsubMemory(t *testing.T) {
t.Parallel()

t.Run("Memory", func(t *testing.T) {
pubsub := database.NewPubsubInMemory()
event := "test"
data := "testing"
ch := make(chan []byte)
cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
ch <- message
})
require.NoError(t, err)
defer cancelFunc()
go func() {
err = pubsub.Publish(event, []byte(data))
require.NoError(t, err)
}()
message := <-ch
assert.Equal(t, string(message), data)
})
}
47 changes: 47 additions & 0 deletions database/pubsub_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//go:build linux

package database_test

import (
"context"
"database/sql"
"testing"

"github.com/coder/coder/database"
"github.com/coder/coder/database/postgres"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPubsub(t *testing.T) {
t.Parallel()

t.Run("Postgres", func(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()

connectionURL, close, err := postgres.Open()
require.NoError(t, err)
defer close()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
data := "testing"
ch := make(chan []byte)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
ch <- message
})
require.NoError(t, err)
defer cancelFunc()
go func() {
err = pubsub.Publish(event, []byte(data))
require.NoError(t, err)
}()
message := <-ch
assert.Equal(t, string(message), data)
})
}