diff --git a/agent/agent.go b/agent/agent.go index 066d91a66684a..3bdc46bf64e04 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -47,7 +47,7 @@ const ( type Options struct { EnableWireguard bool - PostPublicKeys PostKeys + UploadWireguardKeys UploadWireguardKeys ListenWireguardPeers ListenWireguardPeers ReconnectingPTYTimeout time.Duration EnvironmentVariables map[string]string @@ -55,7 +55,7 @@ type Options struct { } type Metadata struct { - Addresses []netaddr.IPPrefix `json:"addresses"` + WireguardAddresses []netaddr.IPPrefix `json:"addresses"` OwnerEmail string `json:"owner_email"` OwnerUsername string `json:"owner_username"` EnvironmentVariables map[string]string `json:"environment_variables"` @@ -63,14 +63,14 @@ type Metadata struct { Directory string `json:"directory"` } -type PublicKeys struct { +type WireguardPublicKeys struct { Public key.NodePublic `json:"public"` Disco key.DiscoPublic `json:"disco"` } type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error) -type PostKeys func(ctx context.Context, keys PublicKeys) error -type ListenWireguardPeers func(ctx context.Context, logger slog.Logger) (<-chan peerwg.WireguardPeerMessage, func(), error) +type UploadWireguardKeys func(ctx context.Context, keys WireguardPublicKeys) error +type ListenWireguardPeers func(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error) func New(dialer Dialer, options *Options) io.Closer { if options == nil { @@ -88,7 +88,7 @@ func New(dialer Dialer, options *Options) io.Closer { closed: make(chan struct{}), envVars: options.EnvironmentVariables, enableWireguard: options.EnableWireguard, - postKeys: options.PostPublicKeys, + postKeys: options.UploadWireguardKeys, listenWireguardPeers: options.ListenWireguardPeers, } server.init(ctx) @@ -114,8 +114,8 @@ type agent struct { sshServer *ssh.Server enableWireguard bool - wg *peerwg.WireguardNetwork - postKeys PostKeys + network *peerwg.Network + postKeys UploadWireguardKeys listenWireguardPeers ListenWireguardPeers } @@ -160,9 +160,11 @@ func (a *agent) run(ctx context.Context) { }() } - err = a.startWireguard(ctx, metadata.Addresses) - if err != nil { - a.logger.Error(ctx, "start wireguard", slog.Error(err)) + if a.enableWireguard { + err = a.startWireguard(ctx, metadata.WireguardAddresses) + if err != nil { + a.logger.Error(ctx, "start wireguard", slog.Error(err)) + } } for { diff --git a/agent/wireguard.go b/agent/wireguard.go index 42121d11addbb..3b213bf34c004 100644 --- a/agent/wireguard.go +++ b/agent/wireguard.go @@ -11,13 +11,9 @@ import ( ) func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) error { - if a.wg != nil { - _ = a.wg.Close() - a.wg = nil - } - - if !a.enableWireguard { - return nil + if a.network != nil { + _ = a.network.Close() + a.network = nil } // We can't create a wireguard network without these. @@ -25,14 +21,16 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er return xerrors.New("wireguard is enabled, but no addresses were provided or necessary functions were not provided") } - wg, err := peerwg.NewWireguardNetwork(ctx, a.logger.Named("wireguard"), addrs) + wg, err := peerwg.New(a.logger.Named("wireguard"), addrs) if err != nil { return xerrors.Errorf("create wireguard network: %w", err) } - err = a.postKeys(ctx, PublicKeys{ - Public: wg.Private.Public(), - Disco: wg.Disco, + // A new keypair is generated on each agent start. + // This keypair must be sent to Coder to allow for incoming connections. + err = a.postKeys(ctx, WireguardPublicKeys{ + Public: wg.NodePrivateKey.Public(), + Disco: wg.DiscoPublicKey, }) if err != nil { a.logger.Warn(ctx, "post keys", slog.Error(err)) @@ -53,13 +51,13 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er } err := wg.AddPeer(peer) - a.logger.Info(ctx, "added wireguard peer", slog.F("peer", peer.Public.ShortString()), slog.Error(err)) + a.logger.Info(ctx, "added wireguard peer", slog.F("peer", peer.NodePublicKey.ShortString()), slog.Error(err)) } listenClose() } }() - a.wg = wg + a.network = wg return nil } diff --git a/cli/agent.go b/cli/agent.go index 395768d70dbdb..7c9daa8653961 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -178,7 +178,7 @@ func workspaceAgent() *cobra.Command { "CODER_AGENT_TOKEN": client.SessionToken, }, EnableWireguard: wireguard, - PostPublicKeys: client.PostWorkspaceAgentKeys, + UploadWireguardKeys: client.UploadWorkspaceAgentKeys, ListenWireguardPeers: client.WireguardPeerListener, }) <-cmd.Context().Done() diff --git a/cli/wireguardtunnel.go b/cli/wireguardtunnel.go index 1b8acfb511619..d50d6e8dfad40 100644 --- a/cli/wireguardtunnel.go +++ b/cli/wireguardtunnel.go @@ -100,7 +100,7 @@ func wireguardPortForward() *cobra.Command { } ipv6 := peerwg.UUIDToNetaddr(uuid.New()) - wgn, err := peerwg.NewWireguardNetwork(cmd.Context(), + wgn, err := peerwg.New( slog.Make(sloghuman.Sink(os.Stderr)), []netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)}, ) @@ -108,21 +108,21 @@ func wireguardPortForward() *cobra.Command { return xerrors.Errorf("create wireguard network: %w", err) } - err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.WireguardPeerMessage{ - Recipient: workspaceAgent.ID, - Public: wgn.Private.Public(), - Disco: wgn.Disco, - IPv6: ipv6, + err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ + Recipient: workspaceAgent.ID, + NodePublicKey: wgn.NodePrivateKey.Public(), + DiscoPublicKey: wgn.DiscoPublicKey, + IPv6: ipv6, }) if err != nil { return xerrors.Errorf("post wireguard peer: %w", err) } - err = wgn.AddPeer(peerwg.WireguardPeerMessage{ - Recipient: workspaceAgent.ID, - Disco: workspaceAgent.DiscoPublicKey, - Public: workspaceAgent.WireguardPublicKey, - IPv6: workspaceAgent.IPv6.IP(), + err = wgn.AddPeer(peerwg.Handshake{ + Recipient: workspaceAgent.ID, + DiscoPublicKey: workspaceAgent.DiscoPublicKey, + NodePublicKey: workspaceAgent.WireguardPublicKey, + IPv6: workspaceAgent.IPv6.IP(), }) if err != nil { return xerrors.Errorf("add workspace agent as peer: %w", err) @@ -177,6 +177,8 @@ func wireguardPortForward() *cobra.Command { }, } + // Hide all wireguard commands for now while we test! + cmd.Hidden = true cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine") cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port") @@ -185,7 +187,7 @@ func wireguardPortForward() *cobra.Command { } func listenAndPortForwardWireguard(ctx context.Context, cmd *cobra.Command, - wgn *peerwg.WireguardNetwork, + wgn *peerwg.Network, wg *sync.WaitGroup, spec portForwardSpec, agentIP netaddr.IP, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 4c3549dca9f16..ce234ecf5a55d 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1599,23 +1599,23 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser defer q.mutex.Unlock() agent := database.WorkspaceAgent{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - ResourceID: arg.ResourceID, - AuthToken: arg.AuthToken, - AuthInstanceID: arg.AuthInstanceID, - EnvironmentVariables: arg.EnvironmentVariables, - Name: arg.Name, - Architecture: arg.Architecture, - OperatingSystem: arg.OperatingSystem, - Directory: arg.Directory, - StartupScript: arg.StartupScript, - InstanceMetadata: arg.InstanceMetadata, - ResourceMetadata: arg.ResourceMetadata, - Ipv6: arg.Ipv6, - WireguardPublicKey: arg.WireguardPublicKey, - DiscoPublicKey: arg.DiscoPublicKey, + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + ResourceID: arg.ResourceID, + AuthToken: arg.AuthToken, + AuthInstanceID: arg.AuthInstanceID, + EnvironmentVariables: arg.EnvironmentVariables, + Name: arg.Name, + Architecture: arg.Architecture, + OperatingSystem: arg.OperatingSystem, + Directory: arg.Directory, + StartupScript: arg.StartupScript, + InstanceMetadata: arg.InstanceMetadata, + ResourceMetadata: arg.ResourceMetadata, + WireguardNodeIPv6: arg.WireguardNodeIPv6, + WireguardNodePublicKey: arg.WireguardNodePublicKey, + WireguardDiscoPublicKey: arg.WireguardDiscoPublicKey, } q.provisionerJobAgents = append(q.provisionerJobAgents, agent) @@ -1920,8 +1920,8 @@ func (q *fakeQuerier) UpdateWorkspaceAgentKeysByID(_ context.Context, arg databa continue } - agent.WireguardPublicKey = arg.WireguardPublicKey - agent.DiscoPublicKey = arg.DiscoPublicKey + agent.WireguardNodePublicKey = arg.WireguardNodePublicKey + agent.WireguardDiscoPublicKey = arg.WireguardDiscoPublicKey agent.UpdatedAt = database.Now() q.provisionerJobAgents[index] = agent return nil diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index abb435c413f29..73631fe319ad8 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -293,9 +293,9 @@ CREATE TABLE workspace_agents ( instance_metadata jsonb, resource_metadata jsonb, directory character varying(4096) DEFAULT ''::character varying NOT NULL, - ipv6 inet DEFAULT '::'::inet NOT NULL, - wireguard_public_key character varying(128) DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL, - disco_public_key character varying(128) DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL + wireguard_node_ipv6 inet DEFAULT '::'::inet NOT NULL, + wireguard_node_public_key character varying(128) DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL, + wireguard_disco_public_key character varying(128) DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL ); CREATE TABLE workspace_apps ( diff --git a/coderd/database/migrations/000028_wireguard.down.sql b/coderd/database/migrations/000028_wireguard.down.sql index 8beb4adba9f62..a467217fbbbff 100644 --- a/coderd/database/migrations/000028_wireguard.down.sql +++ b/coderd/database/migrations/000028_wireguard.down.sql @@ -1,4 +1,4 @@ ALTER TABLE workspace_agents - DROP COLUMN ipv6, - DROP COLUMN wireguard_public_key, + DROP COLUMN wireguard_ipv6, + DROP COLUMN node_public_key, DROP COLUMN disco_public_key; diff --git a/coderd/database/migrations/000028_wireguard.up.sql b/coderd/database/migrations/000028_wireguard.up.sql index e15bd97e54731..202656b446106 100644 --- a/coderd/database/migrations/000028_wireguard.up.sql +++ b/coderd/database/migrations/000028_wireguard.up.sql @@ -1,4 +1,4 @@ ALTER TABLE workspace_agents - ADD COLUMN ipv6 inet NOT NULL DEFAULT '::/128', - ADD COLUMN wireguard_public_key varchar(128) NOT NULL DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000', - ADD COLUMN disco_public_key varchar(128) NOT NULL DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'; + ADD COLUMN wireguard_node_ipv6 inet NOT NULL DEFAULT '::/128', + ADD COLUMN wireguard_node_public_key varchar(128) NOT NULL DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000', + ADD COLUMN wireguard_disco_public_key varchar(128) NOT NULL DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'; diff --git a/coderd/database/models.go b/coderd/database/models.go index 836f35b5b7f6d..8664803f7d98e 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -503,26 +503,26 @@ type Workspace struct { } type WorkspaceAgent struct { - ID uuid.UUID `db:"id" json:"id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Name string `db:"name" json:"name"` - FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"` - LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"` - DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"` - ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` - AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` - AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` - Architecture string `db:"architecture" json:"architecture"` - EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` - OperatingSystem string `db:"operating_system" json:"operating_system"` - StartupScript sql.NullString `db:"startup_script" json:"startup_script"` - InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` - ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` - Directory string `db:"directory" json:"directory"` - Ipv6 pqtype.Inet `db:"ipv6" json:"ipv6"` - WireguardPublicKey string `db:"wireguard_public_key" json:"wireguard_public_key"` - DiscoPublicKey string `db:"disco_public_key" json:"disco_public_key"` + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"` + LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"` + DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"` + ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` + AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` + AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` + Architecture string `db:"architecture" json:"architecture"` + EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` + OperatingSystem string `db:"operating_system" json:"operating_system"` + StartupScript sql.NullString `db:"startup_script" json:"startup_script"` + InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` + ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` + Directory string `db:"directory" json:"directory"` + WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"` + WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` + WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` } type WorkspaceApp struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index faee385ca4a5e..cbb7fc51d69cf 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2843,7 +2843,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, ipv6, wireguard_public_key, disco_public_key + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE @@ -2873,16 +2873,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.Ipv6, - &i.WireguardPublicKey, - &i.DiscoPublicKey, + &i.WireguardNodeIPv6, + &i.WireguardNodePublicKey, + &i.WireguardDiscoPublicKey, ) return i, err } const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, ipv6, wireguard_public_key, disco_public_key + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE @@ -2910,16 +2910,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.Ipv6, - &i.WireguardPublicKey, - &i.DiscoPublicKey, + &i.WireguardNodeIPv6, + &i.WireguardNodePublicKey, + &i.WireguardDiscoPublicKey, ) return i, err } const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, ipv6, wireguard_public_key, disco_public_key + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE @@ -2949,16 +2949,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.Ipv6, - &i.WireguardPublicKey, - &i.DiscoPublicKey, + &i.WireguardNodeIPv6, + &i.WireguardNodePublicKey, + &i.WireguardDiscoPublicKey, ) return i, err } const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, ipv6, wireguard_public_key, disco_public_key + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE @@ -2992,9 +2992,9 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids [] &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.Ipv6, - &i.WireguardPublicKey, - &i.DiscoPublicKey, + &i.WireguardNodeIPv6, + &i.WireguardNodePublicKey, + &i.WireguardDiscoPublicKey, ); err != nil { return nil, err } @@ -3010,7 +3010,7 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids [] } const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many -SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, ipv6, wireguard_public_key, disco_public_key FROM workspace_agents WHERE created_at > $1 +SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE created_at > $1 ` func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error) { @@ -3040,9 +3040,9 @@ func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, created &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.Ipv6, - &i.WireguardPublicKey, - &i.DiscoPublicKey, + &i.WireguardNodeIPv6, + &i.WireguardNodePublicKey, + &i.WireguardDiscoPublicKey, ); err != nil { return nil, err } @@ -3074,32 +3074,32 @@ INSERT INTO directory, instance_metadata, resource_metadata, - ipv6, - wireguard_public_key, - disco_public_key + wireguard_node_ipv6, + wireguard_node_public_key, + wireguard_disco_public_key ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, ipv6, wireguard_public_key, disco_public_key + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key ` type InsertWorkspaceAgentParams struct { - ID uuid.UUID `db:"id" json:"id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Name string `db:"name" json:"name"` - ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` - AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` - AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` - Architecture string `db:"architecture" json:"architecture"` - EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` - OperatingSystem string `db:"operating_system" json:"operating_system"` - StartupScript sql.NullString `db:"startup_script" json:"startup_script"` - Directory string `db:"directory" json:"directory"` - InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` - ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` - Ipv6 pqtype.Inet `db:"ipv6" json:"ipv6"` - WireguardPublicKey string `db:"wireguard_public_key" json:"wireguard_public_key"` - DiscoPublicKey string `db:"disco_public_key" json:"disco_public_key"` + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + ResourceID uuid.UUID `db:"resource_id" json:"resource_id"` + AuthToken uuid.UUID `db:"auth_token" json:"auth_token"` + AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"` + Architecture string `db:"architecture" json:"architecture"` + EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"` + OperatingSystem string `db:"operating_system" json:"operating_system"` + StartupScript sql.NullString `db:"startup_script" json:"startup_script"` + Directory string `db:"directory" json:"directory"` + InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` + ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` + WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"` + WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` + WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` } func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) { @@ -3118,9 +3118,9 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa arg.Directory, arg.InstanceMetadata, arg.ResourceMetadata, - arg.Ipv6, - arg.WireguardPublicKey, - arg.DiscoPublicKey, + arg.WireguardNodeIPv6, + arg.WireguardNodePublicKey, + arg.WireguardDiscoPublicKey, ) var i WorkspaceAgent err := row.Scan( @@ -3141,9 +3141,9 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa &i.InstanceMetadata, &i.ResourceMetadata, &i.Directory, - &i.Ipv6, - &i.WireguardPublicKey, - &i.DiscoPublicKey, + &i.WireguardNodeIPv6, + &i.WireguardNodePublicKey, + &i.WireguardDiscoPublicKey, ) return i, err } @@ -3182,20 +3182,20 @@ UPDATE workspace_agents SET updated_at = now(), - wireguard_public_key = $2, - disco_public_key = $3 + wireguard_node_public_key = $2, + wireguard_disco_public_key = $3 WHERE id = $1 ` type UpdateWorkspaceAgentKeysByIDParams struct { - ID uuid.UUID `db:"id" json:"id"` - WireguardPublicKey string `db:"wireguard_public_key" json:"wireguard_public_key"` - DiscoPublicKey string `db:"disco_public_key" json:"disco_public_key"` + ID uuid.UUID `db:"id" json:"id"` + WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` + WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` } func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error { - _, err := q.db.ExecContext(ctx, updateWorkspaceAgentKeysByID, arg.ID, arg.WireguardPublicKey, arg.DiscoPublicKey) + _, err := q.db.ExecContext(ctx, updateWorkspaceAgentKeysByID, arg.ID, arg.WireguardNodePublicKey, arg.WireguardDiscoPublicKey) return err } diff --git a/coderd/database/queries/workspaceagents.sql b/coderd/database/queries/workspaceagents.sql index 922f05733217a..37e7754739311 100644 --- a/coderd/database/queries/workspaceagents.sql +++ b/coderd/database/queries/workspaceagents.sql @@ -54,9 +54,9 @@ INSERT INTO directory, instance_metadata, resource_metadata, - ipv6, - wireguard_public_key, - disco_public_key + wireguard_node_ipv6, + wireguard_node_public_key, + wireguard_disco_public_key ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING *; @@ -77,7 +77,7 @@ UPDATE workspace_agents SET updated_at = now(), - wireguard_public_key = $2, - disco_public_key = $3 + wireguard_node_public_key = $2, + wireguard_disco_public_key = $3 WHERE id = $1; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 33fc8a8399dcf..8827cc76083b3 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -34,3 +34,4 @@ rename: gitsshkey: GitSSHKey rbac_roles: RBACRoles ip_address: IPAddress + wireguard_node_ipv6: WireguardNodeIPv6 diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 416584e5275ba..c9dfb92986848 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -746,7 +746,6 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid. agentID := uuid.New() dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{ ID: agentID, - Ipv6: peerwg.UUIDToInet(agentID), CreatedAt: database.Now(), UpdatedAt: database.Now(), ResourceID: resource.ID, @@ -761,8 +760,9 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid. String: prAgent.StartupScript, Valid: prAgent.StartupScript != "", }, - WireguardPublicKey: key.NodePublic{}.String(), - DiscoPublicKey: key.DiscoPublic{}.String(), + WireguardNodeIPv6: peerwg.UUIDToInet(agentID), + WireguardNodePublicKey: key.NodePublic{}.String(), + WireguardDiscoPublicKey: key.DiscoPublic{}.String(), }) if err != nil { return xerrors.Errorf("insert agent: %w", err) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6bba26298f7e1..af23d11dc7922 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -161,17 +161,17 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) return } - ipp, ok := netaddr.FromStdIPNet(&workspaceAgent.Ipv6.IPNet) + ipp, ok := netaddr.FromStdIPNet(&workspaceAgent.WireguardNodeIPv6.IPNet) if !ok { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: "Workspace agent has an invalid ipv6 address.", - Detail: workspaceAgent.Ipv6.IPNet.String(), + Detail: workspaceAgent.WireguardNodeIPv6.IPNet.String(), }) return } httpapi.Write(rw, http.StatusOK, agent.Metadata{ - Addresses: []netaddr.IPPrefix{ipp}, + WireguardAddresses: []netaddr.IPPrefix{ipp}, OwnerEmail: owner.Email, OwnerUsername: owner.Username, EnvironmentVariables: apiAgent.EnvironmentVariables, @@ -487,9 +487,9 @@ func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) } err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{ - ID: workspaceAgent.ID, - WireguardPublicKey: keys.Public.String(), - DiscoPublicKey: keys.Disco.String(), + ID: workspaceAgent.ID, + WireguardNodePublicKey: keys.Public.String(), + WireguardDiscoPublicKey: keys.Disco.String(), }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -504,7 +504,7 @@ func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) func (api *API) postWorkspaceAgentWireguardPeer(rw http.ResponseWriter, r *http.Request) { var ( - req peerwg.WireguardPeerMessage + req peerwg.Handshake workspaceAgent = httpmw.WorkspaceAgentParam(r) workspace = httpmw.WorkspaceParam(r) ) @@ -570,7 +570,7 @@ func (api *API) workspaceAgentWireguardListener(rw http.ResponseWriter, r *http. // Since we subscribe to all peer broadcasts, we do a light check to // make sure we're the intended recipient without fully decoding the // message. - hint, err := peerwg.WireguardPeerMessageRecipientHint(agentIDBytes, message) + hint, err := peerwg.HandshakeRecipientHint(agentIDBytes, message) if err != nil { api.Logger.Error(ctx, "invalid wireguard peer message", slog.Error(err)) return @@ -710,16 +710,16 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work EnvironmentVariables: envs, Directory: dbAgent.Directory, Apps: apps, - IPv6: inetToNetaddr(dbAgent.Ipv6), + IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6), } - err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardPublicKey)) + err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey)) if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard public key %q: %w", dbAgent.WireguardPublicKey, err) + return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err) } - err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.DiscoPublicKey)) + err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey)) if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.DiscoPublicKey, err) + return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err) } if dbAgent.FirstConnectedAt.Valid { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 13641dc9f7e68..499286d0c91a3 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -255,7 +255,7 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) ( // PostWireguardPeer announces your public keys and IPv6 address to the // specified recipient. -func (c *Client) PostWireguardPeer(ctx context.Context, workspaceID uuid.UUID, peerMsg peerwg.WireguardPeerMessage) error { +func (c *Client) PostWireguardPeer(ctx context.Context, workspaceID uuid.UUID, peerMsg peerwg.Handshake) error { res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/peer?workspace=%s", peerMsg.Recipient, workspaceID.String(), @@ -275,7 +275,7 @@ func (c *Client) PostWireguardPeer(ctx context.Context, workspaceID uuid.UUID, p // WireguardPeerListener listens for wireguard peer messages. Peer messages are // sent when a new client wants to connect. Once receiving a peer message, the // peer should be added to the NetworkMap of the wireguard interface. -func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) (<-chan peerwg.WireguardPeerMessage, func(), error) { +func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error) { serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/wireguardlisten") if err != nil { return nil, nil, xerrors.Errorf("parse url: %w", err) @@ -304,7 +304,7 @@ func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) return nil, nil, readBodyAsError(res) } - ch := make(chan peerwg.WireguardPeerMessage, 1) + ch := make(chan peerwg.Handshake, 1) go func() { defer conn.Close(websocket.StatusGoingAway, "") defer close(ch) @@ -315,7 +315,7 @@ func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) break } - var msg peerwg.WireguardPeerMessage + var msg peerwg.Handshake err = msg.UnmarshalText(message) if err != nil { logger.Error(ctx, "unmarshal wireguard peer message", slog.Error(err)) @@ -329,10 +329,10 @@ func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) return ch, func() { _ = conn.Close(websocket.StatusGoingAway, "") }, nil } -// PostWorkspaceAgentKeys uploads the public keys of the workspace agent that +// UploadWorkspaceAgentKeys uploads the public keys of the workspace agent that // were generated on startup. These keys are used by clients to communicate with // the workspace agent over the wireguard interface. -func (c *Client) PostWorkspaceAgentKeys(ctx context.Context, keys agent.PublicKeys) error { +func (c *Client) UploadWorkspaceAgentKeys(ctx context.Context, keys agent.WireguardPublicKeys) error { res, err := c.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/keys", keys) if err != nil { return xerrors.Errorf("do request: %w", err) diff --git a/peer/peerwg/peermessage.go b/peer/peerwg/peermessage.go index 8913204e5461d..7fc3bb4b2f4ab 100644 --- a/peer/peerwg/peermessage.go +++ b/peer/peerwg/peermessage.go @@ -10,26 +10,26 @@ import ( "tailscale.com/types/key" ) -const peerMessageSeparator byte = '\n' +const handshakeSeparator byte = '\n' -// WireguardPeerMessage is a message received from a wireguard peer, indicating +// Handshake is a message received from a wireguard peer, indicating // it would like to connect. -type WireguardPeerMessage struct { +type Handshake struct { // Recipient is the uuid of the agent that the message was intended for. Recipient uuid.UUID `json:"recipient"` - // Disco is the disco public key of the peer. - Disco key.DiscoPublic `json:"disco"` - // Public is the public key of the peer. - Public key.NodePublic `json:"public"` + // DiscoPublicKey is the disco public key of the peer. + DiscoPublicKey key.DiscoPublic `json:"disco"` + // NodePublicKey is the public key of the peer. + NodePublicKey key.NodePublic `json:"public"` // IPv6 is the IPv6 address of the peer. IPv6 netaddr.IP `json:"ipv6"` } -// WireguardPeerMessageRecipientHint parses the first part of a serialized -// WireguardPeerMessage to quickly determine if the message is meant for the -// provided agentID. -func WireguardPeerMessageRecipientHint(agentID []byte, msg []byte) (bool, error) { - idx := bytes.Index(msg, []byte{peerMessageSeparator}) +// HandshakeRecipientHint parses the first part of a serialized +// Handshake to quickly determine if the message is meant for the +// provided recipient. +func HandshakeRecipientHint(agentID []byte, msg []byte) (bool, error) { + idx := bytes.Index(msg, []byte{handshakeSeparator}) if idx == -1 { return false, xerrors.Errorf("invalid peer message, no separator") } @@ -37,28 +37,28 @@ func WireguardPeerMessageRecipientHint(agentID []byte, msg []byte) (bool, error) return bytes.Equal(agentID, msg[:idx]), nil } -func (pm *WireguardPeerMessage) UnmarshalText(text []byte) error { - sp := bytes.Split(text, []byte{peerMessageSeparator}) +func (h *Handshake) UnmarshalText(text []byte) error { + sp := bytes.Split(text, []byte{handshakeSeparator}) if len(sp) != 4 { return xerrors.Errorf("expected 4 parts, got %d", len(sp)) } - err := pm.Recipient.UnmarshalText(sp[0]) + err := h.Recipient.UnmarshalText(sp[0]) if err != nil { return xerrors.Errorf("parse recipient: %w", err) } - err = pm.Disco.UnmarshalText(sp[1]) + err = h.DiscoPublicKey.UnmarshalText(sp[1]) if err != nil { return xerrors.Errorf("parse disco: %w", err) } - err = pm.Public.UnmarshalText(sp[2]) + err = h.NodePublicKey.UnmarshalText(sp[2]) if err != nil { return xerrors.Errorf("parse public: %w", err) } - pm.IPv6, err = netaddr.ParseIP(string(sp[3])) + h.IPv6, err = netaddr.ParseIP(string(sp[3])) if err != nil { return xerrors.Errorf("parse ipv6: %w", err) } @@ -66,24 +66,24 @@ func (pm *WireguardPeerMessage) UnmarshalText(text []byte) error { return nil } -func (pm WireguardPeerMessage) MarshalText() ([]byte, error) { +func (h Handshake) MarshalText() ([]byte, error) { const expectedLen = 223 var buf bytes.Buffer buf.Grow(expectedLen) - recp, _ := pm.Recipient.MarshalText() + recp, _ := h.Recipient.MarshalText() _, _ = buf.Write(recp) - _ = buf.WriteByte(peerMessageSeparator) + _ = buf.WriteByte(handshakeSeparator) - disco, _ := pm.Disco.MarshalText() + disco, _ := h.DiscoPublicKey.MarshalText() _, _ = buf.Write(disco) - _ = buf.WriteByte(peerMessageSeparator) + _ = buf.WriteByte(handshakeSeparator) - pub, _ := pm.Public.MarshalText() + pub, _ := h.NodePublicKey.MarshalText() _, _ = buf.Write(pub) - _ = buf.WriteByte(peerMessageSeparator) + _ = buf.WriteByte(handshakeSeparator) - ipv6 := pm.IPv6.StringExpanded() + ipv6 := h.IPv6.StringExpanded() _, _ = buf.WriteString(ipv6) // Ensure we're always allocating exactly enough. diff --git a/peer/peerwg/wireguard.go b/peer/peerwg/wireguard.go index 35f7801bd2c6f..d24a32eb76c4e 100644 --- a/peer/peerwg/wireguard.go +++ b/peer/peerwg/wireguard.go @@ -35,6 +35,14 @@ import ( "cdr.dev/slog" ) +var logf tslogger.Logf = log.Printf + +func init() { + // Globally disable network namespacing. + // All networking happens in userspace. + netns.SetEnabled(false) +} + func UUIDToInet(uid uuid.UUID) pqtype.Inet { uid = privateUUID(uid) @@ -63,40 +71,41 @@ func privateUUID(uid uuid.UUID) uuid.UUID { return uid } -var logf tslogger.Logf = log.Printf - -type WireguardNetwork struct { - mu sync.Mutex - logger slog.Logger - Private key.NodePrivate - Disco key.DiscoPublic - - Engine wgengine.Engine - Netstack *netstack.Impl - Magic *magicsock.Conn +type Network struct { + mu sync.Mutex + logger slog.Logger + listeners map[listenKey]*listener + magicSock *magicsock.Conn netMap *netmap.NetworkMap router *router.Config - listeners map[listenKey]*listener + wgEngine wgengine.Engine + + DiscoPublicKey key.DiscoPublic + Netstack *netstack.Impl + NodePrivateKey key.NodePrivate } -func NewWireguardNetwork(_ context.Context, logger slog.Logger, addrs []netaddr.IPPrefix) (*WireguardNetwork, error) { - var ( - private = key.NewNode() - public = private.Public() - id, stableID = nodeIDs(public) - ) +// New constructs a Wireguard network that filters traffic +// to destinations matching the addresses provided. +func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) { + nodePrivateKey := key.NewNode() + nodePublicKey := nodePrivateKey.Public() + id, stableID := nodeIDs(nodePublicKey) netMap := &netmap.NetworkMap{ - NodeKey: public, - PrivateKey: private, - Addresses: addrs, + NodeKey: nodePublicKey, + PrivateKey: nodePrivateKey, + Addresses: addresses, PacketFilter: []filter.Match{{ - IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6}, + // Allow any protocol! + IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP}, + // Allow traffic sourced from anywhere. Srcs: []netaddr.IPPrefix{ netaddr.IPPrefixFrom(netaddr.IPv4(0, 0, 0, 0), 0), netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 0), }, + // Allow traffic to route anywhere. Dsts: []filter.NetPortRange{ { Net: netaddr.IPPrefixFrom(netaddr.IPv4(0, 0, 0, 0), 0), @@ -116,63 +125,70 @@ func NewWireguardNetwork(_ context.Context, logger slog.Logger, addrs []netaddr. Caps: []filter.CapMatch{}, }}, } + // Identify itself as a node on the network with the addresses provided. netMap.SelfNode = &tailcfg.Node{ ID: id, StableID: stableID, - Key: public, + Key: nodePublicKey, Addresses: netMap.Addresses, AllowedIPs: append(netMap.Addresses, netaddr.MustParseIPPrefix("::/0")), Endpoints: []string{}, DERP: DefaultDerpHome, } - linkMon, err := monitor.New(logf) + wgMonitor, err := monitor.New(logf) if err != nil { return nil, xerrors.Errorf("create link monitor: %w", err) } - netns.SetEnabled(false) dialer := new(tsdial.Dialer) dialer.Logf = logf - e, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ - LinkMonitor: linkMon, + // Create a wireguard engine in userspace. + engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ + LinkMonitor: wgMonitor, Dialer: dialer, }) if err != nil { return nil, xerrors.Errorf("create wgengine: %w", err) } - ig, _ := e.(wgengine.InternalsGetter) - tunDev, magicConn, dnsMgr, ok := ig.GetInternals() + // This is taken from Tailscale: + // https://github.com/tailscale/tailscale/blob/0f05b2c13ff0c305aa7a1655fa9c17ed969d65be/tsnet/tsnet.go#L247-L255 + // nolint + tunDev, magicConn, dnsManager, ok := engine.(wgengine.InternalsGetter).GetInternals() if !ok { return nil, xerrors.New("could not get wgengine internals") } - // This can't error. - _ = magicConn.SetPrivateKey(private) + // Update the keys for the magic connection! + err = magicConn.SetPrivateKey(nodePrivateKey) + if err != nil { + return nil, xerrors.Errorf("set node private key: %w", err) + } netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() - ns, err := netstack.Create(logf, tunDev, e, magicConn, dialer, dnsMgr) + // Create the networking stack. + // This is called to route connections. + netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager) if err != nil { return nil, xerrors.Errorf("create netstack: %w", err) } - - ns.ProcessLocalIPs = true - ns.ProcessSubnets = true + netStack.ProcessLocalIPs = true + netStack.ProcessSubnets = true dialer.UseNetstackForIP = func(ip netaddr.IP) bool { - _, ok := e.PeerForIP(ip) + _, ok := engine.PeerForIP(ip) return ok } dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) { - return ns.DialContextTCP(ctx, dst) + return netStack.DialContextTCP(ctx, dst) } - - err = ns.Start() + err = netStack.Start() if err != nil { return nil, xerrors.Errorf("start netstack: %w", err) } - e = wgengine.NewWatchdog(e) + engine = wgengine.NewWatchdog(engine) + // Update the wireguard configuration to allow traffic to flow. cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID) if err != nil { return nil, xerrors.Errorf("create wgcfg: %w", err) @@ -181,14 +197,13 @@ func NewWireguardNetwork(_ context.Context, logger slog.Logger, addrs []netaddr. rtr := &router.Config{ LocalAddrs: cfg.Addresses, } - - err = e.Reconfig(cfg, rtr, &dns.Config{}, &tailcfg.Debug{}) + err = engine.Reconfig(cfg, rtr, &dns.Config{}, &tailcfg.Debug{}) if err != nil { return nil, xerrors.Errorf("reconfig: %w", err) } - e.SetDERPMap(DerpMap) - e.SetNetworkMap(copyNetMap(netMap)) + engine.SetDERPMap(DerpMap) + engine.SetNetworkMap(copyNetMap(netMap)) ipb := netaddr.IPSetBuilder{} for _, addr := range netMap.Addresses { @@ -198,56 +213,47 @@ func NewWireguardNetwork(_ context.Context, logger slog.Logger, addrs []netaddr. iplb := netaddr.IPSetBuilder{} ipl, _ := iplb.IPSet() - e.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf)) - - wn := &WireguardNetwork{ - logger: logger, - Private: private, - Disco: magicConn.DiscoPublicKey(), - Engine: e, - Netstack: ns, - Magic: magicConn, - netMap: netMap, - router: rtr, - listeners: map[listenKey]*listener{}, + engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf)) + + wn := &Network{ + logger: logger, + NodePrivateKey: nodePrivateKey, + DiscoPublicKey: magicConn.DiscoPublicKey(), + wgEngine: engine, + Netstack: netStack, + magicSock: magicConn, + netMap: netMap, + router: rtr, + listeners: map[listenKey]*listener{}, } - ns.ForwardTCPIn = wn.forwardTCP + netStack.ForwardTCPIn = wn.forwardTCP return wn, nil } -// nodeIDs generates Tailscale node IDs for the provided public key. -func nodeIDs(public key.NodePublic) (tailcfg.NodeID, tailcfg.StableNodeID) { - idhash := fnv.New64() - pub, _ := public.MarshalText() - _, _ = idhash.Write(pub) - - return tailcfg.NodeID(idhash.Sum64()), tailcfg.StableNodeID(pub) -} - -// forwardTCP handles incoming TCP connections from wireguard. -func (wn *WireguardNetwork) forwardTCP(c net.Conn, port uint16) { - wn.mu.Lock() - ln, ok := wn.listeners[listenKey{"tcp", "", fmt.Sprint(port)}] - wn.mu.Unlock() +// forwardTCP handles incoming connections from Wireguard in userspace. +func (n *Network) forwardTCP(conn net.Conn, port uint16) { + n.mu.Lock() + listener, ok := n.listeners[listenKey{"tcp", "", fmt.Sprint(port)}] + n.mu.Unlock() if !ok { // No listener added, forward to host. - wn.forwardTCPLocal(c, port) + n.forwardTCPToLocalHandler(conn, port) return } - t := time.NewTimer(time.Second) - defer t.Stop() + timer := time.NewTimer(time.Second) + defer timer.Stop() select { - case ln.conn <- c: - case <-t.C: - _ = c.Close() + case listener.conn <- conn: + case <-timer.C: + _ = conn.Close() } } -// forwardTCPLocal forwards the provided net.Conn to the matching port on the -// host. -func (wn *WireguardNetwork) forwardTCPLocal(c net.Conn, port uint16) { +// forwardTCPToLocalHandler forwards the provided net.Conn to the +// matching port bound to localhost. +func (n *Network) forwardTCPToLocalHandler(c net.Conn, port uint16) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() defer c.Close() @@ -256,7 +262,7 @@ func (wn *WireguardNetwork) forwardTCPLocal(c net.Conn, port uint16) { var stdDialer net.Dialer server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) if err != nil { - wn.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err)) + n.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err)) return } defer server.Close() @@ -272,84 +278,72 @@ func (wn *WireguardNetwork) forwardTCPLocal(c net.Conn, port uint16) { }() err = <-connClosed if err != nil { - wn.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err)) + n.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err)) } - wn.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) -} - -func (wn *WireguardNetwork) Close() error { - _ = wn.Netstack.Close() - wn.Engine.Close() - - return nil + n.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) } -// AddPeer adds a peer to the network from a WireguardPeerMessage. After adding -// a peer, they may connect to you. -func (wn *WireguardNetwork) AddPeer(peer WireguardPeerMessage) error { - wn.mu.Lock() - defer wn.mu.Unlock() +// AddPeer allows connections from another Wireguard instance with the +// handshake credentials. +func (n *Network) AddPeer(handshake Handshake) error { + n.mu.Lock() + defer n.mu.Unlock() // If the peer already exists in the network map, do nothing. - for _, p := range wn.netMap.Peers { - if p.Key == peer.Public { - wn.logger.Debug(context.Background(), "peer already in netmap", slog.F("peer", peer.Public.ShortString())) + for _, p := range n.netMap.Peers { + if p.Key == handshake.NodePublicKey { + n.logger.Debug(context.Background(), "peer already in netmap", slog.F("peer", handshake.NodePublicKey.ShortString())) return nil } } // The Tailscale engine owns this slice, so we need to copy to make // modifications. - peers := append(([]*tailcfg.Node)(nil), wn.netMap.Peers...) + peers := append(([]*tailcfg.Node)(nil), n.netMap.Peers...) - id, stableID := nodeIDs(peer.Public) + id, stableID := nodeIDs(handshake.NodePublicKey) peers = append(peers, &tailcfg.Node{ ID: id, StableID: stableID, - Name: peer.Public.String() + ".com", - Key: peer.Public, - DiscoKey: peer.Disco, - Addresses: []netaddr.IPPrefix{netaddr.IPPrefixFrom(peer.IPv6, 128)}, - AllowedIPs: []netaddr.IPPrefix{netaddr.IPPrefixFrom(peer.IPv6, 128)}, + Name: handshake.NodePublicKey.String() + ".com", + Key: handshake.NodePublicKey, + DiscoKey: handshake.DiscoPublicKey, + Addresses: []netaddr.IPPrefix{netaddr.IPPrefixFrom(handshake.IPv6, 128)}, + AllowedIPs: []netaddr.IPPrefix{netaddr.IPPrefixFrom(handshake.IPv6, 128)}, DERP: DefaultDerpHome, Endpoints: []string{DefaultDerpHome}, }) - wn.netMap.Peers = peers + n.netMap.Peers = peers - cfg, err := nmcfg.WGCfg(wn.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL")) + cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL")) if err != nil { return xerrors.Errorf("create wgcfg: %w", err) } - err = wn.Engine.Reconfig(cfg, wn.router, &dns.Config{}, &tailcfg.Debug{}) + err = n.wgEngine.Reconfig(cfg, n.router, &dns.Config{}, &tailcfg.Debug{}) if err != nil { return xerrors.Errorf("reconfig: %w", err) } // Always give the Tailscale engine a copy of our network map. - wn.Engine.SetNetworkMap(copyNetMap(wn.netMap)) + n.wgEngine.SetNetworkMap(copyNetMap(n.netMap)) return nil } -func copyNetMap(nm *netmap.NetworkMap) *netmap.NetworkMap { - nmCopy := *nm - return &nmCopy -} - // Ping sends a discovery ping to the provided peer. -func (wn *WireguardNetwork) Ping(peer WireguardPeerMessage) *ipnstate.PingResult { +// The peer address must be connected before a successful ping will work. +func (n *Network) Ping(ip netaddr.IP) *ipnstate.PingResult { ch := make(chan *ipnstate.PingResult) - wn.Engine.Ping(peer.IPv6, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { + n.wgEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { ch <- pr }) - return <-ch } -// Listen returns a net.Listener that can be used to accept connections from the -// wireguard network at the specified address. -func (wn *WireguardNetwork) Listen(network, addr string) (net.Listener, error) { +// Listener returns a net.Listener in userspace that can be used to accept +// connections from the Wireguard network to the specified address. +func (n *Network) Listen(network, addr string) (net.Listener, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, xerrors.Errorf("split addr host port: %w", err) @@ -357,24 +351,31 @@ func (wn *WireguardNetwork) Listen(network, addr string) (net.Listener, error) { lkey := listenKey{network, host, port} ln := &listener{ - wn: wn, + wn: n, key: lkey, addr: addr, conn: make(chan net.Conn, 1), } - wn.mu.Lock() - defer wn.mu.Unlock() + n.mu.Lock() + defer n.mu.Unlock() - if _, ok := wn.listeners[lkey]; ok { + if _, ok := n.listeners[lkey]; ok { return nil, xerrors.Errorf("listener already open for %s, %s", network, addr) } - wn.listeners[lkey] = ln + n.listeners[lkey] = ln return ln, nil } +func (n *Network) Close() error { + _ = n.Netstack.Close() + n.wgEngine.Close() + + return nil +} + type listenKey struct { network string host string @@ -382,7 +383,7 @@ type listenKey struct { } type listener struct { - wn *WireguardNetwork + wn *Network key listenKey addr string conn chan net.Conn @@ -413,3 +414,17 @@ type addr struct{ ln *listener } func (a addr) Network() string { return a.ln.key.network } func (a addr) String() string { return a.ln.addr } + +// nodeIDs generates Tailscale node IDs for the provided public key. +func nodeIDs(public key.NodePublic) (tailcfg.NodeID, tailcfg.StableNodeID) { + idhash := fnv.New64() + pub, _ := public.MarshalText() + _, _ = idhash.Write(pub) + + return tailcfg.NodeID(idhash.Sum64()), tailcfg.StableNodeID(pub) +} + +func copyNetMap(nm *netmap.NetworkMap) *netmap.NetworkMap { + nmCopy := *nm + return &nmCopy +}