diff --git a/coderd/coderdtest/authtest.go b/coderd/coderdtest/authtest.go index 10b96c660f1f0..42da8bbe5c571 100644 --- a/coderd/coderdtest/authtest.go +++ b/coderd/coderdtest/authtest.go @@ -8,13 +8,12 @@ import ( "strings" "testing" - "github.com/coder/coder/coderd" - "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" @@ -33,8 +32,8 @@ type AuthTester struct { t *testing.T api *coderd.API authorizer *recordingAuthorizer - client *codersdk.Client + Client *codersdk.Client Workspace codersdk.Workspace Organization codersdk.Organization Admin codersdk.CreateFirstUserResponse @@ -117,14 +116,11 @@ func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTes }) require.NoError(t, err, "create template param") - // Always fail auth from this point forward - authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) - return &AuthTester{ t: t, api: api, authorizer: authorizer, - client: client, + Client: client, Workspace: workspace, Organization: organization, Admin: admin, @@ -386,6 +382,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { } func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) { + // Always fail auth from this point forward + a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil) + for k, v := range assertRoute { noTrailSlash := strings.TrimRight(k, "/") if _, ok := assertRoute[noTrailSlash]; ok && noTrailSlash != k { @@ -450,7 +449,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck route = strings.ReplaceAll(route, "{scope}", string(a.TemplateParam.Scope)) route = strings.ReplaceAll(route, "{id}", a.TemplateParam.ScopeID.String()) - resp, err := a.client.Request(ctx, method, route, nil) + resp, err := a.Client.Request(ctx, method, route, nil) require.NoError(t, err, "do req") body, _ := io.ReadAll(resp.Body) t.Logf("Response Body: %q", string(body)) diff --git a/enterprise/cli/licenses.go b/enterprise/cli/licenses.go index c548b74f31893..72c7c3d67e819 100644 --- a/enterprise/cli/licenses.go +++ b/enterprise/cli/licenses.go @@ -26,6 +26,7 @@ func licenses() *cobra.Command { } cmd.AddCommand( licenseAdd(), + licensesList(), ) return cmd } @@ -112,3 +113,32 @@ func validJWT(s string) error { } return xerrors.New("Invalid license") } + +func licensesList() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List licenses (including expired)", + Aliases: []string{"ls"}, + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + client, err := agpl.CreateClient(cmd) + if err != nil { + return err + } + + licenses, err := client.Licenses(cmd.Context()) + if err != nil { + return err + } + // Ensure that we print "[]" instead of "null" when there are no licenses. + if licenses == nil { + licenses = make([]codersdk.License, 0) + } + + enc := json.NewEncoder(cmd.OutOrStdout()) + enc.SetIndent("", " ") + return enc.Encode(licenses) + }, + } + return cmd +} diff --git a/enterprise/cli/licenses_test.go b/enterprise/cli/licenses_test.go index f56be0b79ec45..79cf1256997f6 100644 --- a/enterprise/cli/licenses_test.go +++ b/enterprise/cli/licenses_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/go-chi/chi/v5" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,7 +29,7 @@ import ( const fakeLicenseJWT = "test.jwt.sig" -func TestLicensesAddSuccess(t *testing.T) { +func TestLicensesAddFake(t *testing.T) { t.Parallel() // We can't check a real license into the git repo, and can't patch out the keys from here, // so instead we have to fake the HTTP interaction. @@ -117,9 +118,9 @@ func TestLicensesAddSuccess(t *testing.T) { }) } -func TestLicensesAddFail(t *testing.T) { +func TestLicensesAddReal(t *testing.T) { t.Parallel() - t.Run("LFlag", func(t *testing.T) { + t.Run("Fails", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) coderdtest.CreateFirstUser(t, client) @@ -141,9 +142,58 @@ func TestLicensesAddFail(t *testing.T) { }) } +func TestLicensesListFake(t *testing.T) { + t.Parallel() + // We can't check a real license into the git repo, and can't patch out the keys from here, + // so instead we have to fake the HTTP interaction. + t.Run("Mainline", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + cmd := setupFakeLicenseServerTest(t, "licenses", "list") + stdout := new(bytes.Buffer) + cmd.SetOut(stdout) + errC := make(chan error) + go func() { + errC <- cmd.ExecuteContext(ctx) + }() + require.NoError(t, <-errC) + var licenses []codersdk.License + err := json.Unmarshal(stdout.Bytes(), &licenses) + require.NoError(t, err) + require.Len(t, licenses, 2) + assert.Equal(t, int32(1), licenses[0].ID) + assert.Equal(t, "claim1", licenses[0].Claims["h1"]) + assert.Equal(t, int32(5), licenses[1].ID) + assert.Equal(t, "claim2", licenses[1].Claims["h2"]) + }) +} + +func TestLicensesListReal(t *testing.T) { + t.Parallel() + t.Run("Empty", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) + coderdtest.CreateFirstUser(t, client) + cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), + "licenses", "list") + stdout := new(bytes.Buffer) + cmd.SetOut(stdout) + clitest.SetupConfig(t, client, root) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + errC := make(chan error) + go func() { + errC <- cmd.ExecuteContext(ctx) + }() + require.NoError(t, <-errC) + assert.Equal(t, "[]\n", stdout.String()) + }) +} + func setupFakeLicenseServerTest(t *testing.T, args ...string) *cobra.Command { t.Helper() - s := httptest.NewServer(&fakeAddLicenseServer{t}) + s := httptest.NewServer(newFakeLicenseAPI(t)) t.Cleanup(s.Close) cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), args...) err := root.URL().Write(s.URL) @@ -160,16 +210,28 @@ func attachPty(t *testing.T, cmd *cobra.Command) *ptytest.PTY { return pty } -type fakeAddLicenseServer struct { +func newFakeLicenseAPI(t *testing.T) http.Handler { + r := chi.NewRouter() + a := &fakeLicenseAPI{t: t, r: r} + r.NotFound(a.notFound) + r.Post("/api/v2/licenses", a.postLicense) + r.Get("/api/v2/licenses", a.licenses) + r.Get("/api/v2/buildinfo", a.noop) + return r +} + +type fakeLicenseAPI struct { t *testing.T + r chi.Router } -func (s *fakeAddLicenseServer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/api/v2/buildinfo" { - return - } - assert.Equal(s.t, http.MethodPost, r.Method) - assert.Equal(s.t, "/api/v2/licenses", r.URL.Path) +func (s *fakeLicenseAPI) notFound(_ http.ResponseWriter, r *http.Request) { + s.t.Errorf("unexpected HTTP call: %s", r.URL.Path) +} + +func (*fakeLicenseAPI) noop(_ http.ResponseWriter, _ *http.Request) {} + +func (s *fakeLicenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) { var req codersdk.AddLicenseRequest err := json.NewDecoder(r.Body).Decode(&req) require.NoError(s.t, err) @@ -190,3 +252,33 @@ func (s *fakeAddLicenseServer) ServeHTTP(rw http.ResponseWriter, r *http.Request err = json.NewEncoder(rw).Encode(resp) assert.NoError(s.t, err) } + +func (s *fakeLicenseAPI) licenses(rw http.ResponseWriter, _ *http.Request) { + resp := []codersdk.License{ + { + ID: 1, + UploadedAt: time.Now(), + Claims: map[string]interface{}{ + "h1": "claim1", + "features": map[string]int64{ + "f1": 1, + "f2": 2, + }, + }, + }, + { + ID: 5, + UploadedAt: time.Now(), + Claims: map[string]interface{}{ + "h2": "claim2", + "features": map[string]int64{ + "f3": 3, + "f4": 4, + }, + }, + }, + } + rw.WriteHeader(http.StatusOK) + err := json.NewEncoder(rw).Encode(resp) + assert.NoError(s.t, err) +} diff --git a/enterprise/coderd/auth_internal_test.go b/enterprise/coderd/auth_internal_test.go new file mode 100644 index 0000000000000..04f2a71d5fc86 --- /dev/null +++ b/enterprise/coderd/auth_internal_test.go @@ -0,0 +1,74 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "net/http" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" +) + +// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered. +// these tests patch the map of license keys, so cannot be run in parallel +// nolint:paralleltest +func TestAuthorizeAllEndpoints(t *testing.T) { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + oldKeys := keys + defer func() { + t.Log("restoring keys") + keys = oldKeys + }() + keys = map[string]ed25519.PublicKey{keyID: pubKey} + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + a := coderdtest.NewAuthTester(ctx, t, &coderdtest.Options{APIBuilder: NewEnterprise}) + + // We need a license in the DB, so that when we call GET api/v2/licenses there is one in the + // list to check authz on. + claims := &Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "test@coder.test", + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), + }, + LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)), + AccountType: AccountTypeSalesforce, + AccountID: "testing", + Version: CurrentVersion, + Features: Features{ + UserLimit: 0, + AuditLog: 1, + }, + } + lic, err := makeLicense(claims, privKey, keyID) + require.NoError(t, err) + _, err = a.Client.AddLicense(ctx, codersdk.AddLicenseRequest{ + License: lic, + }) + require.NoError(t, err) + + skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) + assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionCreate, + AssertObject: rbac.ResourceLicense, + } + assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{ + StatusCode: http.StatusOK, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceLicense, + } + a.Test(ctx, assertRoute, skipRoutes) +} diff --git a/enterprise/coderd/auth_test.go b/enterprise/coderd/auth_test.go deleted file mode 100644 index b390db7cf09f5..0000000000000 --- a/enterprise/coderd/auth_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package coderd_test - -import ( - "context" - "net/http" - "testing" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/enterprise/coderd" - "github.com/coder/coder/testutil" -) - -// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered. -func TestAuthorizeAllEndpoints(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - a := coderdtest.NewAuthTester(ctx, t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise}) - skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) - assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{ - AssertAction: rbac.ActionCreate, - AssertObject: rbac.ResourceLicense, - } - // TODO: fix this test so that there are licenses to get. - assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{ - StatusCode: http.StatusOK, - NoAuthorize: true, - } - a.Test(ctx, assertRoute, skipRoutes) -} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index e9cdbd0cde62a..420d1c034b7cf 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -12,11 +12,12 @@ import ( "strings" "time" - "cdr.dev/slog" "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v4" "golang.org/x/xerrors" + "cdr.dev/slog" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" @@ -253,7 +254,7 @@ func decodeClaims(l database.License) (jwt.MapClaims, error) { if len(parts) != 3 { return nil, xerrors.Errorf("Unable to parse license %d as JWT", l.ID) } - cb, err := base64.URLEncoding.DecodeString(parts[1]) + cb, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, xerrors.Errorf("Unable to decode license %d claims: %w", l.ID, err) }