Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Centralize p2p Dial logic #326

Merged
merged 1 commit into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ func (s *Server) Run(ctx context.Context) error {
}),
).Run(
func() error {
ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15)
defer cancelFunc()
s.log.Info(ctx, "connecting to coder", slog.F("url", s.listenURL.String()))
conn, resp, err := websocket.Dial(ctx, s.listenURL.String(), nil)
if err != nil && resp == nil {
Expand All @@ -71,7 +69,8 @@ func (s *Server) Run(ctx context.Context) error {
Response: resp,
}
}
nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)

nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
session, err := yamux.Server(nc, nil)
if err != nil {
return fmt.Errorf("open: %w", err)
Expand Down
15 changes: 10 additions & 5 deletions agent/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ import (
"io"
"net"

"cdr.dev/coder-cli/xwebrtc"

"cdr.dev/slog"
"github.com/hashicorp/yamux"
"github.com/pion/webrtc/v3"
"golang.org/x/xerrors"

"cdr.dev/coder-cli/internal/x/xwebrtc"
"cdr.dev/coder-cli/pkg/proto"
)

Expand Down Expand Up @@ -128,6 +129,10 @@ func (s *stream) processMessage(msg proto.Message) {
}

func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
if channel.Protocol() == "control" {
return
}

if channel.Protocol() == "ping" {
channel.OnOpen(func() {
rw, err := channel.Detach()
Expand All @@ -149,7 +154,7 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
return
}

prto, port, err := xwebrtc.ParseProxyDataChannel(channel)
prto, addr, err := xwebrtc.ParseProxyDataChannel(channel)
if err != nil {
s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err))
return
Expand All @@ -159,14 +164,14 @@ func (s *stream) processDataChannel(channel *webrtc.DataChannel) {
return
}

conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port))
conn, err := net.Dial(prto, addr)
if err != nil {
s.fatal(fmt.Errorf("failed to dial client port: %d", port))
s.fatal(fmt.Errorf("failed to dial client addr: %s", addr))
return
}

channel.OnOpen(func() {
s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port))
s.logger.Debug(context.Background(), "proxying data channel", slog.F("addr", addr))
rw, err := channel.Detach()
if err != nil {
_ = channel.Close()
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/rjeczalik/notify v0.9.2
github.com/spf13/cobra v1.1.3
go.coder.com/retry v1.2.0
golang.org/x/net v0.0.0-20210420210106-798c2154c571
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9
golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
Expand Down
223 changes: 39 additions & 184 deletions internal/cmd/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,21 @@ package cmd

import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/url"
"os"
"strconv"
"time"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/pion/webrtc/v3"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"nhooyr.io/websocket"

"cdr.dev/coder-cli/coder-sdk"
"cdr.dev/coder-cli/internal/x/xcobra"
"cdr.dev/coder-cli/internal/x/xwebrtc"
"cdr.dev/coder-cli/pkg/proto"
"cdr.dev/coder-cli/xwebrtc"
)

func tunnelCmd() *cobra.Command {
Expand All @@ -41,26 +36,26 @@ coder tunnel my-dev 3000 3000

remotePort, err := strconv.ParseUint(args[1], 10, 16)
if err != nil {
log.Fatal(ctx, "parse remote port", slog.Error(err))
return xerrors.Errorf("parse remote port: %w", err)
}

var localPort uint64
if args[2] != "stdio" {
localPort, err = strconv.ParseUint(args[2], 10, 16)
if err != nil {
log.Fatal(ctx, "parse local port", slog.Error(err))
return xerrors.Errorf("parse local port: %w", err)
}
}

sdk, err := newClient(ctx)
if err != nil {
return err
return xerrors.Errorf("getting coder client: %w", err)
}
baseURL := sdk.BaseURL()

envs, err := getEnvs(ctx, sdk, coder.Me)
if err != nil {
return err
return xerrors.Errorf("get workspaces: %w", err)
}

var envID string
Expand All @@ -74,20 +69,19 @@ coder tunnel my-dev 3000 3000
return xerrors.Errorf("No workspace found by name '%s'", args[0])
}

c := &client{
id: envID,
stdio: args[2] == "stdio",
localPort: uint16(localPort),
remotePort: uint16(remotePort),
ctx: context.Background(),
logger: log.Leveled(slog.LevelDebug),
brokerAddr: baseURL,
token: sdk.Token(),
c := &tunnneler{
log: log.Leveled(slog.LevelDebug),
brokerAddr: &baseURL,
token: sdk.Token(),
workspaceID: envID,
stdio: args[2] == "stdio",
localPort: uint16(localPort),
remotePort: uint16(remotePort),
}

err = c.start()
err = c.start(ctx)
if err != nil {
log.Fatal(ctx, err.Error())
return xerrors.Errorf("running tunnel: %w", err)
}

return nil
Expand All @@ -97,197 +91,58 @@ coder tunnel my-dev 3000 3000
return cmd
}

type client struct {
ctx context.Context
brokerAddr url.URL
token string
logger slog.Logger
id string
remotePort uint16
localPort uint16
stdio bool
type tunnneler struct {
log slog.Logger
brokerAddr *url.URL
token string
workspaceID string
remotePort uint16
localPort uint16
stdio bool
}

func (c *client) start() error {
url := fmt.Sprintf("%s%s%s%s%s", c.brokerAddr.String(), "/api/private/envagent/", c.id, "/connect?session_token=", c.token)
turnScheme := "turns"
if c.brokerAddr.Scheme == "http" {
turnScheme = "turn"
}
tcpProxy := fmt.Sprintf("%s:%s:5349?transport=tcp", turnScheme, c.brokerAddr.Host)
c.logger.Info(c.ctx, "connecting to broker", slog.F("url", url), slog.F("tcp-proxy", tcpProxy))
conn, resp, err := websocket.Dial(c.ctx, url, nil)
if err != nil && resp == nil {
return fmt.Errorf("dial: %w", err)
}
if err != nil && resp != nil {
return &coder.HTTPError{
Response: resp,
}
}
nconn := websocket.NetConn(context.Background(), conn, websocket.MessageBinary)

// Only enabled under a private feature flag for now,
// so insecure connections are entirely fine to allow.
servers := []webrtc.ICEServer{{
URLs: []string{tcpProxy},
Username: "insecure",
Credential: "pass",
CredentialType: webrtc.ICECredentialTypePassword,
}}
rtc, err := xwebrtc.NewPeerConnection(servers)
if err != nil {
return fmt.Errorf("create connection: %w", err)
}

rtc.OnNegotiationNeeded(func() {
c.logger.Debug(context.Background(), "negotiation needed...")
})

rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
c.logger.Info(context.Background(), "connection state changed", slog.F("state", pcs))
})

channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
if err != nil {
return fmt.Errorf("create data channel: %w", err)
}
flushCandidates := proto.ProxyICECandidates(rtc, nconn)

localDesc, err := rtc.CreateOffer(&webrtc.OfferOptions{})
if err != nil {
return fmt.Errorf("create offer: %w", err)
}

err = rtc.SetLocalDescription(localDesc)
if err != nil {
return fmt.Errorf("set local desc: %w", err)
}

c.logger.Debug(context.Background(), "writing offer")
b, _ := json.Marshal(&proto.Message{
Offer: &localDesc,
Servers: servers,
})
_, err = nconn.Write(b)
func (c *tunnneler) start(ctx context.Context) error {
wd, err := xwebrtc.NewWorkspaceDialer(ctx, c.log, c.brokerAddr, c.token, c.workspaceID)
if err != nil {
return fmt.Errorf("write offer: %w", err)
}
flushCandidates()

go func() {
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
if err != nil {
c.logger.Fatal(context.Background(), "waiting for data channel open", slog.Error(err))
}
_ = conn.Close(websocket.StatusNormalClosure, "rtc connected")
}()

decoder := json.NewDecoder(nconn)
for {
var msg proto.Message
err = decoder.Decode(&msg)
if err == io.EOF {
break
}
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
break
}
if err != nil {
return fmt.Errorf("read msg: %w", err)
}
if msg.Candidate != "" {
c.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate))
err = proto.AcceptICECandidate(rtc, &msg)
if err != nil {
return fmt.Errorf("accept ice: %w", err)
}
}
if msg.Answer != nil {
c.logger.Debug(context.Background(), "got answer", slog.F("answer", msg.Answer))
err = rtc.SetRemoteDescription(*msg.Answer)
if err != nil {
return fmt.Errorf("set remote: %w", err)
}
}
return xerrors.Errorf("creating workspace dialer: %w", wd)
}

// Once we're open... let's test out the ping.
pingProto := "ping"
pingChannel, err := rtc.CreateDataChannel("pinger", &webrtc.DataChannelInit{
Protocol: &pingProto,
})
nc, err := wd.DialContext(ctx, xwebrtc.NetworkTCP, fmt.Sprintf("localhost:%d", c.remotePort))
if err != nil {
return fmt.Errorf("create ping channel")
return xerrors.Errorf("dial: %w", err)
}
pingChannel.OnOpen(func() {
defer func() {
_ = pingChannel.Close()
}()
t1 := time.Now()
rw, _ := pingChannel.Detach()
defer func() {
_ = rw.Close()
}()
_, _ = rw.Write([]byte("hello"))
b := make([]byte, 64)
_, _ = rw.Read(b)
c.logger.Info(c.ctx, "your latency directly to the agent", slog.F("ms", time.Since(t1).Milliseconds()))
})

// proxy via stdio
if c.stdio {
// At this point the RTC is connected and data channel is opened...
rw, err := channel.Detach()
if err != nil {
return fmt.Errorf("detach channel: %w", err)
}
go func() {
_, _ = io.Copy(rw, os.Stdin)
_, _ = io.Copy(nc, os.Stdin)
}()
_, err = io.Copy(os.Stdout, rw)
_, err = io.Copy(os.Stdout, nc)
if err != nil {
return fmt.Errorf("copy: %w", err)
return xerrors.Errorf("copy: %w", err)
}
return nil
}

// proxy via tcp listener
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", c.localPort))
if err != nil {
return fmt.Errorf("listen: %w", err)
return xerrors.Errorf("listen: %w", err)
}

for {
conn, err := listener.Accept()
lc, err := listener.Accept()
if err != nil {
return fmt.Errorf("accept: %w", err)
return xerrors.Errorf("accept: %w", err)
}
go func() {
defer func() {
_ = conn.Close()
}()
channel, err := xwebrtc.NewProxyDataChannel(rtc, "forwarder", "tcp", c.remotePort)
if err != nil {
c.logger.Warn(context.Background(), "create data channel for proxying", slog.Error(err))
return
}
defer func() {
_ = channel.Close()
_ = lc.Close()
}()
err = xwebrtc.WaitForDataChannelOpen(context.Background(), channel)
if err != nil {
c.logger.Warn(context.Background(), "wait for data channel open", slog.Error(err))
return
}
rw, err := channel.Detach()
if err != nil {
c.logger.Warn(context.Background(), "detach channel", slog.Error(err))
return
}

go func() {
_, _ = io.Copy(conn, rw)
_, _ = io.Copy(lc, nc)
}()
_, _ = io.Copy(rw, conn)
_, _ = io.Copy(nc, lc)
}()
}
}
Loading