Skip to content

Commit db0ba85

Browse files
authored
chore: Refactor Enterprise code to layer on top of AGPL (coder#4034)
* chore: Refactor Enterprise code to layer on top of AGPL This is an experiment to invert the import order of the Enterprise code to layer on top of AGPL. * Fix Garrett's comments * Add pointer.Handle to atomically obtain references This uses a context to ensure the same value persists through multiple executions to `Load()`. * Remove entitlements API from AGPL coderd * Remove AGPL Coder entitlements endpoint test * Fix warnings output * Add command-line flag to toggle audit logging * Fix hasLicense being set * Remove features interface * Fix audit logging default * Add bash as a dependency * Add comment * Add tests for resync and pubsub, and add back previous exp backoff retry * Separate authz code again * Add pointer loading example from comment * Fix duplicate test, remove pointer.Handle * Fix expired license * Add entitlements struct * Fix context passing
1 parent 714c366 commit db0ba85

39 files changed

+1341
-2008
lines changed

cli/gitssh_test.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ import (
2020
"github.com/stretchr/testify/require"
2121
gossh "golang.org/x/crypto/ssh"
2222

23-
"cdr.dev/slog"
24-
2523
"github.com/coder/coder/cli/clitest"
2624
"github.com/coder/coder/coderd/coderdtest"
2725
"github.com/coder/coder/codersdk"
@@ -83,18 +81,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
8381
errC <- cmd.ExecuteContext(ctx)
8482
}()
8583
t.Cleanup(func() { require.NoError(t, <-errC) })
86-
8784
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
88-
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
89-
require.NoError(t, err)
90-
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
91-
require.NoError(t, err)
92-
defer dialer.Close()
93-
require.Eventually(t, func() bool {
94-
_, err = dialer.Ping()
95-
return err == nil
96-
}, testutil.WaitMedium, testutil.IntervalFast)
97-
9885
return agentClient, agentToken, pubkey
9986
}
10087

cli/root.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ func Core() []*cobra.Command {
9191
users(),
9292
versionCmd(),
9393
workspaceAgent(),
94-
features(),
9594
}
9695
}
9796

9897
func AGPL() []*cobra.Command {
99-
all := append(Core(), Server(coderd.New))
98+
all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, error) {
99+
return coderd.New(o), nil
100+
}))
100101
return all
101102
}
102103

@@ -548,13 +549,11 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
548549
defer cancel()
549550

550551
entitlements, err := client.Entitlements(ctx)
551-
if err != nil {
552-
return xerrors.Errorf("get entitlements to show warnings: %w", err)
553-
}
554-
for _, w := range entitlements.Warnings {
555-
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
552+
if err == nil {
553+
for _, w := range entitlements.Warnings {
554+
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
555+
}
556556
}
557-
558557
return nil
559558
}
560559

cli/server.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ import (
6868
)
6969

7070
// nolint:gocyclo
71-
func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
71+
func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command {
7272
var (
7373
accessURL string
7474
address string
@@ -489,7 +489,10 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
489489
), promAddress, "prometheus")()
490490
}
491491

492-
coderAPI := newAPI(options)
492+
coderAPI, err := newAPI(ctx, options)
493+
if err != nil {
494+
return err
495+
}
493496
defer coderAPI.Close()
494497

495498
client := codersdk.New(localURL)
@@ -536,7 +539,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
536539
// These errors are typically noise like "TLS: EOF". Vault does similar:
537540
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
538541
ErrorLog: log.New(io.Discard, "", 0),
539-
Handler: coderAPI.Handler,
542+
Handler: coderAPI.RootHandler,
540543
BaseContext: func(_ net.Listener) context.Context {
541544
return shutdownConnsCtx
542545
},

coderd/audit/request.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@ import (
1212

1313
"cdr.dev/slog"
1414
"github.com/coder/coder/coderd/database"
15-
"github.com/coder/coder/coderd/features"
1615
"github.com/coder/coder/coderd/httpmw"
1716
"github.com/coder/coder/coderd/tracing"
1817
)
1918

2019
type RequestParams struct {
21-
Features features.Service
22-
Log slog.Logger
20+
Audit Auditor
21+
Log slog.Logger
2322

2423
Request *http.Request
2524
Action database.AuditAction
@@ -102,15 +101,6 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
102101
params: p,
103102
}
104103

105-
feats := struct {
106-
Audit Auditor
107-
}{}
108-
err := p.Features.Get(&feats)
109-
if err != nil {
110-
p.Log.Error(p.Request.Context(), "unable to get auditor interface", slog.Error(err))
111-
return req, func() {}
112-
}
113-
114104
return req, func() {
115105
ctx := context.Background()
116106
logCtx := p.Request.Context()
@@ -120,15 +110,15 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
120110
return
121111
}
122112

123-
diff := Diff(feats.Audit, req.Old, req.New)
113+
diff := Diff(p.Audit, req.Old, req.New)
124114
diffRaw, _ := json.Marshal(diff)
125115

126116
ip, err := parseIP(p.Request.RemoteAddr)
127117
if err != nil {
128118
p.Log.Warn(logCtx, "parse ip", slog.Error(err))
129119
}
130120

131-
err = feats.Audit.Export(ctx, database.AuditLog{
121+
err = p.Audit.Export(ctx, database.AuditLog{
132122
ID: uuid.New(),
133123
Time: database.Now(),
134124
UserID: httpmw.APIKey(p.Request).UserID,

coderd/authorize.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ type HTTPAuthorizer struct {
4343
// return
4444
// }
4545
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
46-
return api.httpAuth.Authorize(r, action, object)
46+
return api.HTTPAuth.Authorize(r, action, object)
4747
}
4848

4949
// Authorize will return false if the user is not authorized to do the action.

coderd/coderd.go

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/url"
88
"path/filepath"
99
"sync"
10+
"sync/atomic"
1011
"time"
1112

1213
"github.com/andybalholm/brotli"
@@ -24,9 +25,9 @@ import (
2425

2526
"cdr.dev/slog"
2627
"github.com/coder/coder/buildinfo"
28+
"github.com/coder/coder/coderd/audit"
2729
"github.com/coder/coder/coderd/awsidentity"
2830
"github.com/coder/coder/coderd/database"
29-
"github.com/coder/coder/coderd/features"
3031
"github.com/coder/coder/coderd/gitsshkey"
3132
"github.com/coder/coder/coderd/httpapi"
3233
"github.com/coder/coder/coderd/httpmw"
@@ -50,6 +51,7 @@ type Options struct {
5051
// CacheDir is used for caching files served by the API.
5152
CacheDir string
5253

54+
Auditor audit.Auditor
5355
AgentConnectionUpdateFrequency time.Duration
5456
AgentInactiveDisconnectTimeout time.Duration
5557
// APIRateLimit is the minutely throughput rate limit per user or ip.
@@ -68,8 +70,6 @@ type Options struct {
6870
Telemetry telemetry.Reporter
6971
TracerProvider trace.TracerProvider
7072
AutoImportTemplates []AutoImportTemplate
71-
LicenseHandler http.Handler
72-
FeaturesService features.Service
7373

7474
TailnetCoordinator *tailnet.Coordinator
7575
DERPMap *tailcfg.DERPMap
@@ -80,6 +80,9 @@ type Options struct {
8080

8181
// New constructs a Coder API handler.
8282
func New(options *Options) *API {
83+
if options == nil {
84+
options = &Options{}
85+
}
8386
if options.AgentConnectionUpdateFrequency == 0 {
8487
options.AgentConnectionUpdateFrequency = 3 * time.Second
8588
}
@@ -117,11 +120,8 @@ func New(options *Options) *API {
117120
if options.TailnetCoordinator == nil {
118121
options.TailnetCoordinator = tailnet.NewCoordinator()
119122
}
120-
if options.LicenseHandler == nil {
121-
options.LicenseHandler = licenses()
122-
}
123-
if options.FeaturesService == nil {
124-
options.FeaturesService = &featuresService{}
123+
if options.Auditor == nil {
124+
options.Auditor = audit.NewNop()
125125
}
126126

127127
siteCacheDir := options.CacheDir
@@ -142,14 +142,16 @@ func New(options *Options) *API {
142142
r := chi.NewRouter()
143143
api := &API{
144144
Options: options,
145-
Handler: r,
145+
RootHandler: r,
146146
siteHandler: site.Handler(site.FS(), binFS),
147-
httpAuth: &HTTPAuthorizer{
147+
HTTPAuth: &HTTPAuthorizer{
148148
Authorizer: options.Authorizer,
149149
Logger: options.Logger,
150150
},
151151
metricsCache: metricsCache,
152+
Auditor: atomic.Pointer[audit.Auditor]{},
152153
}
154+
api.Auditor.Store(&options.Auditor)
153155
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
154156
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
155157
oauthConfigs := &httpmw.OAuth2Configs{
@@ -218,6 +220,8 @@ func New(options *Options) *API {
218220
})
219221

220222
r.Route("/api/v2", func(r chi.Router) {
223+
api.APIHandler = r
224+
221225
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
222226
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
223227
Message: "Route not found.",
@@ -473,14 +477,6 @@ func New(options *Options) *API {
473477
r.Get("/resources", api.workspaceBuildResources)
474478
r.Get("/state", api.workspaceBuildState)
475479
})
476-
r.Route("/entitlements", func(r chi.Router) {
477-
r.Use(apiKeyMiddleware)
478-
r.Get("/", api.FeaturesService.EntitlementsAPI)
479-
})
480-
r.Route("/licenses", func(r chi.Router) {
481-
r.Use(apiKeyMiddleware)
482-
r.Mount("/", options.LicenseHandler)
483-
})
484480
})
485481

486482
r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP)
@@ -489,17 +485,20 @@ func New(options *Options) *API {
489485

490486
type API struct {
491487
*Options
488+
Auditor atomic.Pointer[audit.Auditor]
489+
HTTPAuth *HTTPAuthorizer
492490

493-
derpServer *derp.Server
491+
// APIHandler serves "/api/v2"
492+
APIHandler chi.Router
493+
// RootHandler serves "/"
494+
RootHandler chi.Router
494495

495-
Handler chi.Router
496+
derpServer *derp.Server
497+
metricsCache *metricscache.Cache
496498
siteHandler http.Handler
497499
websocketWaitMutex sync.Mutex
498500
websocketWaitGroup sync.WaitGroup
499501
workspaceAgentCache *wsconncache.Cache
500-
httpAuth *HTTPAuthorizer
501-
502-
metricsCache *metricscache.Cache
503502
}
504503

505504
// Close waits for all WebSocket connections to drain before returning.

coderd/coderd_test.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,6 @@ func TestBuildInfo(t *testing.T) {
3838
require.Equal(t, buildinfo.Version(), buildInfo.Version, "version")
3939
}
4040

41-
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
42-
func TestAuthorizeAllEndpoints(t *testing.T) {
43-
t.Parallel()
44-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
45-
defer cancel()
46-
a := coderdtest.NewAuthTester(ctx, t, nil)
47-
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
48-
a.Test(ctx, assertRoute, skipRoutes)
49-
}
50-
5141
func TestDERP(t *testing.T) {
5242
t.Parallel()
5343
client := coderdtest.New(t, nil)

0 commit comments

Comments
 (0)