Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,20 @@ func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) {
return q.db.GetActiveUserCount(ctx)
}

func (q *querier) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAgent, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return []database.TailnetAgent{}, err
}
return q.db.GetAllTailnetAgents(ctx)
}

func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return []database.TailnetClient{}, err
}
return q.db.GetAllTailnetClients(ctx)
}

func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) {
// No authz checks
return q.db.GetAppSecurityKey(ctx)
Expand Down
8 changes: 8 additions & 0 deletions coderd/database/dbfake/dbfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,14 @@ func (q *FakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
return active, nil
}

func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAgent, error) {
return nil, ErrUnimplemented
}

func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.TailnetClient, error) {
return nil, ErrUnimplemented
}

func (q *FakeQuerier) GetAppSecurityKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
Expand Down
14 changes: 14 additions & 0 deletions coderd/database/dbmetrics/dbmetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 68 additions & 0 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions coderd/database/queries/tailnet.sql
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,20 @@ SELECT *
FROM tailnet_agents
WHERE id = $1;

-- name: GetAllTailnetAgents :many
SELECT *
FROM tailnet_agents;

-- name: GetTailnetClientsForAgent :many
SELECT *
FROM tailnet_clients
WHERE agent_id = $1;

-- name: GetAllTailnetClients :many
SELECT *
FROM tailnet_clients
ORDER BY agent_id;

-- name: UpsertTailnetCoordinator :one
INSERT INTO
tailnet_coordinators (
Expand Down
4 changes: 3 additions & 1 deletion enterprise/tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,5 +704,7 @@ func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
c.mutex.RLock()
defer c.mutex.RUnlock()

agpl.CoordinatorHTTPDebug(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache)(w, r)
agpl.CoordinatorHTTPDebug(
agpl.HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
)(w, r)
}
125 changes: 116 additions & 9 deletions enterprise/tailnet/pgcoord.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"nhooyr.io/websocket"

Expand Down Expand Up @@ -307,6 +308,9 @@ type binding struct {
node *agpl.Node
}

func (b *binding) isAgent() bool { return b.client == uuid.Nil }
func (b *binding) isClient() bool { return b.client != uuid.Nil }

// binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff.
type binder struct {
ctx context.Context
Expand Down Expand Up @@ -386,19 +390,19 @@ func (b *binder) writeOne(bnd binding) error {
if err != nil {
// this is very bad news, but it should never happen because the node was Unmarshalled by this process
// earlier.
b.logger.Error(b.ctx, "failed to marshall node", slog.Error(err))
b.logger.Error(b.ctx, "failed to marshal node", slog.Error(err))
return err
}
}

switch {
case bnd.client == uuid.Nil && len(nodeRaw) > 0:
case bnd.isAgent() && len(nodeRaw) > 0:
_, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{
ID: bnd.agent,
CoordinatorID: b.coordinatorID,
Node: nodeRaw,
})
case bnd.client == uuid.Nil && len(nodeRaw) == 0:
case bnd.isAgent() && len(nodeRaw) == 0:
_, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{
ID: bnd.agent,
CoordinatorID: b.coordinatorID,
Expand All @@ -407,14 +411,14 @@ func (b *binder) writeOne(bnd binding) error {
// treat deletes as idempotent
err = nil
}
case bnd.client != uuid.Nil && len(nodeRaw) > 0:
case bnd.isClient() && len(nodeRaw) > 0:
_, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{
ID: bnd.client,
CoordinatorID: b.coordinatorID,
AgentID: bnd.agent,
Node: nodeRaw,
})
case bnd.client != uuid.Nil && len(nodeRaw) == 0:
case bnd.isClient() && len(nodeRaw) == 0:
_, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{
ID: bnd.client,
CoordinatorID: b.coordinatorID,
Expand Down Expand Up @@ -927,6 +931,27 @@ func (q *querier) updateAll() {
}
}

func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAgent, map[uuid.UUID][]database.TailnetClient, error) {
agents, err := q.store.GetAllTailnetAgents(ctx)
if err != nil {
return nil, nil, xerrors.Errorf("get all tailnet agents: %w", err)
}
agentsMap := map[uuid.UUID]database.TailnetAgent{}
for _, agent := range agents {
agentsMap[agent.ID] = agent
}
clients, err := q.store.GetAllTailnetClients(ctx)
if err != nil {
return nil, nil, xerrors.Errorf("get all tailnet clients: %w", err)
}
clientsMap := map[uuid.UUID][]database.TailnetClient{}
for _, client := range clients {
clientsMap[client.AgentID] = append(clientsMap[client.AgentID], client)
}

return agentsMap, clientsMap, nil
}

func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) {
parts := strings.Split(msg, ",")
if len(parts) != 2 {
Expand Down Expand Up @@ -1289,8 +1314,90 @@ func (h *heartbeats) cleanup() {
h.logger.Debug(h.ctx, "cleaned up old coordinators")
}

func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) {
// TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed.
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Coder Enterprise PostgreSQL distributed tailnet coordinator"))
func (c *pgCoord) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
debug, err := c.htmlDebug(ctx)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(err.Error()))
return
}

agpl.CoordinatorHTTPDebug(debug)(w, r)
}

func (c *pgCoord) htmlDebug(ctx context.Context) (agpl.HTMLDebug, error) {
now := time.Now()
data := agpl.HTMLDebug{}
agents, clients, err := c.querier.getAll(ctx)
if err != nil {
return data, xerrors.Errorf("get all agents and clients: %w", err)
}

for _, agent := range agents {
htmlAgent := &agpl.HTMLAgent{
ID: agent.ID,
// Name: ??, TODO: get agent names
LastWriteAge: now.Sub(agent.UpdatedAt).Round(time.Second),
}
for _, conn := range clients[agent.ID] {
htmlAgent.Connections = append(htmlAgent.Connections, &agpl.HTMLClient{
ID: conn.ID,
Name: conn.ID.String(),
LastWriteAge: now.Sub(conn.UpdatedAt).Round(time.Second),
})
data.Nodes = append(data.Nodes, &agpl.HTMLNode{
ID: conn.ID,
Node: conn.Node,
})
}
slices.SortFunc(htmlAgent.Connections, func(a, b *agpl.HTMLClient) bool {
return a.Name < b.Name
})

data.Agents = append(data.Agents, htmlAgent)
data.Nodes = append(data.Nodes, &agpl.HTMLNode{
ID: agent.ID,
// Name: ??, TODO: get agent names
Node: agent.Node,
})
}
slices.SortFunc(data.Agents, func(a, b *agpl.HTMLAgent) bool {
return a.Name < b.Name
})

for agentID, conns := range clients {
if len(conns) == 0 {
continue
}

if _, ok := agents[agentID]; ok {
continue
}
agent := &agpl.HTMLAgent{
Name: "unknown",
ID: agentID,
}
for _, conn := range conns {
agent.Connections = append(agent.Connections, &agpl.HTMLClient{
Name: conn.ID.String(),
ID: conn.ID,
LastWriteAge: now.Sub(conn.UpdatedAt).Round(time.Second),
})
data.Nodes = append(data.Nodes, &agpl.HTMLNode{
ID: conn.ID,
Node: conn.Node,
})
}
slices.SortFunc(agent.Connections, func(a, b *agpl.HTMLClient) bool {
return a.Name < b.Name
})

data.MissingAgents = append(data.MissingAgents, agent)
}
slices.SortFunc(data.MissingAgents, func(a, b *agpl.HTMLAgent) bool {
return a.Name < b.Name
})

return data, nil
}
Loading