Skip to content

Commit 25ec905

Browse files
committed
feat: add CoderVPN protocol definition & implementaion
1 parent 326886d commit 25ec905

File tree

7 files changed

+3379
-0
lines changed

7 files changed

+3379
-0
lines changed

Makefile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ gen: \
488488
agent/proto/agent.pb.go \
489489
provisionersdk/proto/provisioner.pb.go \
490490
provisionerd/proto/provisionerd.pb.go \
491+
vpn/vpn.proto \
491492
coderd/database/dump.sql \
492493
$(DB_GEN_FILES) \
493494
site/src/api/typesGenerated.ts \
@@ -517,6 +518,7 @@ gen/mark-fresh:
517518
agent/proto/agent.pb.go \
518519
provisionersdk/proto/provisioner.pb.go \
519520
provisionerd/proto/provisionerd.pb.go \
521+
vpn/vpn.proto \
520522
coderd/database/dump.sql \
521523
$(DB_GEN_FILES) \
522524
site/src/api/typesGenerated.ts \
@@ -600,6 +602,12 @@ provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
600602
--go-drpc_opt=paths=source_relative \
601603
./provisionerd/proto/provisionerd.proto
602604

605+
vpn/vpn.pb.go: vpn/vpn.proto
606+
protoc \
607+
--go_out=. \
608+
--go_opt=paths=source_relative \
609+
./vpn/vpn.proto
610+
603611
site/src/api/typesGenerated.ts: $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
604612
go run ./scripts/apitypings/ > $@
605613
./scripts/pnpm_install.sh

vpn/serdes.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package vpn
2+
3+
import (
4+
"context"
5+
"encoding/binary"
6+
"io"
7+
"sync"
8+
9+
"google.golang.org/protobuf/proto"
10+
11+
"cdr.dev/slog"
12+
)
13+
14+
// MaxLength is the largest possible CoderVPN Protocol message size. This is set
15+
// so that a misbehaving peer can't cause us to allocate a huge amount of memory.
16+
const MaxLength = 0x1000000 // 16MiB
17+
18+
// serdes SERializes and DESerializes protobuf messages to and from the conn.
19+
type serdes[S rpcMessage, R receivableRPCMessage[RR], RR any] struct {
20+
ctx context.Context
21+
logger slog.Logger
22+
conn io.ReadWriteCloser
23+
sendCh <-chan S
24+
recvCh chan<- R
25+
closeOnce sync.Once
26+
wg sync.WaitGroup
27+
}
28+
29+
func (s *serdes[_, R, RR]) recvLoop() {
30+
s.logger.Debug(s.ctx, "starting recvLoop")
31+
defer s.closeIdempotent()
32+
defer close(s.recvCh)
33+
for {
34+
var length uint32
35+
if err := binary.Read(s.conn, binary.BigEndian, &length); err != nil {
36+
s.logger.Debug(s.ctx, "failed to read length", slog.Error(err))
37+
return
38+
}
39+
if length > MaxLength {
40+
s.logger.Critical(s.ctx, "message length exceeds max",
41+
slog.F("length", length))
42+
}
43+
s.logger.Debug(s.ctx, "about to read message", slog.F("length", length))
44+
mb := make([]byte, length)
45+
if n, err := io.ReadFull(s.conn, mb); err != nil {
46+
s.logger.Debug(s.ctx, "failed to read message",
47+
slog.Error(err),
48+
slog.F("num_bytes_read", n))
49+
return
50+
}
51+
msg := R(new(RR))
52+
if err := proto.Unmarshal(mb, msg); err != nil {
53+
s.logger.Critical(s.ctx, "failed to unmarshal message", slog.Error(err))
54+
return
55+
}
56+
select {
57+
case s.recvCh <- msg:
58+
s.logger.Debug(s.ctx, "passed received message to speaker")
59+
case <-s.ctx.Done():
60+
s.logger.Debug(s.ctx, "recvLoop canceled", slog.Error(s.ctx.Err()))
61+
}
62+
}
63+
}
64+
65+
func (s *serdes[S, _, _]) sendLoop() {
66+
s.logger.Debug(s.ctx, "starting sendLoop")
67+
defer s.closeIdempotent()
68+
for {
69+
select {
70+
case <-s.ctx.Done():
71+
s.logger.Debug(s.ctx, "sendLoop canceled", slog.Error(s.ctx.Err()))
72+
return
73+
case msg, ok := <-s.sendCh:
74+
if !ok {
75+
s.logger.Debug(s.ctx, "sendCh closed")
76+
return
77+
}
78+
mb, err := proto.Marshal(msg)
79+
if err != nil {
80+
s.logger.Critical(s.ctx, "failed to marshal message", slog.Error(err))
81+
return
82+
}
83+
if err := binary.Write(s.conn, binary.BigEndian, uint32(len(mb))); err != nil {
84+
s.logger.Debug(s.ctx, "failed to write length", slog.Error(err))
85+
return
86+
}
87+
if _, err := s.conn.Write(mb); err != nil {
88+
s.logger.Debug(s.ctx, "failed to write message", slog.Error(err))
89+
return
90+
}
91+
}
92+
}
93+
}
94+
95+
func (s *serdes[_, _, _]) closeIdempotent() {
96+
s.closeOnce.Do(func() {
97+
if err := s.conn.Close(); err != nil {
98+
s.logger.Error(s.ctx, "failed to close connection", slog.Error(err))
99+
} else {
100+
s.logger.Info(s.ctx, "closed connection")
101+
}
102+
})
103+
}
104+
105+
func (s *serdes[_, _, _]) Close() error {
106+
s.closeIdempotent()
107+
s.wg.Wait()
108+
return nil
109+
}
110+
111+
func (s *serdes[_, _, _]) start() {
112+
s.wg.Add(2)
113+
go func() {
114+
defer s.wg.Done()
115+
s.recvLoop()
116+
}()
117+
go func() {
118+
defer s.wg.Done()
119+
s.sendLoop()
120+
}()
121+
}
122+
123+
func newSerdes[S rpcMessage, R receivableRPCMessage[RR], RR any](
124+
ctx context.Context, logger slog.Logger, conn io.ReadWriteCloser,
125+
sendCh <-chan S, recvCh chan<- R,
126+
) *serdes[S, R, RR] {
127+
return &serdes[S, R, RR]{
128+
ctx: ctx,
129+
logger: logger.Named("serdes"),
130+
conn: conn,
131+
sendCh: sendCh,
132+
recvCh: recvCh,
133+
}
134+
}

0 commit comments

Comments
 (0)