Skip to content

fix(enterprise/coderd): check provisionerd API version on connection #12191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 16, 2024
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
)

func TestAPIVersionValidate(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion coderd/healthcheck/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (

"golang.org/x/mod/semver"

"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk"
Expand Down
3 changes: 3 additions & 0 deletions codersdk/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionerd/runner"
"github.com/coder/coder/v2/provisionersdk"
)

type LogSource string
Expand Down Expand Up @@ -201,6 +202,8 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
query := serverURL.Query()
query.Add("id", req.ID.String())
query.Add("name", req.Name)
query.Add("version", provisionersdk.VersionCurrent.String())

for _, provisioner := range req.Provisioners {
query.Add("provisioner", string(provisioner))
}
Expand Down
2 changes: 1 addition & 1 deletion docs/templates/resource-ordering.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The resource with the lower `order` is presented before the one with greater
value. A missing `order` property defaults to 0. If two resources have the same
`order` property, the resources will be ordered by property `name` (or `key`).

## Using `order` property
## Using "order" property

### Coder parameters

Expand Down
10 changes: 10 additions & 0 deletions enterprise/coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
apiVersion = qv
}

if err := provisionersdk.VersionCurrent.Validate(apiVersion); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Incompatible or unparsable version",
Validations: []codersdk.ValidationError{
{Field: "version", Detail: err.Error()},
},
})
return
}

// Create the daemon in the database.
now := dbtime.Now()
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
Expand Down
105 changes: 105 additions & 0 deletions enterprise/coderd/provisionerdaemons_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package coderd_test
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"testing"

Expand All @@ -12,6 +14,7 @@ import (

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
Expand Down Expand Up @@ -63,6 +66,108 @@ func TestProvisionerDaemonServe(t *testing.T) {
}
})

t.Run("NoVersion", func(t *testing.T) {
t.Parallel()
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
// endpoint. We do not pass the required machinery to start a websocket connection, so we expect a
// WebSocket protocol violation. This just means the pre-flight checks have passed though.

// Sending a HTTP request triggers an error log, which would otherwise fail the test.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
client, user := coderdenttest.New(t, &coderdenttest.Options{
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
Options: &coderdtest.Options{
Logger: &logger,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

// Formulate the correct URL for provisionerd server.
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
require.NoError(t, err)
q := srvURL.Query()
// Set required query parameters.
q.Add("provisioner", "echo")
// Note: Explicitly not setting API version.
q.Add("version", "")
srvURL.RawQuery = q.Encode()

// Set PSK header for auth.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
require.NoError(t, err)
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")

// Do the request!
resp, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// The below means that provisionerd tried to serve us!
require.Contains(t, string(b), "Internal error accepting websocket connection.")

daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
if assert.Len(t, daemons, 1) {
assert.Equal(t, "1.0", daemons[0].APIVersion) // The whole point of this test is here.
}
})

t.Run("OldVersion", func(t *testing.T) {
t.Parallel()
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
// endpoint. We do not pass the required machinery to start a websocket connection, but we pass a
// version header that should cause provisionerd to refuse to serve us, so no websocket for you!

// Sending a HTTP request triggers an error log, which would otherwise fail the test.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
client, user := coderdenttest.New(t, &coderdenttest.Options{
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
Options: &coderdtest.Options{
Logger: &logger,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

// Formulate the correct URL for provisionerd server.
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
require.NoError(t, err)
q := srvURL.Query()
// Set required query parameters.
q.Add("provisioner", "echo")

// Set a different (newer) version than the current.
v := apiversion.New(provisionersdk.CurrentMajor+1, provisionersdk.CurrentMinor+1)
q.Add("version", v.String())
srvURL.RawQuery = q.Encode()

// Set PSK header for auth.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
require.NoError(t, err)
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")

// Do the request!
resp, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// The below means that provisionerd tried to serve us, checked our api version, and said nope.
require.Contains(t, string(b), "server is at version 1.0, behind requested major version 2.1")
})

t.Run("NoLicense", func(t *testing.T) {
t.Parallel()
client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
Expand Down
2 changes: 1 addition & 1 deletion enterprise/coderd/workspaceproxycoordinate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"github.com/google/uuid"
"nhooyr.io/websocket"

"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet/proto"
)
Expand Down
2 changes: 1 addition & 1 deletion enterprise/tailnet/workspaceproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"tailscale.com/tailcfg"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
)
Expand Down
2 changes: 1 addition & 1 deletion provisionersdk/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (

"cdr.dev/slog"

"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/provisionersdk/proto"
)

Expand Down
2 changes: 1 addition & 1 deletion tailnet/proto/version.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package proto

import (
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion tailnet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"tailscale.com/tailcfg"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/tailnet/proto"

"golang.org/x/xerrors"
Expand Down