Skip to content

Commit 2618ce4

Browse files
authored
chore: Add pubsub (#7)
* chore: Add pubsub Exposes new in-memory and PostgreSQL pubsubs. This will be used for negotiating WebRTC connections. * Don't run PostgreSQL tests on non-Linux
1 parent 8accb81 commit 2618ce4

File tree

5 files changed

+299
-1
lines changed

5 files changed

+299
-1
lines changed

database/postgres/postgres_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ import (
66
"database/sql"
77
"testing"
88

9-
"github.com/coder/coder/database/postgres"
109
"github.com/stretchr/testify/require"
1110
"go.uber.org/goleak"
1211

12+
"github.com/coder/coder/database/postgres"
13+
1314
_ "github.com/lib/pq"
1415
)
1516

database/pubsub.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package database
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"sync"
8+
"time"
9+
10+
"github.com/google/uuid"
11+
"github.com/lib/pq"
12+
"golang.org/x/xerrors"
13+
)
14+
15+
// Listener represents a pubsub handler.
16+
type Listener func(ctx context.Context, message []byte)
17+
18+
// Pubsub is a generic interface for broadcasting and receiving messages.
19+
// Implementors should assume high-availability with the backing implementation.
20+
type Pubsub interface {
21+
Subscribe(event string, listener Listener) (cancel func(), err error)
22+
Publish(event string, message []byte) error
23+
Close() error
24+
}
25+
26+
// Pubsub implementation using PostgreSQL.
27+
type pgPubsub struct {
28+
pgListener *pq.Listener
29+
db *sql.DB
30+
mut sync.Mutex
31+
listeners map[string]map[string]Listener
32+
}
33+
34+
// Subscribe calls the listener when an event matching the name is received.
35+
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
36+
p.mut.Lock()
37+
defer p.mut.Unlock()
38+
39+
err = p.pgListener.Listen(event)
40+
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
41+
// It's ok if it's already open!
42+
err = nil
43+
}
44+
if err != nil {
45+
return nil, xerrors.Errorf("listen: %w", err)
46+
}
47+
48+
var listeners map[string]Listener
49+
var ok bool
50+
if listeners, ok = p.listeners[event]; !ok {
51+
listeners = map[string]Listener{}
52+
p.listeners[event] = listeners
53+
}
54+
var id string
55+
for {
56+
id = uuid.New().String()
57+
if _, ok = listeners[id]; !ok {
58+
break
59+
}
60+
}
61+
listeners[id] = listener
62+
return func() {
63+
p.mut.Lock()
64+
defer p.mut.Unlock()
65+
listeners := p.listeners[event]
66+
delete(listeners, id)
67+
68+
if len(listeners) == 0 {
69+
_ = p.pgListener.Unlisten(event)
70+
}
71+
}, nil
72+
}
73+
74+
func (p *pgPubsub) Publish(event string, message []byte) error {
75+
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
76+
if err != nil {
77+
return xerrors.Errorf("exec: %w", err)
78+
}
79+
return nil
80+
}
81+
82+
// Close closes the pubsub instance.
83+
func (p *pgPubsub) Close() error {
84+
return p.pgListener.Close()
85+
}
86+
87+
// listen begins receiving messages on the pq listener.
88+
func (p *pgPubsub) listen(ctx context.Context) {
89+
var (
90+
notif *pq.Notification
91+
ok bool
92+
)
93+
defer p.pgListener.Close()
94+
for {
95+
select {
96+
case <-ctx.Done():
97+
return
98+
case notif, ok = <-p.pgListener.Notify:
99+
if !ok {
100+
return
101+
}
102+
}
103+
// A nil notification can be dispatched on reconnect.
104+
if notif == nil {
105+
continue
106+
}
107+
p.listenReceive(ctx, notif)
108+
}
109+
}
110+
111+
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
112+
p.mut.Lock()
113+
defer p.mut.Unlock()
114+
listeners, ok := p.listeners[notif.Channel]
115+
if !ok {
116+
return
117+
}
118+
extra := []byte(notif.Extra)
119+
for _, listener := range listeners {
120+
go listener(ctx, extra)
121+
}
122+
}
123+
124+
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
125+
func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, error) {
126+
// Creates a new listener using pq.
127+
errCh := make(chan error)
128+
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
129+
// This callback gets events whenever the connection state changes.
130+
// Don't send if the errChannel has already been closed.
131+
select {
132+
case <-errCh:
133+
return
134+
default:
135+
errCh <- err
136+
close(errCh)
137+
}
138+
})
139+
select {
140+
case err := <-errCh:
141+
if err != nil {
142+
return nil, xerrors.Errorf("create pq listener: %w", err)
143+
}
144+
case <-ctx.Done():
145+
return nil, ctx.Err()
146+
}
147+
pg := &pgPubsub{
148+
db: db,
149+
pgListener: listener,
150+
listeners: make(map[string]map[string]Listener),
151+
}
152+
go pg.listen(ctx)
153+
154+
return pg, nil
155+
}

database/pubsub_memory.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package database
2+
3+
import (
4+
"context"
5+
"sync"
6+
7+
"github.com/google/uuid"
8+
)
9+
10+
// memoryPubsub is an in-memory Pubsub implementation.
11+
type memoryPubsub struct {
12+
mut sync.RWMutex
13+
listeners map[string]map[uuid.UUID]Listener
14+
}
15+
16+
func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
17+
m.mut.Lock()
18+
defer m.mut.Unlock()
19+
20+
var listeners map[uuid.UUID]Listener
21+
var ok bool
22+
if listeners, ok = m.listeners[event]; !ok {
23+
listeners = map[uuid.UUID]Listener{}
24+
m.listeners[event] = listeners
25+
}
26+
var id uuid.UUID
27+
for {
28+
id = uuid.New()
29+
if _, ok = listeners[id]; !ok {
30+
break
31+
}
32+
}
33+
listeners[id] = listener
34+
return func() {
35+
m.mut.Lock()
36+
defer m.mut.Unlock()
37+
listeners := m.listeners[event]
38+
delete(listeners, id)
39+
}, nil
40+
}
41+
42+
func (m *memoryPubsub) Publish(event string, message []byte) error {
43+
m.mut.RLock()
44+
defer m.mut.RUnlock()
45+
listeners, ok := m.listeners[event]
46+
if !ok {
47+
return nil
48+
}
49+
for _, listener := range listeners {
50+
listener(context.Background(), message)
51+
}
52+
return nil
53+
}
54+
55+
func (m *memoryPubsub) Close() error {
56+
return nil
57+
}
58+
59+
func NewPubsubInMemory() Pubsub {
60+
return &memoryPubsub{
61+
listeners: make(map[string]map[uuid.UUID]Listener),
62+
}
63+
}

database/pubsub_memory_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package database_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/coder/coder/database"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestPubsubMemory(t *testing.T) {
13+
t.Parallel()
14+
15+
t.Run("Memory", func(t *testing.T) {
16+
pubsub := database.NewPubsubInMemory()
17+
event := "test"
18+
data := "testing"
19+
ch := make(chan []byte)
20+
cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
21+
ch <- message
22+
})
23+
require.NoError(t, err)
24+
defer cancelFunc()
25+
go func() {
26+
err = pubsub.Publish(event, []byte(data))
27+
require.NoError(t, err)
28+
}()
29+
message := <-ch
30+
assert.Equal(t, string(message), data)
31+
})
32+
}

database/pubsub_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//go:build linux
2+
3+
package database_test
4+
5+
import (
6+
"context"
7+
"database/sql"
8+
"testing"
9+
10+
"github.com/coder/coder/database"
11+
"github.com/coder/coder/database/postgres"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func TestPubsub(t *testing.T) {
17+
t.Parallel()
18+
19+
t.Run("Postgres", func(t *testing.T) {
20+
ctx, cancelFunc := context.WithCancel(context.Background())
21+
defer cancelFunc()
22+
23+
connectionURL, close, err := postgres.Open()
24+
require.NoError(t, err)
25+
defer close()
26+
db, err := sql.Open("postgres", connectionURL)
27+
require.NoError(t, err)
28+
defer db.Close()
29+
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
30+
require.NoError(t, err)
31+
defer pubsub.Close()
32+
event := "test"
33+
data := "testing"
34+
ch := make(chan []byte)
35+
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
36+
ch <- message
37+
})
38+
require.NoError(t, err)
39+
defer cancelFunc()
40+
go func() {
41+
err = pubsub.Publish(event, []byte(data))
42+
require.NoError(t, err)
43+
}()
44+
message := <-ch
45+
assert.Equal(t, string(message), data)
46+
})
47+
}

0 commit comments

Comments
 (0)