Skip to content

Commit 5f4beaa

Browse files
committed
feat: add wsproxy implementation for key fetching
1 parent 21b92ef commit 5f4beaa

File tree

3 files changed

+553
-0
lines changed

3 files changed

+553
-0
lines changed

enterprise/wsproxy/keycache.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package wsproxy
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"golang.org/x/xerrors"
9+
10+
"cdr.dev/slog"
11+
12+
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
13+
"github.com/coder/quartz"
14+
)
15+
16+
type CryptoKeyCache struct {
17+
client *wsproxysdk.Client
18+
logger slog.Logger
19+
Clock quartz.Clock
20+
21+
keysMu sync.RWMutex
22+
keys map[int32]wsproxysdk.CryptoKey
23+
latest wsproxysdk.CryptoKey
24+
}
25+
26+
func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) {
27+
cache := &CryptoKeyCache{
28+
client: client,
29+
logger: log,
30+
Clock: quartz.NewReal(),
31+
}
32+
33+
for _, opt := range opts {
34+
opt(cache)
35+
}
36+
37+
m, latest, err := cache.fetch(ctx)
38+
if err != nil {
39+
return nil, xerrors.Errorf("initial fetch: %w", err)
40+
}
41+
cache.keys, cache.latest = m, latest
42+
43+
go cache.refresh(ctx)
44+
45+
return cache, nil
46+
}
47+
48+
func (k *CryptoKeyCache) Latest(ctx context.Context) (wsproxysdk.CryptoKey, error) {
49+
k.keysMu.RLock()
50+
latest := k.latest
51+
k.keysMu.RUnlock()
52+
53+
now := k.Clock.Now().UTC()
54+
if latest.Active(now) {
55+
return latest, nil
56+
}
57+
58+
k.keysMu.Lock()
59+
defer k.keysMu.Unlock()
60+
61+
if k.latest.Active(now) {
62+
return k.latest, nil
63+
}
64+
65+
var err error
66+
k.keys, k.latest, err = k.fetch(ctx)
67+
if err != nil {
68+
return wsproxysdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
69+
}
70+
71+
if !k.latest.Active(now) {
72+
return wsproxysdk.CryptoKey{}, xerrors.Errorf("no active keys found")
73+
}
74+
75+
return k.latest, nil
76+
}
77+
78+
func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (wsproxysdk.CryptoKey, error) {
79+
now := k.Clock.Now().UTC()
80+
k.keysMu.RLock()
81+
key, ok := k.keys[sequence]
82+
k.keysMu.RUnlock()
83+
if ok {
84+
return validKey(key, now)
85+
}
86+
87+
k.keysMu.Lock()
88+
defer k.keysMu.Unlock()
89+
key, ok = k.keys[sequence]
90+
if ok {
91+
return validKey(key, now)
92+
}
93+
94+
var err error
95+
k.keys, k.latest, err = k.fetch(ctx)
96+
if err != nil {
97+
return wsproxysdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
98+
}
99+
100+
key, ok = k.keys[sequence]
101+
if !ok {
102+
return wsproxysdk.CryptoKey{}, xerrors.Errorf("key %d not found", sequence)
103+
}
104+
105+
return validKey(key, now)
106+
}
107+
108+
func (k *CryptoKeyCache) refresh(ctx context.Context) {
109+
k.Clock.TickerFunc(ctx, time.Minute*10, func() error {
110+
kmap, latest, err := k.fetch(ctx)
111+
if err != nil {
112+
k.logger.Error(ctx, "failed to fetch crypto keys", slog.Error(err))
113+
return nil
114+
}
115+
116+
k.keysMu.Lock()
117+
defer k.keysMu.Unlock()
118+
k.keys = kmap
119+
k.latest = latest
120+
return nil
121+
})
122+
}
123+
124+
func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]wsproxysdk.CryptoKey, wsproxysdk.CryptoKey, error) {
125+
keys, err := k.client.CryptoKeys(ctx)
126+
if err != nil {
127+
return nil, wsproxysdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err)
128+
}
129+
130+
kmap, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now().UTC())
131+
return kmap, latest, nil
132+
}
133+
134+
func toKeyMap(keys []wsproxysdk.CryptoKey, now time.Time) (map[int32]wsproxysdk.CryptoKey, wsproxysdk.CryptoKey) {
135+
m := make(map[int32]wsproxysdk.CryptoKey)
136+
var latest wsproxysdk.CryptoKey
137+
for _, key := range keys {
138+
m[key.Sequence] = key
139+
if key.Sequence > latest.Sequence && key.Active(now) {
140+
latest = key
141+
}
142+
}
143+
return m, latest
144+
}
145+
146+
func validKey(key wsproxysdk.CryptoKey, now time.Time) (wsproxysdk.CryptoKey, error) {
147+
if key.Invalid(now) {
148+
return wsproxysdk.CryptoKey{}, xerrors.Errorf("key %d is invalid", key.Sequence)
149+
}
150+
151+
return key, nil
152+
}

0 commit comments

Comments
 (0)