diff --git a/vpn/speaker.go b/vpn/speaker.go index e4305ad4ae4d0..06236cbeeb6d6 100644 --- a/vpn/speaker.go +++ b/vpn/speaker.go @@ -11,7 +11,6 @@ import ( "google.golang.org/protobuf/proto" "cdr.dev/slog" - "github.com/coder/coder/v2/apiversion" ) type SpeakerRole string @@ -258,7 +257,7 @@ func handshake( // read and write simultaneously to avoid deadlocking if the conn is not buffered errCh := make(chan error, 2) go func() { - ours := headerString(CurrentVersion, me) + ours := headerString(me, CurrentSupportedVersions) _, err := conn.Write([]byte(ours)) logger.Debug(ctx, "wrote out header") if err != nil { @@ -316,34 +315,43 @@ func handshake( } } logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader)) - err := validateHeader(theirHeader, them) + gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions) if err != nil { return xerrors.Errorf("validate header (%s): %w", theirHeader, err) } + logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion)) + // TODO: actually use the common version to perform different behavior once + // we have multiple versions return nil } const headerPreamble = "codervpn" -func headerString(version *apiversion.APIVersion, role SpeakerRole) string { - return fmt.Sprintf("%s %s %s\n", headerPreamble, version.String(), role) +func headerString(role SpeakerRole, versions RPCVersionList) string { + return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String()) } -func validateHeader(header string, expectedRole SpeakerRole) error { +func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) { parts := strings.Split(header, " ") if len(parts) != 3 { - return xerrors.New("wrong number of parts") + return RPCVersion{}, xerrors.New("wrong number of parts") } if parts[0] != headerPreamble { - return xerrors.New("invalid preamble") + return RPCVersion{}, xerrors.New("invalid preamble") } - if err := CurrentVersion.Validate(parts[1]); err != nil { - return xerrors.Errorf("version: %w", err) + if parts[1] != string(expectedRole) { + return RPCVersion{}, xerrors.New("unexpected role") } - if parts[2] != string(expectedRole) { - return xerrors.New("unexpected role") + otherVersions, err := ParseRPCVersionList(parts[2]) + if err != nil { + return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err) } - return nil + compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions) + if !ok { + return RPCVersion{}, + xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String()) + } + return compatibleVersion, nil } type request[S rpcMessage, R rpcMessage] struct { diff --git a/vpn/speaker_internal_test.go b/vpn/speaker_internal_test.go index b1f38a91724fd..03a6ed0927c35 100644 --- a/vpn/speaker_internal_test.go +++ b/vpn/speaker_internal_test.go @@ -47,14 +47,14 @@ func TestSpeaker_RawPeer(t *testing.T) { errCh <- err }() - expectedHandshake := "codervpn 1.0 tunnel\n" + expectedHandshake := "codervpn tunnel 1.0\n" b := make([]byte, 256) n, err := mp.Read(b) require.NoError(t, err) require.Equal(t, expectedHandshake, string(b[:n])) - _, err = mp.Write([]byte("codervpn 1.0 manager\n")) + _, err = mp.Write([]byte("codervpn manager 1.3,2.1\n")) require.NoError(t, err) err = testutil.RequireRecvCtx(ctx, t, errCh) @@ -155,7 +155,7 @@ func TestSpeaker_OversizeHandshake(t *testing.T) { errCh <- err }() - expectedHandshake := "codervpn 1.0 tunnel\n" + expectedHandshake := "codervpn tunnel 1.0\n" b := make([]byte, 256) n, err := mp.Read(b) @@ -177,10 +177,10 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) { for _, tc := range []struct { name, handshake string }{ - {name: "preamble", handshake: "ssh 1.0 manager\n"}, + {name: "preamble", handshake: "ssh manager 1.0\n"}, {name: "2components", handshake: "ssh manager\n"}, - {name: "newversion", handshake: "codervpn 1.1 manager\n"}, - {name: "oldversion", handshake: "codervpn 0.1 manager\n"}, + {name: "newmajors", handshake: "codervpn manager 2.0,3.0\n"}, + {name: "0version", handshake: "codervpn 0.1 manager\n"}, {name: "unknown_role", handshake: "codervpn 1.0 supervisor\n"}, {name: "unexpected_role", handshake: "codervpn 1.0 tunnel\n"}, } { @@ -208,7 +208,7 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) { _, err = mp.Write([]byte(tc.handshake)) require.NoError(t, err) - expectedHandshake := "codervpn 1.0 tunnel\n" + expectedHandshake := "codervpn tunnel 1.0\n" b := make([]byte, 256) n, err := mp.Read(b) require.NoError(t, err) @@ -246,14 +246,14 @@ func TestSpeaker_CorruptMessage(t *testing.T) { errCh <- err }() - expectedHandshake := "codervpn 1.0 tunnel\n" + expectedHandshake := "codervpn tunnel 1.0\n" b := make([]byte, 256) n, err := mp.Read(b) require.NoError(t, err) require.Equal(t, expectedHandshake, string(b[:n])) - _, err = mp.Write([]byte("codervpn 1.0 manager\n")) + _, err = mp.Write([]byte("codervpn manager 1.0\n")) require.NoError(t, err) err = testutil.RequireRecvCtx(ctx, t, errCh) diff --git a/vpn/version.go b/vpn/version.go index d869ad38ce07d..1962dc36d4501 100644 --- a/vpn/version.go +++ b/vpn/version.go @@ -1,10 +1,141 @@ package vpn -import "github.com/coder/coder/v2/apiversion" +import ( + "fmt" + "strconv" + "strings" -const ( - CurrentMajor = 1 - CurrentMinor = 0 + "golang.org/x/xerrors" ) -var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor) +// CurrentSupportedVersions is the list of versions supported by this +// implementation of the VPN RPC protocol. +var CurrentSupportedVersions = RPCVersionList{ + Versions: []RPCVersion{ + {Major: 1, Minor: 0}, + }, +} + +// RPCVersion represents a single version of the RPC protocol. Any given version +// is expected to be backwards compatible with all previous minor versions on +// the same major version. +// +// e.g. RPCVersion{2, 3} is backwards compatible with RPCVersion{2, 2} but is +// not backwards compatible with RPCVersion{1, 2}. +type RPCVersion struct { + Major uint64 `json:"major"` + Minor uint64 `json:"minor"` +} + +// ParseRPCVersion parses a version string in the format "major.minor" into a +// RPCVersion. +func ParseRPCVersion(str string) (RPCVersion, error) { + split := strings.Split(str, ".") + if len(split) != 2 { + return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str) + } + major, err := strconv.ParseUint(split[0], 10, 64) + if err != nil { + return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str) + } + if major == 0 { + return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str) + } + minor, err := strconv.ParseUint(split[1], 10, 64) + if err != nil { + return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str) + } + return RPCVersion{Major: major, Minor: minor}, nil +} + +func (v RPCVersion) String() string { + return fmt.Sprintf("%d.%d", v.Major, v.Minor) +} + +// IsCompatibleWith returns the lowest version that is compatible with both +// versions. If the versions are not compatible, the second return value will be +// false. +func (v RPCVersion) IsCompatibleWith(other RPCVersion) (RPCVersion, bool) { + if v.Major != other.Major { + return RPCVersion{}, false + } + // The lowest minor version from the two versions should be returned. + if v.Minor < other.Minor { + return v, true + } + return other, true +} + +// RPCVersionList represents a list of RPC versions supported by a RPC peer. An +type RPCVersionList struct { + Versions []RPCVersion `json:"versions"` +} + +// ParseRPCVersionList parses a version string in the format +// "major.minor,major.minor" into a RPCVersionList. +func ParseRPCVersionList(str string) (RPCVersionList, error) { + split := strings.Split(str, ",") + versions := make([]RPCVersion, len(split)) + for i, v := range split { + version, err := ParseRPCVersion(v) + if err != nil { + return RPCVersionList{}, xerrors.Errorf("invalid version list: %s", str) + } + versions[i] = version + } + vl := RPCVersionList{Versions: versions} + err := vl.Validate() + if err != nil { + return RPCVersionList{}, xerrors.Errorf("invalid parsed version list %q: %w", str, err) + } + return vl, nil +} + +func (vl RPCVersionList) String() string { + versionStrings := make([]string, len(vl.Versions)) + for i, v := range vl.Versions { + versionStrings[i] = v.String() + } + return strings.Join(versionStrings, ",") +} + +// Validate returns an error if the version list is not sorted or contains +// duplicate major versions. +func (vl RPCVersionList) Validate() error { + if len(vl.Versions) == 0 { + return xerrors.New("no versions") + } + for i := 0; i < len(vl.Versions); i++ { + if vl.Versions[i].Major == 0 { + return xerrors.Errorf("invalid version: %s", vl.Versions[i].String()) + } + if i > 0 && vl.Versions[i-1].Major == vl.Versions[i].Major { + return xerrors.Errorf("duplicate major version: %d", vl.Versions[i].Major) + } + if i > 0 && vl.Versions[i-1].Major > vl.Versions[i].Major { + return xerrors.Errorf("versions are not sorted") + } + } + return nil +} + +// IsCompatibleWith returns the lowest version that is compatible with both +// version lists. If the versions are not compatible, the second return value +// will be false. +func (vl RPCVersionList) IsCompatibleWith(other RPCVersionList) (RPCVersion, bool) { + bestVersion := RPCVersion{} + for _, v1 := range vl.Versions { + for _, v2 := range other.Versions { + if v1.Major == v2.Major && v1.Major > bestVersion.Major { + v, ok := v1.IsCompatibleWith(v2) + if ok { + bestVersion = v + } + } + } + } + if bestVersion.Major == 0 { + return bestVersion, false + } + return bestVersion, true +} diff --git a/vpn/version_test.go b/vpn/version_test.go new file mode 100644 index 0000000000000..cff333f10b507 --- /dev/null +++ b/vpn/version_test.go @@ -0,0 +1,262 @@ +package vpn_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/vpn" +) + +func TestRPCVersionLatest(t *testing.T) { + t.Parallel() + require.NoError(t, vpn.CurrentSupportedVersions.Validate()) +} + +func TestRPCVersionParseString(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input string + want vpn.RPCVersion + }{ + { + name: "valid version", + input: "1.0", + want: vpn.RPCVersion{Major: 1, Minor: 0}, + }, + { + name: "valid version with larger numbers", + input: "12.34", + want: vpn.RPCVersion{Major: 12, Minor: 34}, + }, + { + name: "empty string", + input: "", + }, + { + name: "one part", + input: "1", + }, + { + name: "three parts", + input: "1.0.0", + }, + { + name: "major version is 0", + input: "0.1", + }, + { + name: "invalid major version", + input: "a.1", + }, + { + name: "invalid minor version", + input: "1.a", + }, + } + + // nolint:paralleltest + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := vpn.ParseRPCVersion(tc.input) + if tc.want.Major == 0 { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.want, got) + + require.Equal(t, tc.input, got.String()) + } + }) + } +} + +func TestRPCVersionIsCompatibleWith(t *testing.T) { + t.Parallel() + cases := []struct { + name string + v1 vpn.RPCVersion + v2 vpn.RPCVersion + want vpn.RPCVersion + wantBool bool + }{ + { + name: "same version", + v1: vpn.RPCVersion{Major: 1, Minor: 0}, + v2: vpn.RPCVersion{Major: 1, Minor: 0}, + want: vpn.RPCVersion{Major: 1, Minor: 0}, + }, + { + name: "compatible minor versions", + v1: vpn.RPCVersion{Major: 1, Minor: 2}, + v2: vpn.RPCVersion{Major: 1, Minor: 3}, + want: vpn.RPCVersion{Major: 1, Minor: 2}, + }, + { + name: "incompatible major versions", + v1: vpn.RPCVersion{Major: 1, Minor: 0}, + v2: vpn.RPCVersion{Major: 2, Minor: 0}, + want: vpn.RPCVersion{}, + }, + } + + // nolint:paralleltest + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, ok := tc.v1.IsCompatibleWith(tc.v2) + if tc.want.Major == 0 { + require.False(t, ok) + return + } + require.True(t, ok) + require.Equal(t, got, tc.want) + }) + } +} + +func TestRPCVersionListParseString(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + input string + want vpn.RPCVersionList + errContains string + }{ + { + name: "single version", + input: "1.0", + want: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{ + {Major: 1, Minor: 0}, + }, + }, + }, + { + name: "multiple versions", + input: "1.1,2.3,3.2", + want: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{ + {Major: 1, Minor: 1}, + {Major: 2, Minor: 3}, + {Major: 3, Minor: 2}, + }, + }, + }, + { + name: "invalid version", + input: "1.0,invalid", + errContains: "invalid version list", + }, + { + name: "empty string", + input: "", + errContains: "invalid version list", + }, + { + name: "duplicate versions", + input: "1.0,1.0", + errContains: "duplicate major version", + }, + { + name: "duplicate major versions", + input: "1.0,1.2", + errContains: "duplicate major version", + }, + { + name: "out of order versions", + input: "2.0,1.0", + errContains: "versions are not sorted", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := vpn.ParseRPCVersionList(tc.input) + if tc.errContains != "" { + require.ErrorContains(t, err, tc.errContains) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + require.Equal(t, tc.input, got.String()) + }) + } +} + +func TestRPCVersionListValidate(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + list vpn.RPCVersionList + errContains string + }{ + { + name: "valid list", + list: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{ + {Major: 1, Minor: 1}, + {Major: 2, Minor: 3}, + {Major: 3, Minor: 2}, + }, + }, + }, + { + name: "empty list", + list: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{}, + }, + errContains: "no versions", + }, + { + name: "duplicate versions", + list: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{ + {Major: 1, Minor: 0}, + {Major: 1, Minor: 0}, + }, + }, + errContains: "duplicate major version", + }, + { + name: "duplicate major versions", + list: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{ + {Major: 1, Minor: 0}, + {Major: 1, Minor: 2}, + }, + }, + errContains: "duplicate major version", + }, + { + name: "out of order versions", + list: vpn.RPCVersionList{ + Versions: []vpn.RPCVersion{ + {Major: 2, Minor: 0}, + {Major: 1, Minor: 0}, + }, + }, + errContains: "versions are not sorted", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.list.Validate() + if tc.errContains != "" { + require.ErrorContains(t, err, tc.errContains) + } else { + require.NoError(t, err) + } + }) + } +}