Skip to content

feat: add single tailnet support to moons #8587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func New(options *Options) *API {
options.Authorizer,
options.Logger.Named("authz_querier"),
)
experiments := initExperiments(
experiments := ReadExperiments(
options.Logger, options.DeploymentValues.Experiments.Value(),
)
if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil {
Expand Down Expand Up @@ -370,7 +370,9 @@ func New(options *Options) *API {
options.Logger,
options.DERPServer,
options.DERPMap,
&api.TailnetCoordinator,
func(context.Context) (tailnet.MultiAgentConn, error) {
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
},
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
)
if err != nil {
Expand Down Expand Up @@ -1081,7 +1083,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
}

// nolint:revive
func initExperiments(log slog.Logger, raw []string) codersdk.Experiments {
func ReadExperiments(log slog.Logger, raw []string) codersdk.Experiments {
exps := make([]codersdk.Experiment, 0, len(raw))
for _, v := range raw {
switch v {
Expand Down
1 change: 1 addition & 0 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
TemplateScheduleStore: &templateScheduleStore,
TLSCertificates: options.TLSCertificates,
TrialGenerator: options.TrialGenerator,
TailnetCoordinator: options.Coordinator,
DERPMap: derpMap,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
Expand Down
21 changes: 21 additions & 0 deletions coderd/httpapi/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,24 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) {
}
}
}

// Heartbeat loops to ping a WebSocket to keep it alive. It kills the connection
// on ping failure.
func HeartbeatClose(ctx context.Context, exit func(), conn *websocket.Conn) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
err := conn.Ping(ctx)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "Ping failed")
exit()
return
}
}
}
2 changes: 1 addition & 1 deletion coderd/httpmw/groupparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func ExtractGroupParam(db database.Store) func(http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

groupID, parsed := parseUUID(rw, r, "group")
groupID, parsed := ParseUUIDParam(rw, r, "group")
if !parsed {
return
}
Expand Down
4 changes: 2 additions & 2 deletions coderd/httpmw/httpmw.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"github.com/coder/coder/codersdk"
)

// parseUUID consumes a url parameter and parses it as a UUID.
func parseUUID(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) {
// ParseUUIDParam consumes a url parameter and parses it as a UUID.
func ParseUUIDParam(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) {
rawID := chi.URLParam(r, param)
if rawID == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Expand Down
4 changes: 2 additions & 2 deletions coderd/httpmw/httpmw_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestParseUUID_Valid(t *testing.T) {
ctx.URLParams.Add(testParam, testWorkspaceAgentID)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))

parsed, ok := parseUUID(rw, r, "workspaceagent")
parsed, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.True(t, ok, "UUID should be parsed")
assert.Equal(t, testWorkspaceAgentID, parsed.String())
}
Expand All @@ -44,7 +44,7 @@ func TestParseUUID_Invalid(t *testing.T) {
ctx.URLParams.Add(testParam, "wrong-id")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))

_, ok := parseUUID(rw, r, "workspaceagent")
_, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.False(t, ok, "UUID should not be parsed")
assert.Equal(t, http.StatusBadRequest, rw.Code)

Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/organizationparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
orgID, ok := parseUUID(rw, r, "organization")
orgID, ok := ParseUUIDParam(rw, r, "organization")
if !ok {
return
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/templateparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func ExtractTemplateParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateID, parsed := parseUUID(rw, r, "template")
templateID, parsed := ParseUUIDParam(rw, r, "template")
if !parsed {
return
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/templateversionparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersionID, parsed := parseUUID(rw, r, "templateversion")
templateVersionID, parsed := ParseUUIDParam(rw, r, "templateversion")
if !parsed {
return
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/workspaceagentparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
agentUUID, parsed := parseUUID(rw, r, "workspaceagent")
agentUUID, parsed := ParseUUIDParam(rw, r, "workspaceagent")
if !parsed {
return
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/workspacebuildparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handl
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceBuildID, parsed := parseUUID(rw, r, "workspacebuild")
workspaceBuildID, parsed := ParseUUIDParam(rw, r, "workspacebuild")
if !parsed {
return
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/workspaceparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceID, parsed := parseUUID(rw, r, "workspace")
workspaceID, parsed := ParseUUIDParam(rw, r, "workspace")
if !parsed {
return
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/workspaceresourceparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ExtractWorkspaceResourceParam(db database.Store) func(http.Handler) http.Ha
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resourceUUID, parsed := parseUUID(rw, r, "workspaceresource")
resourceUUID, parsed := ParseUUIDParam(rw, r, "workspaceresource")
if !parsed {
return
}
Expand Down
92 changes: 53 additions & 39 deletions coderd/tailnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
)

var tailnetTransport *http.Transport
Expand All @@ -41,7 +42,7 @@ func NewServerTailnet(
logger slog.Logger,
derpServer *derp.Server,
derpMap *tailcfg.DERPMap,
coord *atomic.Pointer[tailnet.Coordinator],
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
cache *wsconncache.Cache,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
Expand All @@ -56,20 +57,23 @@ func NewServerTailnet(

serverCtx, cancel := context.WithCancel(ctx)
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
coord: coord,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
tn.transport.MaxIdleConnsPerHost = 10
tn.transport.MaxIdleConns = 0
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
agentConn, err := getMultiAgent(ctx)
if err != nil {
return nil, xerrors.Errorf("get initial multi agent: %w", err)
}
tn.agentConn.Store(&agentConn)

err = tn.getAgentConn().UpdateSelf(conn.Node())
Expand All @@ -86,19 +90,21 @@ func NewServerTailnet(
// This is set to allow local DERP traffic to be proxied through memory
// instead of needing to hit the external access URL. Don't use the ctx
// given in this callback, it's only valid while connecting.
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
if derpServer != nil {
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
}

go tn.watchAgentUpdates()
go tn.expireOldAgents()
Expand Down Expand Up @@ -167,30 +173,38 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
}

func (s *ServerTailnet) reinitCoordinator() {
s.nodesMu.Lock()
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
s.agentConn.Store(&agentConn)

// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
s.nodesMu.Unlock()
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
continue
}
s.agentConn.Store(&agentConn)

// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.nodesMu.Unlock()
return
}
s.nodesMu.Unlock()
}

type ServerTailnet struct {
ctx context.Context
cancel func()

logger slog.Logger
conn *tailnet.Conn
coord *atomic.Pointer[tailnet.Coordinator]
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
logger slog.Logger
conn *tailnet.Conn
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentNodes is a map of agent tailnetNodes the server wants to keep a
// connection to. It contains the last time the agent was connected to.
agentNodes map[uuid.UUID]time.Time
Expand Down
5 changes: 1 addition & 4 deletions coderd/tailnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/http/httptest"
"net/netip"
"net/url"
"sync/atomic"
"testing"

"github.com/google/uuid"
Expand Down Expand Up @@ -133,9 +132,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
DERPMap: derpMap,
}

var coordPtr atomic.Pointer[tailnet.Coordinator]
coord := tailnet.NewCoordinator(logger)
coordPtr.Store(&coord)
t.Cleanup(func() {
_ = coord.Close()
})
Expand Down Expand Up @@ -194,7 +191,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
logger,
derpServer,
manifest.DERPMap,
&coordPtr,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
cache,
)
require.NoError(t, err)
Expand Down
7 changes: 2 additions & 5 deletions codersdk/agentsdk/agentsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ import (
"time"

"cloud.google.com/go/compute/metadata"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"

"github.com/coder/retry"

"cdr.dev/slog"

"github.com/google/uuid"

"github.com/coder/coder/codersdk"
"github.com/coder/retry"
)

// New returns a client that is used to interact with the
Expand Down
2 changes: 1 addition & 1 deletion enterprise/cli/licenses.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"strings"
"time"

"github.com/google/uuid"
"golang.org/x/xerrors"

"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
"github.com/google/uuid"
)

var jwtRegexp = regexp.MustCompile(`^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`)
Expand Down
2 changes: 2 additions & 0 deletions enterprise/cli/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/coder/coder/cli"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
Expand Down Expand Up @@ -220,6 +221,7 @@ func (*RootCmd) proxyServer() *clibase.Cmd {

proxy, err := wsproxy.New(ctx, &wsproxy.Options{
Logger: logger,
Experiments: coderd.ReadExperiments(logger, cfg.Experiments.Value()),
HTTPClient: httpClient,
DashboardURL: primaryAccessURL.Value(),
AccessURL: cfg.AccessURL.Value(),
Expand Down
Loading