Skip to content

feat: implement provisioner auth middleware and proper org params #12330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ var (
rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate},
rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol},
// When org scoped provisioner credentials are implemented,
// this can be reduced to read a specific org.
rbac.ResourceOrganization.Type: {rbac.ActionRead},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
Expand Down
29 changes: 29 additions & 0 deletions coderd/httpmw/actor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,32 @@ func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler {
})
}
}

// RequireAPIKeyOrProvisionerDaemonAuth is middleware that should be inserted
// after optional ExtractAPIKey and ExtractProvisionerDaemonAuthenticated
// middlewares to ensure one of the two authentication methods is provided.
//
// If both are provided, an error is returned to avoid misuse.
func RequireAPIKeyOrProvisionerDaemonAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasAPIKey := APIKeyOptional(r)
hasProvisionerDaemon := ProvisionerDaemonAuthenticated(r)

if hasAPIKey && hasProvisionerDaemon {
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
Message: "API key and external provisioner authentication provided, but only one is allowed",
})
return
}
if !hasAPIKey && !hasProvisionerDaemon {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: "API key or external provisioner authentication required, but none provided",
})
return
}

next.ServeHTTP(w, r)
})
}
}
33 changes: 24 additions & 9 deletions coderd/httpmw/organizationparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,41 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
}

var organization database.Organization
var err error
// Try by name or uuid.
id, err := uuid.Parse(arg)
if err == nil {
organization, err = db.GetOrganizationByID(ctx, id)
var dbErr error

// If the name is exactly "default", then we fetch the default
// organization. This is a special case to make it easier
// for single org deployments.
//
// arg == uuid.Nil.String() should be a temporary workaround for
// legacy provisioners that don't provide an organization ID.
// This prevents a breaking change.
// TODO: This change was added March 2024. Nil uuid returning the
// default org should be removed some number of months after
// that date.
if arg == codersdk.DefaultOrganization || arg == uuid.Nil.String() {
organization, dbErr = db.GetDefaultOrganization(ctx)
} else {
organization, err = db.GetOrganizationByName(ctx, arg)
// Try by name or uuid.
id, err := uuid.Parse(arg)
if err == nil {
organization, dbErr = db.GetOrganizationByID(ctx, id)
} else {
organization, dbErr = db.GetOrganizationByName(ctx, arg)
}
}
if httpapi.Is404Error(err) {
if httpapi.Is404Error(dbErr) {
httpapi.ResourceNotFound(rw)
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Organization %q not found.", arg),
Detail: "Provide either the organization id or name.",
})
return
}
if err != nil {
if dbErr != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("Internal error fetching organization %q.", arg),
Detail: err.Error(),
Detail: dbErr.Error(),
})
return
}
Expand Down
19 changes: 19 additions & 0 deletions coderd/httpmw/organizationparam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,24 @@ func TestOrganizationParam(t *testing.T) {
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by name")

// Try by 'default'
chi.RouteContext(r.Context()).URLParams.Add("organization", codersdk.DefaultOrganization)
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.ServeHTTP(rw, r)
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by default keyword")

// Try by legacy
// TODO: This can be removed when legacy nil uuids are no longer supported.
// This is a temporary measure to ensure as legacy provisioners use
// nil uuids as the org id and expect the default org.
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.Nil.String())
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.ServeHTTP(rw, r)
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by nil uuid (legacy)")
})
}
86 changes: 86 additions & 0 deletions coderd/httpmw/provisionerdaemon.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package httpmw

import (
"context"
"crypto/subtle"
"net/http"

"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)

type provisionerDaemonContextKey struct{}

func ProvisionerDaemonAuthenticated(r *http.Request) bool {
proxy, ok := r.Context().Value(provisionerDaemonContextKey{}).(bool)
return ok && proxy
}

type ExtractProvisionerAuthConfig struct {
DB database.Store
Optional bool
}

func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

handleOptional := func(code int, response codersdk.Response) {
if opts.Optional {
next.ServeHTTP(w, r)
return
}
httpapi.Write(ctx, w, code, response)
}

if psk == "" {
// No psk means external provisioner daemons are not allowed.
// So their auth is not valid.
handleOptional(http.StatusBadRequest, codersdk.Response{
Message: "External provisioner daemons not enabled",
})
return
}

token := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if token == "" {
handleOptional(http.StatusUnauthorized, codersdk.Response{
Message: "provisioner daemon auth token required",
})
return
}

if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 {
handleOptional(http.StatusUnauthorized, codersdk.Response{
Message: "provisioner daemon auth token invalid",
})
return
}

// The PSK does not indicate a specific provisioner daemon. So just
// store a boolean so the caller can check if the request is from an
// authenticated provisioner daemon.
ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true)
// nolint:gocritic // Authenticating as a provisioner daemon.
ctx = dbauthz.AsProvisionerd(ctx)
subj, ok := dbauthz.ActorFromContext(ctx)
if !ok {
// This should never happen
httpapi.InternalServerError(w, xerrors.New("developer error: ExtractProvisionerDaemonAuth missing rbac actor"))
}

// Use the same subject for the userAuthKey
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. why do we need this for provisioner daemons?

  2. more generally, why do we store the rbac.Subject under 2 different context keys (authContextKey{} and userAuthKey{})?

  3. what is the difference between an "actor" and a "subject"? Are they interchangeable ideas and we're just inconsistent in how we refer to them in the code?

Copy link
Member Author

@Emyrk Emyrk Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be merged into 1 context value.

It exists because dbauthz needed a context, and it did not feel right to use the httpmw functions for it. So it created it's own.

We should standardize on the dbauthz context, or place one in like /rbac.


  1. I was just being consistent with the other auths. If I omit this, then the HTTPAuthorizer will not work:
    func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
    . This is used in the provisioner handlers.
  2. Legacy reasons when dbauthz used to be an experiment, and not always enabled.
  3. Interchangeable. One affects http handlers, one affects dbauthz.

I just made an issue tracking it here: #12363

Actor: subj,
ActorName: "provisioner_daemon",
})

next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
7 changes: 7 additions & 0 deletions coderd/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
return
}

if req.Name == codersdk.DefaultOrganization {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Organization name %q is reserved.", codersdk.DefaultOrganization),
})
return
}

_, err := api.Database.GetOrganizationByName(ctx, req.Name)
if err == nil {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Expand Down
3 changes: 3 additions & 0 deletions codersdk/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"golang.org/x/xerrors"
)

// DefaultOrganization is used as a replacement for the default organization.
var DefaultOrganization = "default"

type ProvisionerStorageMethod string

const (
Expand Down
11 changes: 8 additions & 3 deletions codersdk/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ type ServeProvisionerDaemonRequest struct {
ID uuid.UUID `json:"id" format:"uuid"`
// Name is the human-readable unique identifier for the daemon.
Name string `json:"name" example:"my-cool-provisioner-daemon"`
// Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations
// and so the organization ID is optional.
// Organization is the organization for the URL. If no orgID is provided,
// then it is assumed to use the default organization.
Organization uuid.UUID `json:"organization" format:"uuid"`
// Provisioners is a list of provisioner types hosted by the provisioner daemon
Provisioners []ProvisionerType `json:"provisioners"`
Expand All @@ -194,7 +194,12 @@ type ServeProvisionerDaemonRequest struct {
// implementation. The context is during dial, not during the lifetime of the
// client. Client should be closed after use.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization))
orgParam := req.Organization.String()
if req.Organization == uuid.Nil {
orgParam = DefaultOrganization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why we don't like uuid.Nil meaning the default on the API, but are fine with it on the SDK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back and forth on this. I think we are the only users of our sdk atm, but if I was to require the OrgID argument in the sdk to be set now, that would be a breaking change for the sdk.

I've been trying to maintain that single org deployments do not have to make any changes, and they continue to work as they do today.

On the sdk side, this is an omission of a field being set.

On the API, the nil uuid could be an explicit value being set, rather than an omission.

}

serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", orgParam))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
Expand Down
9 changes: 9 additions & 0 deletions enterprise/coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
r.Use(
api.provisionerDaemonsEnabledMW,
apiKeyMiddlewareOptional,
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: api.Database,
Optional: true,
}, api.ProvisionerDaemonPSK),
// Either a user auth or provisioner auth is required
// to move forward.
httpmw.RequireAPIKeyOrProvisionerDaemonAuth(),
httpmw.ExtractOrganizationParam(api.Database),
)
r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons)
r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe)
Expand Down
20 changes: 7 additions & 13 deletions enterprise/coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package coderd

import (
"context"
"crypto/subtle"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
for _, daemon := range daemons {
apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon))
}
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)

httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
}

type provisionerDaemonAuth struct {
Expand Down Expand Up @@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
}

// Check for PSK
if p.psk != "" {
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
if provAuth {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
return nil, false
}
Expand Down
11 changes: 6 additions & 5 deletions enterprise/coderd/provisionerdaemons_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) {

t.Run("PSK_daily_cost", func(t *testing.T) {
t.Parallel()
const provPSK = `provisionersftw`
client, user := coderdenttest.New(t, &coderdenttest.Options{
UserWorkspaceQuota: 10,
LicenseOptions: &coderdenttest.LicenseOptions{
Expand All @@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
codersdk.FeatureTemplateRBAC: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
ProvisionerDaemonPSK: provPSK,
})
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
Expand Down Expand Up @@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
Tags: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
},
PreSharedKey: "provisionersftw",
PreSharedKey: provPSK,
})
}, &provisionerd.Options{
Logger: logger.Named("provisionerd"),
Expand Down Expand Up @@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())

daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
Expand Down Expand Up @@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())

daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
Expand Down Expand Up @@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())

daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
Expand Down