Skip to content

Commit 7fe91a2

Browse files
committed
Fix conncache with interface
1 parent 5ba96b5 commit 7fe91a2

File tree

7 files changed

+204
-20
lines changed

7 files changed

+204
-20
lines changed

agent/agent.go

+3
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,9 @@ func (a *agent) Close() error {
865865
}
866866
close(a.closed)
867867
a.closeCancel()
868+
if a.network != nil {
869+
_ = a.network.Close()
870+
}
868871
_ = a.sshServer.Close()
869872
a.connCloseWait.Wait()
870873
return nil

agent/conn.go

+53
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ import (
77
"io"
88
"net"
99
"net/url"
10+
"strconv"
1011
"strings"
1112
"time"
1213

1314
"golang.org/x/crypto/ssh"
1415
"golang.org/x/xerrors"
16+
"inet.af/netaddr"
1517

1618
"github.com/coder/coder/peer"
1719
"github.com/coder/coder/peerbroker/proto"
20+
"github.com/coder/coder/tailnet"
1821
)
1922

2023
// ReconnectingPTYRequest is sent from the client to the server
@@ -130,3 +133,53 @@ func (c *WebRTCConn) Close() error {
130133
_ = c.Negotiator.DRPCConn().Close()
131134
return c.Conn.Close()
132135
}
136+
137+
type TailnetConn struct {
138+
Target netaddr.IP
139+
*tailnet.Server
140+
}
141+
142+
func (c *TailnetConn) Closed() <-chan struct{} {
143+
return nil
144+
}
145+
146+
func (c *TailnetConn) Ping() (time.Duration, error) {
147+
return 0, nil
148+
}
149+
150+
func (c *TailnetConn) CloseWithError(err error) error {
151+
return c.Close()
152+
}
153+
154+
func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
155+
return nil, xerrors.New("not implemented")
156+
}
157+
158+
func (c *TailnetConn) SSH() (net.Conn, error) {
159+
return c.DialContextTCP(context.Background(), netaddr.IPPortFrom(c.Target, 12212))
160+
}
161+
162+
// SSHClient calls SSH to create a client that uses a weak cipher
163+
// for high throughput.
164+
func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
165+
netConn, err := c.SSH()
166+
if err != nil {
167+
return nil, xerrors.Errorf("ssh: %w", err)
168+
}
169+
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
170+
// SSH host validation isn't helpful, because obtaining a peer
171+
// connection already signifies user-intent to dial a workspace.
172+
// #nosec
173+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
174+
})
175+
if err != nil {
176+
return nil, xerrors.Errorf("ssh conn: %w", err)
177+
}
178+
return ssh.NewClient(sshConn, channels, requests), nil
179+
}
180+
181+
func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
182+
_, rawPort, _ := net.SplitHostPort(addr)
183+
port, _ := strconv.Atoi(rawPort)
184+
return c.Server.DialContextTCP(ctx, netaddr.IPPortFrom(c.Target, uint16(port)))
185+
}

coderd/coderdtest/coderdtest.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,17 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, *coderd.API)
181181
Nodes: []*tailcfg.DERPNode{{
182182
Name: "1a",
183183
RegionID: 1,
184-
HostName: serverURL.Host,
184+
IPv4: "127.0.0.1",
185185
DERPPort: derpPort,
186186
STUNPort: -1,
187187
InsecureForTests: true,
188188
HTTPForTests: true,
189+
}, {
190+
Name: "1b",
191+
RegionID: 1,
192+
STUNOnly: true,
193+
HostName: "stun.l.google.com",
194+
STUNPort: 19302,
189195
}},
190196
},
191197
},

coderd/workspaceagents.go

+10-14
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@ import (
1515
"github.com/google/uuid"
1616
"github.com/hashicorp/yamux"
1717
"github.com/tabbed/pqtype"
18-
"go4.org/mem"
1918
"golang.org/x/xerrors"
2019
"inet.af/netaddr"
2120
"nhooyr.io/websocket"
22-
"tailscale.com/types/key"
21+
"nhooyr.io/websocket/wsjson"
2322

2423
"cdr.dev/slog"
2524
"github.com/coder/coder/agent"
@@ -549,7 +548,6 @@ func (api *API) workspaceAgentNode(rw http.ResponseWriter, r *http.Request) {
549548
return
550549
}
551550
defer conn.Close(websocket.StatusNormalClosure, "")
552-
ctx, nc := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
553551
agentIDBytes, _ := workspaceAgent.ID.MarshalText()
554552
subCancel, err := api.Pubsub.Subscribe("tailnet", func(ctx context.Context, message []byte) {
555553
// Since we subscribe to all peer broadcasts, we do a light check to
@@ -560,25 +558,24 @@ func (api *API) workspaceAgentNode(rw http.ResponseWriter, r *http.Request) {
560558
return
561559
}
562560
// We aren't the intended recipient.
563-
if !bytes.Equal(message[:len(agentIDBytes)-1], agentIDBytes) {
561+
if !bytes.Equal(message[:len(agentIDBytes)], agentIDBytes) {
564562
return
565563
}
566-
_, _ = nc.Write(message)
564+
_ = conn.Write(ctx, websocket.MessageText, message[len(agentIDBytes):])
567565
})
568566
if err != nil {
569-
api.Logger.Error(ctx, "pubsub listen", slog.Error(err))
567+
api.Logger.Error(context.Background(), "pubsub listen", slog.Error(err))
570568
return
571569
}
572570
defer subCancel()
573571

574-
decoder := json.NewDecoder(nc)
575572
for {
576573
var node tailnet.Node
577-
err = decoder.Decode(&node)
574+
err = wsjson.Read(r.Context(), conn, &node)
578575
if err != nil {
579576
return
580577
}
581-
err := api.Database.UpdateWorkspaceAgentNetworkByID(ctx, database.UpdateWorkspaceAgentNetworkByIDParams{
578+
err := api.Database.UpdateWorkspaceAgentNetworkByID(r.Context(), database.UpdateWorkspaceAgentNetworkByIDParams{
582579
ID: workspaceAgent.ID,
583580
NodePublicKey: sql.NullString{
584581
String: node.Key.String(),
@@ -623,7 +620,8 @@ func (api *API) postWorkspaceAgentNode(rw http.ResponseWriter, r *http.Request)
623620
})
624621
return
625622
}
626-
data = append(workspaceAgent.ID[:], data...)
623+
agentIDBytes, _ := workspaceAgent.ID.MarshalText()
624+
data = append(agentIDBytes, data...)
627625
err = api.Pubsub.Publish("tailnet", data)
628626
if err != nil {
629627
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@@ -696,15 +694,13 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
696694
}
697695

698696
if dbAgent.NodePublicKey.Valid {
699-
var err error
700-
workspaceAgent.NodePublicKey, err = key.ParseNodePublicUntyped(mem.S(dbAgent.NodePublicKey.String))
697+
err := workspaceAgent.NodePublicKey.UnmarshalText([]byte(dbAgent.NodePublicKey.String))
701698
if err != nil {
702699
return codersdk.WorkspaceAgent{}, xerrors.Errorf("parse node public key: %w", err)
703700
}
704701
}
705702
if dbAgent.DiscoPublicKey.Valid {
706-
var err error
707-
err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.DiscoPublicKey.String))
703+
err := workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.DiscoPublicKey.String))
708704
if err != nil {
709705
return codersdk.WorkspaceAgent{}, xerrors.Errorf("parse disco public key: %w", err)
710706
}

coderd/workspaceagents_test.go

+63
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bufio"
55
"context"
66
"encoding/json"
7+
"fmt"
78
"runtime"
89
"strings"
910
"testing"
@@ -253,6 +254,68 @@ func TestWorkspaceAgentTURN(t *testing.T) {
253254
require.NoError(t, err)
254255
}
255256

257+
func TestWorkspaceAgentTailnet(t *testing.T) {
258+
t.Parallel()
259+
client, coderAPI := coderdtest.NewWithAPI(t, nil)
260+
user := coderdtest.CreateFirstUser(t, client)
261+
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
262+
authToken := uuid.NewString()
263+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
264+
Parse: echo.ParseComplete,
265+
ProvisionDryRun: echo.ProvisionComplete,
266+
Provision: []*proto.Provision_Response{{
267+
Type: &proto.Provision_Response_Complete{
268+
Complete: &proto.Provision_Complete{
269+
Resources: []*proto.Resource{{
270+
Name: "example",
271+
Type: "aws_instance",
272+
Agents: []*proto.Agent{{
273+
Id: uuid.NewString(),
274+
Auth: &proto.Agent_Token{
275+
Token: authToken,
276+
},
277+
}},
278+
}},
279+
},
280+
},
281+
}},
282+
})
283+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
284+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
285+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
286+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
287+
daemonCloser.Close()
288+
289+
agentClient := codersdk.New(client.URL)
290+
agentClient.SessionToken = authToken
291+
agentCloser := agent.New(agent.Options{
292+
FetchMetadata: agentClient.WorkspaceAgentMetadata,
293+
WebRTCDialer: agentClient.ListenWorkspaceAgent,
294+
EnableTailnet: true,
295+
NodeDialer: agentClient.WorkspaceAgentNodeBroker,
296+
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
297+
})
298+
t.Cleanup(func() {
299+
_ = agentCloser.Close()
300+
})
301+
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
302+
303+
time.Sleep(3 * time.Second)
304+
305+
conn, err := client.DialWorkspaceAgentTailnet(context.Background(), resources[0].Agents[0].ID, slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug))
306+
require.NoError(t, err)
307+
t.Cleanup(func() {
308+
_ = conn.Close()
309+
})
310+
sshClient, err := conn.SSHClient()
311+
require.NoError(t, err)
312+
session, err := sshClient.NewSession()
313+
require.NoError(t, err)
314+
output, err := session.CombinedOutput("echo test")
315+
require.NoError(t, err)
316+
fmt.Printf("Output: %s\n", output)
317+
}
318+
256319
func TestWorkspaceAgentPTY(t *testing.T) {
257320
t.Parallel()
258321
if runtime.GOOS == "windows" {

coderd/wsconncache/wsconncache.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
3232
}
3333

3434
// Dialer creates a new agent connection by ID.
35-
type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error)
35+
type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error)
3636

3737
// Conn wraps an agent connection with a reusable HTTP transport.
3838
type Conn struct {
39-
*agent.Conn
39+
agent.Conn
4040

4141
locks atomic.Uint64
4242
timeoutMutex sync.Mutex

codersdk/workspaceagents.go

+66-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ import (
1515
"github.com/pion/webrtc/v3"
1616
"golang.org/x/net/proxy"
1717
"golang.org/x/xerrors"
18+
"inet.af/netaddr"
1819
"nhooyr.io/websocket"
1920
"nhooyr.io/websocket/wsjson"
21+
"tailscale.com/tailcfg"
2022

2123
"cdr.dev/slog"
2224

@@ -302,6 +304,67 @@ func (c *Client) WorkspaceAgentNodeBroker(ctx context.Context) (agent.NodeBroker
302304
return &workspaceAgentNodeBroker{conn}, nil
303305
}
304306

307+
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, agentID uuid.UUID, logger slog.Logger) (agent.Conn, error) {
308+
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/derpmap", agentID), nil)
309+
if err != nil {
310+
return nil, err
311+
}
312+
defer res.Body.Close()
313+
if res.StatusCode != http.StatusOK {
314+
return nil, readBodyAsError(res)
315+
}
316+
var derpMap tailcfg.DERPMap
317+
err = json.NewDecoder(res.Body).Decode(&derpMap)
318+
if err != nil {
319+
return nil, xerrors.Errorf("decode derpmap: %w", err)
320+
}
321+
ip := tailnet.IP()
322+
323+
server, err := tailnet.New(&tailnet.Options{
324+
Addresses: []netaddr.IPPrefix{netaddr.IPPrefixFrom(ip, 128)},
325+
DERPMap: &derpMap,
326+
Logger: logger,
327+
})
328+
if err != nil {
329+
return nil, xerrors.Errorf("create tailnet: %w", err)
330+
}
331+
server.SetNodeCallback(func(node *tailnet.Node) {
332+
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/node", agentID), node)
333+
if err != nil {
334+
logger.Error(ctx, "update node", slog.Error(err), slog.F("node", node))
335+
return
336+
}
337+
defer res.Body.Close()
338+
if res.StatusCode != http.StatusOK {
339+
logger.Error(ctx, "update node", slog.F("status_code", res.StatusCode), slog.F("node", node))
340+
}
341+
})
342+
workspaceAgent, err := c.WorkspaceAgent(ctx, agentID)
343+
if err != nil {
344+
return nil, xerrors.Errorf("get workspace agent: %w", err)
345+
}
346+
ipRanges := make([]netaddr.IPPrefix, 0, len(workspaceAgent.IPAddresses))
347+
for _, address := range workspaceAgent.IPAddresses {
348+
ipRanges = append(ipRanges, netaddr.IPPrefixFrom(address, 128))
349+
}
350+
agentNode := &tailnet.Node{
351+
Key: workspaceAgent.NodePublicKey,
352+
DiscoKey: workspaceAgent.DiscoPublicKey,
353+
PreferredDERP: workspaceAgent.PreferredDERP,
354+
Addresses: ipRanges,
355+
AllowedIPs: ipRanges,
356+
}
357+
logger.Debug(ctx, "adding agent node", slog.F("node", agentNode))
358+
err = server.UpdateNodes([]*tailnet.Node{agentNode})
359+
if err != nil {
360+
return nil, xerrors.Errorf("update nodes: %w", err)
361+
}
362+
return &agent.TailnetConn{
363+
Target: workspaceAgent.IPAddresses[0],
364+
Server: server,
365+
}, nil
366+
}
367+
305368
// DialWorkspaceAgent creates a connection to the specified resource.
306369
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (agent.Conn, error) {
307370
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String()))
@@ -447,9 +510,9 @@ type workspaceAgentNodeBroker struct {
447510
}
448511

449512
func (w *workspaceAgentNodeBroker) Read(ctx context.Context) (*tailnet.Node, error) {
450-
var node *tailnet.Node
451-
err := wsjson.Read(ctx, w.conn, node)
452-
return node, err
513+
var node tailnet.Node
514+
err := wsjson.Read(ctx, w.conn, &node)
515+
return &node, err
453516
}
454517

455518
func (w *workspaceAgentNodeBroker) Write(ctx context.Context, node *tailnet.Node) error {

0 commit comments

Comments
 (0)