Skip to content

feat: add flag to disable all direct connections instance-wide #7936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 21, 2023
Prev Previous commit
Next Next commit
z
  • Loading branch information
deansheather committed Jun 10, 2023
commit 76e2200af3d8165260e77f1a8e82f0da0def0954
78 changes: 42 additions & 36 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
}

stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
stunAddr.IP = net.ParseIP("127.0.0.1")
t.Cleanup(stunCleanup)

derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp").Leveled(slog.LevelDebug)))
Expand All @@ -306,6 +307,29 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
require.NoError(t, err)
}

region := &tailcfg.DERPRegion{
EmbeddedRelay: true,
RegionID: int(options.DeploymentValues.DERP.Server.RegionID.Value()),
RegionCode: options.DeploymentValues.DERP.Server.RegionCode.String(),
RegionName: options.DeploymentValues.DERP.Server.RegionName.String(),
Nodes: []*tailcfg.DERPNode{{
Name: fmt.Sprintf("%db", options.DeploymentValues.DERP.Server.RegionID),
RegionID: int(options.DeploymentValues.DERP.Server.RegionID.Value()),
IPv4: "127.0.0.1",
DERPPort: derpPort,
// STUN port is added as a separate node by tailnet.NewDERPMap() if
// direct connections are enabled.
STUNPort: -1,
InsecureForTests: true,
ForceHTTP: options.TLSCertificates == nil,
}},
}
if !options.DeploymentValues.DERP.Server.Enable.Value() {
region = nil
}
derpMap, err := tailnet.NewDERPMap(ctx, region, []string{stunAddr.String()}, "", "", !options.DeploymentValues.DERP.Config.DisableDirect.Value())
require.NoError(t, err)

return func(h http.Handler) {
mutex.Lock()
defer mutex.Unlock()
Expand All @@ -324,42 +348,24 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
Pubsub: options.Pubsub,
GitAuthConfigs: options.GitAuthConfigs,

Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
RealIPConfig: options.RealIPConfig,
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
DERPServer: derpServer,
APIRateLimit: options.APIRateLimit,
LoginRateLimit: options.LoginRateLimit,
FilesRateLimit: options.FilesRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
TemplateScheduleStore: &templateScheduleStore,
TLSCertificates: options.TLSCertificates,
TrialGenerator: options.TrialGenerator,
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
EmbeddedRelay: true,
RegionID: 1,
RegionCode: "coder",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
IPv4: "127.0.0.1",
DERPPort: derpPort,
STUNPort: stunAddr.Port,
InsecureForTests: true,
ForceHTTP: options.TLSCertificates == nil,
}},
},
},
},
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
RealIPConfig: options.RealIPConfig,
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
DERPServer: derpServer,
APIRateLimit: options.APIRateLimit,
LoginRateLimit: options.LoginRateLimit,
FilesRateLimit: options.FilesRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
TemplateScheduleStore: &templateScheduleStore,
TLSCertificates: options.TLSCertificates,
TrialGenerator: options.TrialGenerator,
DERPMap: derpMap,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentValues: options.DeploymentValues,
Expand Down
4 changes: 2 additions & 2 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
StartupScriptTimeout: time.Duration(apiAgent.StartupScriptTimeoutSeconds) * time.Second,
ShutdownScript: apiAgent.ShutdownScript,
ShutdownScriptTimeout: time.Duration(apiAgent.ShutdownScriptTimeoutSeconds) * time.Second,
AllowDirectConnections: api.DeploymentValues.DERP.Config.DisableDirect.Value(),
AllowDirectConnections: !api.DeploymentValues.DERP.Config.DisableDirect.Value(),
Metadata: convertWorkspaceAgentMetadataDesc(metadata),
})
}
Expand Down Expand Up @@ -736,7 +736,7 @@ func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request

httpapi.Write(ctx, rw, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
DERPMap: api.DERPMap,
AllowDirectConnections: api.DeploymentValues.DERP.Config.DisableDirect.Value(),
AllowDirectConnections: !api.DeploymentValues.DERP.Config.DisableDirect.Value(),
})
}

Expand Down
77 changes: 77 additions & 0 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package coderd_test

import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -444,6 +445,82 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
require.Equal(t, "test", strings.TrimSpace(string(output)))
}

func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
t.Parallel()

dv := coderdtest.DeploymentValues(t)
err := dv.DERP.Config.DisableDirect.Set("true")
require.NoError(t, err)
require.True(t, dv.DERP.Config.DisableDirect.Value())

client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{
DeploymentValues: dv,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.ProvisionComplete,
ProvisionApply: echo.ProvisionApplyWithAgent(authToken),
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()

ctx := testutil.Context(t, testutil.WaitLong)

// Verify that the manifest has AllowDirectConnections set to false.
agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(authToken)
manifest, err := agentClient.Manifest(ctx)
require.NoError(t, err)
require.False(t, manifest.AllowDirectConnections)

agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer agentCloser.Close()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
agentID := resources[0].Agents[0].ID

// Verify that the connection data has no STUN ports and
// AllowDirectConnections set to false.
res, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
var connInfo codersdk.WorkspaceAgentConnectionInfo
err = json.NewDecoder(res.Body).Decode(&connInfo)
require.NoError(t, err)
require.False(t, connInfo.AllowDirectConnections)
for _, region := range connInfo.DERPMap.Regions {
t.Logf("region %s (%v)", region.RegionCode, region.EmbeddedRelay)
for _, node := range region.Nodes {
t.Logf(" node %s (stun %d)", node.Name, node.STUNPort)
require.EqualValues(t, -1, node.STUNPort)
// tailnet.NewDERPMap() will create nodes with "stun" in the name,
// but not if direct is disabled.
require.NotContains(t, node.Name, "stun")
require.False(t, node.STUNOnly)
}
}

conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
defer conn.Close()
require.True(t, conn.BlockEndpoints())

require.True(t, conn.AwaitReachable(ctx))
_, p2p, _, err := conn.Ping(ctx)
require.NoError(t, err)
require.False(t, p2p)
}

func TestWorkspaceAgentListeningPorts(t *testing.T) {
t.Parallel()

Expand Down
7 changes: 7 additions & 0 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
return c.netMap.DERPMap
}

// BlockEndpoints returns whether or not P2P is blocked.
func (c *Conn) BlockEndpoints() bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.blockEndpoints
}

// AwaitReachable pings the provided IP continually until the
// address is reachable. It's the callers responsibility to provide
// a timeout, otherwise this function will block forever.
Expand Down
5 changes: 5 additions & 0 deletions tailnet/derpmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,14 @@ func NewDERPMap(ctx context.Context, region *tailcfg.DERPRegion, stunAddrs []str
}
if !allowSTUN {
for _, region := range derpMap.Regions {
newNodes := make([]*tailcfg.DERPNode, 0, len(region.Nodes))
for _, node := range region.Nodes {
node.STUNPort = -1
if !node.STUNOnly {
newNodes = append(newNodes, node)
}
}
region.Nodes = newNodes
}
}

Expand Down
7 changes: 7 additions & 0 deletions tailnet/derpmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func TestNewDERPMap(t *testing.T) {
{
STUNPort: 12345,
},
{
STUNOnly: true,
STUNPort: 54321,
},
},
},
},
Expand All @@ -113,9 +117,12 @@ func TestNewDERPMap(t *testing.T) {

require.Len(t, derpMap.Regions[1].Nodes, 1)
require.EqualValues(t, -1, derpMap.Regions[1].Nodes[0].STUNPort)
// The STUNOnly node should get removed.
require.Len(t, derpMap.Regions[2].Nodes, 2)
require.EqualValues(t, -1, derpMap.Regions[2].Nodes[0].STUNPort)
require.False(t, derpMap.Regions[2].Nodes[0].STUNOnly)
require.EqualValues(t, -1, derpMap.Regions[2].Nodes[1].STUNPort)
require.False(t, derpMap.Regions[2].Nodes[1].STUNOnly)
// We don't add any nodes ourselves if STUN is disabled.
require.Len(t, derpMap.Regions[3].Nodes, 1)
// ... but we still remove the STUN port from existing nodes in the
Expand Down