Skip to content

Commit 14a6030

Browse files
authored
chore: rework RPC version negotiation (coder#15687)
Changes the RPC header format from `codervpn <version> <role>` to `codervpn <role> <version1,version2,...>`. The versions list is a list of the maximum supported minor version for each major version, sorted by major versions. E.g. `1.0,2.3,3.1` means `1.0, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1` are supported. When we eventually support multiple versions, the peer's version list will be compared against the current supported versions list to determine the maximum major and minor version supported by both peers. Closes coder#15601
1 parent 887ea14 commit 14a6030

File tree

4 files changed

+428
-27
lines changed

4 files changed

+428
-27
lines changed

vpn/speaker.go

+21-13
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"google.golang.org/protobuf/proto"
1212

1313
"cdr.dev/slog"
14-
"github.com/coder/coder/v2/apiversion"
1514
)
1615

1716
type SpeakerRole string
@@ -258,7 +257,7 @@ func handshake(
258257
// read and write simultaneously to avoid deadlocking if the conn is not buffered
259258
errCh := make(chan error, 2)
260259
go func() {
261-
ours := headerString(CurrentVersion, me)
260+
ours := headerString(me, CurrentSupportedVersions)
262261
_, err := conn.Write([]byte(ours))
263262
logger.Debug(ctx, "wrote out header")
264263
if err != nil {
@@ -316,34 +315,43 @@ func handshake(
316315
}
317316
}
318317
logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader))
319-
err := validateHeader(theirHeader, them)
318+
gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions)
320319
if err != nil {
321320
return xerrors.Errorf("validate header (%s): %w", theirHeader, err)
322321
}
322+
logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion))
323+
// TODO: actually use the common version to perform different behavior once
324+
// we have multiple versions
323325
return nil
324326
}
325327

326328
const headerPreamble = "codervpn"
327329

328-
func headerString(version *apiversion.APIVersion, role SpeakerRole) string {
329-
return fmt.Sprintf("%s %s %s\n", headerPreamble, version.String(), role)
330+
func headerString(role SpeakerRole, versions RPCVersionList) string {
331+
return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String())
330332
}
331333

332-
func validateHeader(header string, expectedRole SpeakerRole) error {
334+
func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) {
333335
parts := strings.Split(header, " ")
334336
if len(parts) != 3 {
335-
return xerrors.New("wrong number of parts")
337+
return RPCVersion{}, xerrors.New("wrong number of parts")
336338
}
337339
if parts[0] != headerPreamble {
338-
return xerrors.New("invalid preamble")
340+
return RPCVersion{}, xerrors.New("invalid preamble")
339341
}
340-
if err := CurrentVersion.Validate(parts[1]); err != nil {
341-
return xerrors.Errorf("version: %w", err)
342+
if parts[1] != string(expectedRole) {
343+
return RPCVersion{}, xerrors.New("unexpected role")
342344
}
343-
if parts[2] != string(expectedRole) {
344-
return xerrors.New("unexpected role")
345+
otherVersions, err := ParseRPCVersionList(parts[2])
346+
if err != nil {
347+
return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err)
345348
}
346-
return nil
349+
compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions)
350+
if !ok {
351+
return RPCVersion{},
352+
xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String())
353+
}
354+
return compatibleVersion, nil
347355
}
348356

349357
type request[S rpcMessage, R rpcMessage] struct {

vpn/speaker_internal_test.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ func TestSpeaker_RawPeer(t *testing.T) {
4747
errCh <- err
4848
}()
4949

50-
expectedHandshake := "codervpn 1.0 tunnel\n"
50+
expectedHandshake := "codervpn tunnel 1.0\n"
5151

5252
b := make([]byte, 256)
5353
n, err := mp.Read(b)
5454
require.NoError(t, err)
5555
require.Equal(t, expectedHandshake, string(b[:n]))
5656

57-
_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
57+
_, err = mp.Write([]byte("codervpn manager 1.3,2.1\n"))
5858
require.NoError(t, err)
5959

6060
err = testutil.RequireRecvCtx(ctx, t, errCh)
@@ -155,7 +155,7 @@ func TestSpeaker_OversizeHandshake(t *testing.T) {
155155
errCh <- err
156156
}()
157157

158-
expectedHandshake := "codervpn 1.0 tunnel\n"
158+
expectedHandshake := "codervpn tunnel 1.0\n"
159159

160160
b := make([]byte, 256)
161161
n, err := mp.Read(b)
@@ -177,10 +177,10 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
177177
for _, tc := range []struct {
178178
name, handshake string
179179
}{
180-
{name: "preamble", handshake: "ssh 1.0 manager\n"},
180+
{name: "preamble", handshake: "ssh manager 1.0\n"},
181181
{name: "2components", handshake: "ssh manager\n"},
182-
{name: "newversion", handshake: "codervpn 1.1 manager\n"},
183-
{name: "oldversion", handshake: "codervpn 0.1 manager\n"},
182+
{name: "newmajors", handshake: "codervpn manager 2.0,3.0\n"},
183+
{name: "0version", handshake: "codervpn 0.1 manager\n"},
184184
{name: "unknown_role", handshake: "codervpn 1.0 supervisor\n"},
185185
{name: "unexpected_role", handshake: "codervpn 1.0 tunnel\n"},
186186
} {
@@ -208,7 +208,7 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
208208
_, err = mp.Write([]byte(tc.handshake))
209209
require.NoError(t, err)
210210

211-
expectedHandshake := "codervpn 1.0 tunnel\n"
211+
expectedHandshake := "codervpn tunnel 1.0\n"
212212
b := make([]byte, 256)
213213
n, err := mp.Read(b)
214214
require.NoError(t, err)
@@ -246,14 +246,14 @@ func TestSpeaker_CorruptMessage(t *testing.T) {
246246
errCh <- err
247247
}()
248248

249-
expectedHandshake := "codervpn 1.0 tunnel\n"
249+
expectedHandshake := "codervpn tunnel 1.0\n"
250250

251251
b := make([]byte, 256)
252252
n, err := mp.Read(b)
253253
require.NoError(t, err)
254254
require.Equal(t, expectedHandshake, string(b[:n]))
255255

256-
_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
256+
_, err = mp.Write([]byte("codervpn manager 1.0\n"))
257257
require.NoError(t, err)
258258

259259
err = testutil.RequireRecvCtx(ctx, t, errCh)

vpn/version.go

+136-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,141 @@
11
package vpn
22

3-
import "github.com/coder/coder/v2/apiversion"
3+
import (
4+
"fmt"
5+
"strconv"
6+
"strings"
47

5-
const (
6-
CurrentMajor = 1
7-
CurrentMinor = 0
8+
"golang.org/x/xerrors"
89
)
910

10-
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
11+
// CurrentSupportedVersions is the list of versions supported by this
12+
// implementation of the VPN RPC protocol.
13+
var CurrentSupportedVersions = RPCVersionList{
14+
Versions: []RPCVersion{
15+
{Major: 1, Minor: 0},
16+
},
17+
}
18+
19+
// RPCVersion represents a single version of the RPC protocol. Any given version
20+
// is expected to be backwards compatible with all previous minor versions on
21+
// the same major version.
22+
//
23+
// e.g. RPCVersion{2, 3} is backwards compatible with RPCVersion{2, 2} but is
24+
// not backwards compatible with RPCVersion{1, 2}.
25+
type RPCVersion struct {
26+
Major uint64 `json:"major"`
27+
Minor uint64 `json:"minor"`
28+
}
29+
30+
// ParseRPCVersion parses a version string in the format "major.minor" into a
31+
// RPCVersion.
32+
func ParseRPCVersion(str string) (RPCVersion, error) {
33+
split := strings.Split(str, ".")
34+
if len(split) != 2 {
35+
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
36+
}
37+
major, err := strconv.ParseUint(split[0], 10, 64)
38+
if err != nil {
39+
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
40+
}
41+
if major == 0 {
42+
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
43+
}
44+
minor, err := strconv.ParseUint(split[1], 10, 64)
45+
if err != nil {
46+
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
47+
}
48+
return RPCVersion{Major: major, Minor: minor}, nil
49+
}
50+
51+
func (v RPCVersion) String() string {
52+
return fmt.Sprintf("%d.%d", v.Major, v.Minor)
53+
}
54+
55+
// IsCompatibleWith returns the lowest version that is compatible with both
56+
// versions. If the versions are not compatible, the second return value will be
57+
// false.
58+
func (v RPCVersion) IsCompatibleWith(other RPCVersion) (RPCVersion, bool) {
59+
if v.Major != other.Major {
60+
return RPCVersion{}, false
61+
}
62+
// The lowest minor version from the two versions should be returned.
63+
if v.Minor < other.Minor {
64+
return v, true
65+
}
66+
return other, true
67+
}
68+
69+
// RPCVersionList represents a list of RPC versions supported by a RPC peer. An
70+
type RPCVersionList struct {
71+
Versions []RPCVersion `json:"versions"`
72+
}
73+
74+
// ParseRPCVersionList parses a version string in the format
75+
// "major.minor,major.minor" into a RPCVersionList.
76+
func ParseRPCVersionList(str string) (RPCVersionList, error) {
77+
split := strings.Split(str, ",")
78+
versions := make([]RPCVersion, len(split))
79+
for i, v := range split {
80+
version, err := ParseRPCVersion(v)
81+
if err != nil {
82+
return RPCVersionList{}, xerrors.Errorf("invalid version list: %s", str)
83+
}
84+
versions[i] = version
85+
}
86+
vl := RPCVersionList{Versions: versions}
87+
err := vl.Validate()
88+
if err != nil {
89+
return RPCVersionList{}, xerrors.Errorf("invalid parsed version list %q: %w", str, err)
90+
}
91+
return vl, nil
92+
}
93+
94+
func (vl RPCVersionList) String() string {
95+
versionStrings := make([]string, len(vl.Versions))
96+
for i, v := range vl.Versions {
97+
versionStrings[i] = v.String()
98+
}
99+
return strings.Join(versionStrings, ",")
100+
}
101+
102+
// Validate returns an error if the version list is not sorted or contains
103+
// duplicate major versions.
104+
func (vl RPCVersionList) Validate() error {
105+
if len(vl.Versions) == 0 {
106+
return xerrors.New("no versions")
107+
}
108+
for i := 0; i < len(vl.Versions); i++ {
109+
if vl.Versions[i].Major == 0 {
110+
return xerrors.Errorf("invalid version: %s", vl.Versions[i].String())
111+
}
112+
if i > 0 && vl.Versions[i-1].Major == vl.Versions[i].Major {
113+
return xerrors.Errorf("duplicate major version: %d", vl.Versions[i].Major)
114+
}
115+
if i > 0 && vl.Versions[i-1].Major > vl.Versions[i].Major {
116+
return xerrors.Errorf("versions are not sorted")
117+
}
118+
}
119+
return nil
120+
}
121+
122+
// IsCompatibleWith returns the lowest version that is compatible with both
123+
// version lists. If the versions are not compatible, the second return value
124+
// will be false.
125+
func (vl RPCVersionList) IsCompatibleWith(other RPCVersionList) (RPCVersion, bool) {
126+
bestVersion := RPCVersion{}
127+
for _, v1 := range vl.Versions {
128+
for _, v2 := range other.Versions {
129+
if v1.Major == v2.Major && v1.Major > bestVersion.Major {
130+
v, ok := v1.IsCompatibleWith(v2)
131+
if ok {
132+
bestVersion = v
133+
}
134+
}
135+
}
136+
}
137+
if bestVersion.Major == 0 {
138+
return bestVersion, false
139+
}
140+
return bestVersion, true
141+
}

0 commit comments

Comments
 (0)