Skip to content

Commit 970b717

Browse files
committed
add required authz system inovcations in enterprise/coderd
1 parent 2a0746d commit 970b717

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

enterprise/coderd/coderd.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"cdr.dev/slog"
1919
"github.com/coder/coder/coderd"
2020
agplaudit "github.com/coder/coder/coderd/audit"
21+
"github.com/coder/coder/coderd/database/dbauthz"
2122
"github.com/coder/coder/coderd/httpapi"
2223
"github.com/coder/coder/coderd/httpmw"
2324
"github.com/coder/coder/coderd/rbac"
@@ -181,7 +182,8 @@ func New(ctx context.Context, options *Options) (*API, error) {
181182
ServerName: options.AccessURL.Hostname(),
182183
}
183184
var err error
184-
api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{
185+
// nolint:gocritic // ReplicaManager needs system permissions.
186+
api.replicaManager, err = replicasync.New(dbauthz.AsSystemRestricted(ctx), options.Logger, options.Database, options.Pubsub, &replicasync.Options{
185187
ID: api.AGPL.ID,
186188
RelayAddress: options.DERPServerRelayAddress,
187189
RegionID: int32(options.DERPServerRegionID),

enterprise/coderd/license/license.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"cdr.dev/slog"
1313

1414
"github.com/coder/coder/coderd/database"
15+
"github.com/coder/coder/coderd/database/dbauthz"
1516
"github.com/coder/coder/codersdk"
1617
)
1718

@@ -39,12 +40,14 @@ func Entitlements(
3940
}
4041
}
4142

42-
licenses, err := db.GetUnexpiredLicenses(ctx)
43+
// nolint:gocritic // Getting unexpired licenses is a system function.
44+
licenses, err := db.GetUnexpiredLicenses(dbauthz.AsSystemRestricted(ctx))
4345
if err != nil {
4446
return entitlements, err
4547
}
4648

47-
activeUserCount, err := db.GetActiveUserCount(ctx)
49+
// nolint:gocritic // Getting active user count is a system function.
50+
activeUserCount, err := db.GetActiveUserCount(dbauthz.AsSystemRestricted(ctx))
4851
if err != nil {
4952
return entitlements, xerrors.Errorf("query active user count: %w", err)
5053
}

enterprise/coderd/provisionerdaemons_test.go

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/coder/coder/coderd/coderdtest"
1313
"github.com/coder/coder/coderd/provisionerdserver"
14+
"github.com/coder/coder/coderd/rbac"
1415
"github.com/coder/coder/codersdk"
1516
"github.com/coder/coder/enterprise/coderd/coderdenttest"
1617
"github.com/coder/coder/enterprise/coderd/license"
@@ -20,6 +21,22 @@ import (
2021

2122
func TestProvisionerDaemonServe(t *testing.T) {
2223
t.Parallel()
24+
t.Run("OK", func(t *testing.T) {
25+
t.Parallel()
26+
client := coderdenttest.New(t, nil)
27+
user := coderdtest.CreateFirstUser(t, client)
28+
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
29+
Features: license.Features{
30+
codersdk.FeatureExternalProvisionerDaemons: 1,
31+
},
32+
})
33+
srv, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{
34+
codersdk.ProvisionerTypeEcho,
35+
}, map[string]string{})
36+
require.NoError(t, err)
37+
srv.DRPCConn().Close()
38+
})
39+
2340
t.Run("NoLicense", func(t *testing.T) {
2441
t.Parallel()
2542
client := coderdenttest.New(t, nil)
@@ -42,11 +59,16 @@ func TestProvisionerDaemonServe(t *testing.T) {
4259
codersdk.FeatureExternalProvisionerDaemons: 1,
4360
},
4461
})
45-
srv, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{
62+
another, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID, rbac.RoleOrgAdmin(user.OrganizationID))
63+
_, err := another.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{
4664
codersdk.ProvisionerTypeEcho,
47-
}, map[string]string{})
48-
require.NoError(t, err)
49-
srv.DRPCConn().Close()
65+
}, map[string]string{
66+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
67+
})
68+
require.Error(t, err)
69+
var apiError *codersdk.Error
70+
require.ErrorAs(t, err, &apiError)
71+
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
5072
})
5173

5274
t.Run("OrganizationNoPerms", func(t *testing.T) {

enterprise/replicasync/replicasync.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919

2020
"github.com/coder/coder/buildinfo"
2121
"github.com/coder/coder/coderd/database"
22-
"github.com/coder/coder/coderd/database/dbauthz"
2322
)
2423

2524
var PubsubEvent = "replica"
@@ -63,7 +62,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data
6362
return nil, xerrors.Errorf("ping database: %w", err)
6463
}
6564
// nolint:gocritic // Inserting a replica is a system function.
66-
replica, err := db.InsertReplica(dbauthz.AsSystemRestricted(ctx), database.InsertReplicaParams{
65+
replica, err := db.InsertReplica(ctx, database.InsertReplicaParams{
6766
ID: options.ID,
6867
CreatedAt: database.Now(),
6968
StartedAt: database.Now(),
@@ -144,7 +143,7 @@ func (m *Manager) loop(ctx context.Context) {
144143
return
145144
case <-deleteTicker.C:
146145
// nolint:gocritic // Deleting a replica is a system function
147-
err := m.db.DeleteReplicasUpdatedBefore(dbauthz.AsSystemRestricted(ctx), m.updateInterval())
146+
err := m.db.DeleteReplicasUpdatedBefore(ctx, m.updateInterval())
148147
if err != nil {
149148
m.logger.Warn(ctx, "delete old replicas", slog.Error(err))
150149
}
@@ -222,7 +221,7 @@ func (m *Manager) syncReplicas(ctx context.Context) error {
222221
// Expect replicas to update once every three times the interval...
223222
// If they don't, assume death!
224223
// nolint:gocritic // Reading replicas is a system function
225-
replicas, err := m.db.GetReplicasUpdatedAfter(dbauthz.AsSystemRestricted(ctx), m.updateInterval())
224+
replicas, err := m.db.GetReplicasUpdatedAfter(ctx, m.updateInterval())
226225
if err != nil {
227226
return xerrors.Errorf("get replicas: %w", err)
228227
}
@@ -280,6 +279,7 @@ func (m *Manager) syncReplicas(ctx context.Context) error {
280279

281280
m.mutex.Lock()
282281
defer m.mutex.Unlock()
282+
// nolint:gocritic // Updating a replica is a system function.
283283
replica, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{
284284
ID: m.self.ID,
285285
UpdatedAt: database.Now(),
@@ -371,7 +371,7 @@ func (m *Manager) Close() error {
371371
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
372372
defer cancelFunc()
373373
// nolint:gocritic // Updating a replica is a sytsem function.
374-
_, err := m.db.UpdateReplica(dbauthz.AsSystemRestricted(ctx), database.UpdateReplicaParams{
374+
_, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{
375375
ID: m.self.ID,
376376
UpdatedAt: database.Now(),
377377
StartedAt: m.self.StartedAt,

0 commit comments

Comments
 (0)