diff --git a/.github/workflows/coder.yaml b/.github/workflows/coder.yaml index 4bdce128b7001..d6fed1b850567 100644 --- a/.github/workflows/coder.yaml +++ b/.github/workflows/coder.yaml @@ -169,10 +169,6 @@ jobs: terraform_version: 1.1.2 terraform_wrapper: false - - name: Install socat - if: runner.os == 'Linux' - run: sudo apt-get install -y socat - - name: Test with Mock Database shell: bash env: diff --git a/agent/agent.go b/agent/agent.go index 1ad4f7fd6ad91..b5ae020353ab7 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -32,23 +32,23 @@ type Options struct { Logger slog.Logger } -type Dialer func(ctx context.Context, options *peer.ConnOptions) (*peerbroker.Listener, error) +type Dialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) -func New(dialer Dialer, options *peer.ConnOptions) io.Closer { +func New(dialer Dialer, logger slog.Logger) io.Closer { ctx, cancelFunc := context.WithCancel(context.Background()) server := &agent{ - clientDialer: dialer, - options: options, - closeCancel: cancelFunc, - closed: make(chan struct{}), + dialer: dialer, + logger: logger, + closeCancel: cancelFunc, + closed: make(chan struct{}), } server.init(ctx) return server } type agent struct { - clientDialer Dialer - options *peer.ConnOptions + dialer Dialer + logger slog.Logger connCloseWait sync.WaitGroup closeCancel context.CancelFunc @@ -64,7 +64,7 @@ func (a *agent) run(ctx context.Context) { // An exponential back-off occurs when the connection is failing to dial. // This is to prevent server spam in case of a coderd outage. for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - peerListener, err = a.clientDialer(ctx, a.options) + peerListener, err = a.dialer(ctx, a.logger) if err != nil { if errors.Is(err, context.Canceled) { return @@ -72,10 +72,10 @@ func (a *agent) run(ctx context.Context) { if a.isClosed() { return } - a.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + a.logger.Warn(context.Background(), "failed to dial", slog.Error(err)) continue } - a.options.Logger.Info(context.Background(), "connected") + a.logger.Info(context.Background(), "connected") break } select { @@ -90,7 +90,7 @@ func (a *agent) run(ctx context.Context) { if a.isClosed() { return } - a.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) + a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) a.run(ctx) return } @@ -105,10 +105,9 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { go func() { select { case <-a.closed: - _ = conn.Close() case <-conn.Closed(): } - <-conn.Closed() + _ = conn.Close() a.connCloseWait.Done() }() for { @@ -117,7 +116,7 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { if errors.Is(err, peer.ErrClosed) || a.isClosed() { return } - a.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) + a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) return } @@ -125,7 +124,7 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { case "ssh": go a.sshServer.HandleConn(channel.NetConn()) default: - a.options.Logger.Warn(ctx, "unhandled protocol from channel", + a.logger.Warn(ctx, "unhandled protocol from channel", slog.F("protocol", channel.Protocol()), slog.F("label", channel.Label()), ) @@ -145,7 +144,7 @@ func (a *agent) init(ctx context.Context) { if err != nil { panic(err) } - sshLogger := a.options.Logger.Named("ssh-server") + sshLogger := a.logger.Named("ssh-server") forwardHandler := &ssh.ForwardedTCPHandler{} a.sshServer = &ssh.Server{ ChannelHandlers: map[string]ssh.ChannelHandler{ @@ -158,7 +157,7 @@ func (a *agent) init(ctx context.Context) { Handler: func(session ssh.Session) { err := a.handleSSHSession(session) if err != nil { - a.options.Logger.Warn(ctx, "ssh session failed", slog.Error(err)) + a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) _ = session.Exit(1) return } @@ -194,7 +193,7 @@ func (a *agent) init(ctx context.Context) { "sftp": func(session ssh.Session) { server, err := sftp.NewServer(session) if err != nil { - a.options.Logger.Debug(session.Context(), "initialize sftp server", slog.Error(err)) + a.logger.Debug(session.Context(), "initialize sftp server", slog.Error(err)) return } defer server.Close() @@ -202,7 +201,7 @@ func (a *agent) init(ctx context.Context) { if errors.Is(err, io.EOF) { return } - a.options.Logger.Debug(session.Context(), "sftp server exited with error", slog.Error(err)) + a.logger.Debug(session.Context(), "sftp server exited with error", slog.Error(err)) }, }, } @@ -250,7 +249,7 @@ func (a *agent) handleSSHSession(session ssh.Session) error { for win := range windowSize { err = ptty.Resize(uint16(win.Width), uint16(win.Height)) if err != nil { - a.options.Logger.Warn(context.Background(), "failed to resize tty", slog.Error(err)) + a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err)) } } }() diff --git a/agent/agent_test.go b/agent/agent_test.go index 5fad6435da061..fbb1432b725ee 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -170,11 +170,9 @@ func setupSSHSession(t *testing.T) *ssh.Session { func setupAgent(t *testing.T) *agent.Conn { client, server := provisionersdk.TransportPipe() - closer := agent.New(func(ctx context.Context, opts *peer.ConnOptions) (*peerbroker.Listener, error) { - return peerbroker.Listen(server, nil, opts) - }, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), - }) + closer := agent.New(func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { + return peerbroker.Listen(server, nil) + }, slogtest.Make(t, nil).Leveled(slog.LevelDebug)) t.Cleanup(func() { _ = client.Close() _ = server.Close() diff --git a/cli/agent.go b/cli/agent.go index 2bdd7bca225fe..cbe7ab1483a98 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -16,7 +16,6 @@ import ( "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" "github.com/coder/retry" ) @@ -110,9 +109,7 @@ func workspaceAgent() *cobra.Command { return xerrors.Errorf("writing agent session token to config: %w", err) } - closer := agent.New(client.ListenWorkspaceAgent, &peer.ConnOptions{ - Logger: logger, - }) + closer := agent.New(client.ListenWorkspaceAgent, logger) <-cmd.Context().Done() return closer.Close() }, diff --git a/cli/agent_test.go b/cli/agent_test.go index c289a32968a81..d8bede0e89007 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -61,7 +61,7 @@ func TestWorkspaceAgent(t *testing.T) { coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) require.NoError(t, err) - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil, nil) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() _, err = dialer.Ping() @@ -115,7 +115,7 @@ func TestWorkspaceAgent(t *testing.T) { coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) require.NoError(t, err) - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil, nil) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() _, err = dialer.Ping() diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 56f8b4f83ec9a..43217d345c8b6 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -19,7 +19,6 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/pty/ptytest" @@ -72,9 +71,7 @@ func TestConfigSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, slogtest.Make(t, nil)) t.Cleanup(func() { _ = agentCloser.Close() }) @@ -82,7 +79,7 @@ func TestConfigSSH(t *testing.T) { require.NoError(t, err) _ = tempFile.Close() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil, nil) + agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) require.NoError(t, err) defer agentConn.Close() diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index 6ebf142b9e54b..fcaf8da9e0fff 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -74,7 +74,7 @@ func TestGitSSH(t *testing.T) { coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID) require.NoError(t, err) - dialer, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil, nil) + dialer, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() _, err = dialer.Ping() diff --git a/cli/ssh.go b/cli/ssh.go index 9f4c0183b756c..627a0a041ee4d 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -10,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/mattn/go-isatty" - "github.com/pion/webrtc/v3" "github.com/spf13/cobra" gossh "golang.org/x/crypto/ssh" "golang.org/x/term" @@ -99,9 +98,7 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, nil) + conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) if err != nil { return err } diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d3bd0503607ec..9fed1fee8aaa9 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -17,7 +17,6 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/coder/coder/peer" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/pty/ptytest" @@ -70,9 +69,7 @@ func TestSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) agentClient.SessionToken = agentToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), - }) + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, slogtest.Make(t, nil).Leveled(slog.LevelDebug)) t.Cleanup(func() { _ = agentCloser.Close() }) @@ -115,9 +112,7 @@ func TestSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) agentClient.SessionToken = agentToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), - }) + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, slogtest.Make(t, nil).Leveled(slog.LevelDebug)) t.Cleanup(func() { _ = agentCloser.Close() }) diff --git a/cli/start.go b/cli/start.go index 50d24bf4f3a3f..6bcd16503a9fc 100644 --- a/cli/start.go +++ b/cli/start.go @@ -18,6 +18,7 @@ import ( "github.com/briandowns/spinner" "github.com/coreos/go-systemd/daemon" + "github.com/pion/turn/v2" "github.com/spf13/cobra" "golang.org/x/xerrors" "google.golang.org/api/idtoken" @@ -34,6 +35,7 @@ import ( "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/devtunnel" "github.com/coder/coder/coderd/gitsshkey" + "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/terraform" "github.com/coder/coder/provisionerd" @@ -56,11 +58,13 @@ func start() *cobra.Command { tlsEnable bool tlsKeyFile string tlsMinVersion string + turnRelayAddress string skipTunnel bool traceDatadog bool secureAuthCookie bool sshKeygenAlgorithmRaw string ) + root := &cobra.Command{ Use: "start", RunE: func(cmd *cobra.Command, args []string) error { @@ -156,6 +160,14 @@ func start() *cobra.Command { return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err) } + turnServer, err := turnconn.New(&turn.RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP(turnRelayAddress), + Address: turnRelayAddress, + }) + if err != nil { + return xerrors.Errorf("create turn server: %w", err) + } + options := &coderd.Options{ AccessURL: accessURLParsed, Logger: logger.Named("coderd"), @@ -164,6 +176,7 @@ func start() *cobra.Command { GoogleTokenValidator: validator, SecureAuthCookie: secureAuthCookie, SSHKeygenAlgorithm: sshKeygenAlgorithm, + TURNServer: turnServer, } _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "access-url: %s\n", accessURL) @@ -376,6 +389,8 @@ func start() *cobra.Command { cliflag.BoolVarP(root.Flags(), &skipTunnel, "skip-tunnel", "", "CODER_DEV_SKIP_TUNNEL", false, "Skip serving dev mode through an exposed tunnel for simple setup.") _ = root.Flags().MarkHidden("skip-tunnel") cliflag.BoolVarP(root.Flags(), &traceDatadog, "trace-datadog", "", "CODER_TRACE_DATADOG", false, "Send tracing data to a datadog agent") + cliflag.StringVarP(root.Flags(), &turnRelayAddress, "turn-relay-address", "", "CODER_TURN_RELAY_ADDRESS", "127.0.0.1", + "Specifies the address to bind TURN connections.") cliflag.BoolVarP(root.Flags(), &secureAuthCookie, "secure-auth-cookie", "", "CODER_SECURE_AUTH_COOKIE", false, "Specifies if the 'Secure' property is set on browser session cookies") cliflag.StringVarP(root.Flags(), &sshKeygenAlgorithmRaw, "ssh-keygen-algorithm", "", "CODER_SSH_KEYGEN_ALGORITHM", "ed25519", "Specifies the algorithm to use for generating ssh keys. "+ `Accepted values are "ed25519", "ecdsa", or "rsa4096"`) diff --git a/coderd/coderd.go b/coderd/coderd.go index 0f3c4537f6a71..bb9a6e2bdb641 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/pion/webrtc/v3" "google.golang.org/api/idtoken" chitrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-chi/chi.v5" @@ -20,23 +21,25 @@ import ( "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" ) // Options are requires parameters for Coder to start. type Options struct { - AgentConnectionUpdateFrequency time.Duration - AccessURL *url.URL - Logger slog.Logger - Database database.Store - Pubsub database.Pubsub - - AWSCertificates awsidentity.Certificates - GoogleTokenValidator *idtoken.Validator + AccessURL *url.URL + Logger slog.Logger + Database database.Store + Pubsub database.Pubsub - SecureAuthCookie bool - SSHKeygenAlgorithm gitsshkey.Algorithm + AgentConnectionUpdateFrequency time.Duration + AWSCertificates awsidentity.Certificates + GoogleTokenValidator *idtoken.Validator + ICEServers []webrtc.ICEServer + SecureAuthCookie bool + SSHKeygenAlgorithm gitsshkey.Algorithm + TURNServer *turnconn.Server } // New constructs the Coder API into an HTTP handler. @@ -175,6 +178,8 @@ func New(options *Options) (http.Handler, func()) { r.Use(httpmw.ExtractWorkspaceAgent(options.Database)) r.Get("/", api.workspaceAgentListen) r.Get("/gitsshkey", api.agentGitSSHKey) + r.Get("/turn", api.workspaceAgentTurn) + r.Get("/iceservers", api.workspaceAgentICEServers) }) r.Route("/{workspaceagent}", func(r chi.Router) { r.Use( @@ -183,6 +188,8 @@ func New(options *Options) (http.Handler, func()) { ) r.Get("/", api.workspaceAgent) r.Get("/dial", api.workspaceAgentDial) + r.Get("/turn", api.workspaceAgentTurn) + r.Get("/iceservers", api.workspaceAgentICEServers) }) }) r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) { diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 7f30e992d23d6..67e3febcf5cab 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/database/postgres" "github.com/coder/coder/coderd/gitsshkey" + "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" "github.com/coder/coder/provisioner/echo" @@ -91,9 +92,8 @@ func New(t *testing.T, options *Options) *codersdk.Client { } srv := httptest.NewUnstartedServer(nil) + ctx, cancelFunc := context.WithCancel(context.Background()) srv.Config.BaseContext = func(_ net.Listener) context.Context { - ctx, cancelFunc := context.WithCancel(context.Background()) - t.Cleanup(cancelFunc) return ctx } srv.Start() @@ -106,6 +106,9 @@ func New(t *testing.T, options *Options) *codersdk.Client { options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 } + turnServer, err := turnconn.New(nil) + require.NoError(t, err) + // We set the handler after server creation for the access URL. srv.Config.Handler, closeWait = coderd.New(&coderd.Options{ AgentConnectionUpdateFrequency: 150 * time.Millisecond, @@ -117,8 +120,11 @@ func New(t *testing.T, options *Options) *codersdk.Client { AWSCertificates: options.AWSInstanceIdentity, GoogleTokenValidator: options.GoogleInstanceIdentity, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, + TURNServer: turnServer, }) t.Cleanup(func() { + cancelFunc() + _ = turnServer.Close() srv.Close() closeWait() }) diff --git a/coderd/httpmw/workspaceagent.go b/coderd/httpmw/workspaceagent.go index f90d8ff79f9a4..05e7fe213c242 100644 --- a/coderd/httpmw/workspaceagent.go +++ b/coderd/httpmw/workspaceagent.go @@ -38,7 +38,7 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler { token, err := uuid.Parse(cookie.Value) if err != nil { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("parse token: %s", err), + Message: fmt.Sprintf("parse token %q: %s", cookie.Value, err), }) return } diff --git a/coderd/turnconn/turnconn.go b/coderd/turnconn/turnconn.go new file mode 100644 index 0000000000000..29462559ce3e7 --- /dev/null +++ b/coderd/turnconn/turnconn.go @@ -0,0 +1,203 @@ +package turnconn + +import ( + "io" + "net" + "sync" + + "github.com/pion/logging" + "github.com/pion/turn/v2" + "github.com/pion/webrtc/v3" + "golang.org/x/net/proxy" + "golang.org/x/xerrors" +) + +var ( + // reservedAddress is a magic address that's used exclusively + // for proxying via Coder. We don't proxy all TURN connections, + // because that'd exclude the possibility of a customer using + // their own TURN server. + reservedAddress = "127.0.0.1:12345" + credential = "coder" + localhost = &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + } + + // Proxy is a an ICE Server that uses a special hostname + // to indicate traffic should be proxied. + Proxy = webrtc.ICEServer{ + URLs: []string{"turns:" + reservedAddress}, + Username: "coder", + Credential: credential, + } +) + +// New constructs a new TURN server binding to the relay address provided. +// The relay address is used to broadcast the location of an accepted connection. +func New(relayAddress *turn.RelayAddressGeneratorStatic) (*Server, error) { + if relayAddress == nil { + relayAddress = &turn.RelayAddressGeneratorStatic{ + RelayAddress: localhost.IP, + Address: "127.0.0.1", + } + } + logger := logging.NewDefaultLoggerFactory() + logger.DefaultLogLevel = logging.LogLevelDebug + server := &Server{ + conns: make(chan net.Conn, 1), + closed: make(chan struct{}), + } + server.listener = &listener{ + srv: server, + } + var err error + server.turn, err = turn.NewServer(turn.ServerConfig{ + AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + // TURN connections require credentials. It's not important + // for our use-case, because our listener is entirely in-memory. + return turn.GenerateAuthKey(Proxy.Username, "", credential), true + }, + ListenerConfigs: []turn.ListenerConfig{{ + Listener: server.listener, + RelayAddressGenerator: relayAddress, + }}, + LoggerFactory: logger, + }) + if err != nil { + return nil, xerrors.Errorf("create server: %w", err) + } + + return server, nil +} + +// Server accepts and connects TURN allocations. +// +// This is a thin wrapper around pion/turn that pipes +// connections directly to the in-memory handler. +type Server struct { + listener *listener + turn *turn.Server + + closeMutex sync.Mutex + closed chan (struct{}) + conns chan (net.Conn) +} + +// Accept consumes a new connection into the TURN server. +// A unique remote address must exist per-connection. +// pion/turn indexes allocations based on the address. +func (s *Server) Accept(nc net.Conn, remoteAddress, localAddress *net.TCPAddr) *Conn { + if localAddress == nil { + localAddress = localhost + } + conn := &Conn{ + Conn: nc, + remoteAddress: remoteAddress, + localAddress: localAddress, + closed: make(chan struct{}), + } + s.conns <- conn + return conn +} + +// Close ends the TURN server. +func (s *Server) Close() error { + s.closeMutex.Lock() + defer s.closeMutex.Unlock() + if s.isClosed() { + return nil + } + err := s.turn.Close() + close(s.conns) + close(s.closed) + return err +} + +func (s *Server) isClosed() bool { + select { + case <-s.closed: + return true + default: + return false + } +} + +// listener implements net.Listener for the TURN +// server to consume. +type listener struct { + srv *Server +} + +func (l *listener) Accept() (net.Conn, error) { + conn, ok := <-l.srv.conns + if !ok { + return nil, io.EOF + } + return conn, nil +} + +func (*listener) Close() error { + return nil +} + +func (*listener) Addr() net.Addr { + return nil +} + +type Conn struct { + net.Conn + closed chan struct{} + localAddress *net.TCPAddr + remoteAddress *net.TCPAddr +} + +func (c *Conn) LocalAddr() net.Addr { + return c.localAddress +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.remoteAddress +} + +// Closed returns a channel which is closed when +// the connection is. +func (c *Conn) Closed() <-chan struct{} { + return c.closed +} + +func (c *Conn) Close() error { + err := c.Conn.Close() + select { + case <-c.closed: + default: + close(c.closed) + } + return err +} + +type dialer func(network, addr string) (c net.Conn, err error) + +func (d dialer) Dial(network, addr string) (c net.Conn, err error) { + return d(network, addr) +} + +// ProxyDialer accepts a proxy function that's called when the connection +// address matches the reserved host in the "Proxy" ICE server. +// +// This should be passed to WebRTC connections as an ICE dialer. +func ProxyDialer(proxyFunc func() (c net.Conn, err error)) proxy.Dialer { + return dialer(func(network, addr string) (net.Conn, error) { + if addr != reservedAddress { + return proxy.Direct.Dial(network, addr) + } + netConn, err := proxyFunc() + if err != nil { + return nil, err + } + return &Conn{ + localAddress: localhost, + closed: make(chan struct{}), + Conn: netConn, + }, nil + }) +} diff --git a/coderd/turnconn/turnconn_test.go b/coderd/turnconn/turnconn_test.go new file mode 100644 index 0000000000000..346bfb2d420c1 --- /dev/null +++ b/coderd/turnconn/turnconn_test.go @@ -0,0 +1,106 @@ +package turnconn_test + +import ( + "net" + "sync" + "testing" + + "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/turnconn" + "github.com/coder/coder/peer" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestTURNConn(t *testing.T) { + t.Parallel() + turnServer, err := turnconn.New(nil) + require.NoError(t, err) + defer turnServer.Close() + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + clientDialer, clientTURN := net.Pipe() + turnServer.Accept(clientTURN, &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 16000, + }, nil) + require.NoError(t, err) + clientSettings := webrtc.SettingEngine{} + clientSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) + clientSettings.SetRelayAcceptanceMinWait(0) + clientSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) { + return clientDialer, nil + })) + client, err := peer.Client([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{ + SettingEngine: clientSettings, + Logger: logger.Named("client"), + }) + require.NoError(t, err) + + serverDialer, serverTURN := net.Pipe() + turnServer.Accept(serverTURN, &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 16001, + }, nil) + require.NoError(t, err) + serverSettings := webrtc.SettingEngine{} + serverSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6}) + serverSettings.SetRelayAcceptanceMinWait(0) + serverSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) { + return serverDialer, nil + })) + server, err := peer.Server([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{ + SettingEngine: serverSettings, + Logger: logger.Named("server"), + }) + require.NoError(t, err) + exchange(t, client, server) + + _, err = client.Ping() + require.NoError(t, err) +} + +func exchange(t *testing.T, client, server *peer.Conn) { + var wg sync.WaitGroup + wg.Add(2) + t.Cleanup(func() { + _ = client.Close() + _ = server.Close() + + wg.Wait() + }) + go func() { + defer wg.Done() + for { + select { + case c := <-server.LocalCandidate(): + client.AddRemoteCandidate(c) + case c := <-server.LocalSessionDescription(): + client.SetRemoteSessionDescription(c) + case <-server.Closed(): + return + } + } + }() + go func() { + defer wg.Done() + for { + select { + case c := <-client.LocalCandidate(): + server.AddRemoteCandidate(c) + case c := <-client.LocalSessionDescription(): + server.SetRemoteSessionDescription(c) + case <-client.Closed(): + return + } + } + }() +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index b00b9657bc3e7..5ad5c40992edb 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -5,7 +5,9 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" + "strconv" "time" "github.com/hashicorp/yamux" @@ -219,6 +221,59 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) { } } +func (api *api) workspaceAgentICEServers(rw http.ResponseWriter, _ *http.Request) { + httpapi.Write(rw, http.StatusOK, api.ICEServers) +} + +// workspaceAgentTurn proxies a WebSocket connection to the TURN server. +func (api *api) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) { + api.websocketWaitMutex.Lock() + api.websocketWaitGroup.Add(1) + api.websocketWaitMutex.Unlock() + defer api.websocketWaitGroup.Done() + + localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) + remoteAddress := &net.TCPAddr{ + IP: net.ParseIP(r.RemoteAddr), + } + // By default requests have the remote address and port. + host, port, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("get remote address: %s", err), + }) + return + } + remoteAddress.IP = net.ParseIP(host) + remoteAddress.Port, err = strconv.Atoi(port) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("remote address %q has no parsable port: %s", r.RemoteAddr, err), + }) + return + } + + wsConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("accept websocket: %s", err), + }) + return + } + defer func() { + _ = wsConn.Close(websocket.StatusNormalClosure, "") + }() + netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary) + api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) + select { + case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed(): + case <-r.Context().Done(): + } + api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress)) +} + func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) { var envs map[string]string if dbAgent.EnvironmentVariables.Valid { diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 8905b7277e5bf..14a889285ab5e 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "cdr.dev/slog" @@ -89,16 +90,65 @@ func TestWorkspaceAgentListen(t *testing.T) { agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken - agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug)) t.Cleanup(func() { _ = agentCloser.Close() }) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() }) + _, err = conn.Ping() + require.NoError(t, err) +} + +func TestWorkspaceAgentTURN(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) + user := coderdtest.CreateFirstUser(t, client) + daemonCloser := coderdtest.NewProvisionerDaemon(t, client) + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + daemonCloser.Close() + + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, slogtest.Make(t, nil)) + t.Cleanup(func() { + _ = agentCloser.Close() + }) + resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + opts := &peer.ConnOptions{ + Logger: slogtest.Make(t, nil).Named("client"), + } + // Force a TURN connection! + opts.SettingEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4}) + conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, opts) require.NoError(t, err) t.Cleanup(func() { _ = conn.Close() diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 5a83d0b265556..1deb1fdb94f73 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/cookiejar" @@ -12,11 +13,15 @@ import ( "github.com/google/uuid" "github.com/hashicorp/yamux" "github.com/pion/webrtc/v3" + "golang.org/x/net/proxy" "golang.org/x/xerrors" "nhooyr.io/websocket" + "cdr.dev/slog" + "github.com/coder/coder/agent" "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" @@ -134,9 +139,9 @@ func (c *Client) AuthWorkspaceAWSInstanceIdentity(ctx context.Context) (Workspac return resp, json.NewDecoder(res.Body).Decode(&resp) } -// ListenWorkspaceAgent connects as a workspace agent. -// It obtains the agent ID based off the session token. -func (c *Client) ListenWorkspaceAgent(ctx context.Context, opts *peer.ConnOptions) (*peerbroker.Listener, error) { +// ListenWorkspaceAgent connects as a workspace agent identifying with the session token. +// On each inbound connection request, connection info is fetched. +func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) { serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me") if err != nil { return nil, xerrors.Errorf("parse url: %w", err) @@ -169,15 +174,36 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context, opts *peer.ConnOption if err != nil { return nil, xerrors.Errorf("multiplex client: %w", err) } - return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, error) { - return []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, nil - }, opts) + return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { + // This can be cached if it adds to latency too much. + res, err := c.request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/iceservers", nil) + if err != nil { + return nil, nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, nil, readBodyAsError(res) + } + var iceServers []webrtc.ICEServer + err = json.NewDecoder(res.Body).Decode(&iceServers) + if err != nil { + return nil, nil, err + } + + options := webrtc.SettingEngine{} + options.SetSrflxAcceptanceMinWait(0) + options.SetRelayAcceptanceMinWait(0) + options.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, "/api/v2/workspaceagents/me/turn")) + iceServers = append(iceServers, turnconn.Proxy) + return iceServers, &peer.ConnOptions{ + SettingEngine: options, + Logger: logger, + }, nil + }) } // DialWorkspaceAgent creates a connection to the specified resource. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, iceServers []webrtc.ICEServer, opts *peer.ConnOptions) (*agent.Conn, error) { +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (*agent.Conn, error) { serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String())) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) @@ -215,7 +241,30 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, iceS if err != nil { return nil, xerrors.Errorf("negotiate connection: %w", err) } - peerConn, err := peerbroker.Dial(stream, iceServers, opts) + + res, err = c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/iceservers", agentID.String()), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var iceServers []webrtc.ICEServer + err = json.NewDecoder(res.Body).Decode(&iceServers) + if err != nil { + return nil, err + } + + if options == nil { + options = &peer.ConnOptions{} + } + options.SettingEngine.SetSrflxAcceptanceMinWait(0) + options.SettingEngine.SetRelayAcceptanceMinWait(0) + options.SettingEngine.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, fmt.Sprintf("/api/v2/workspaceagents/%s/turn", agentID.String()))) + iceServers = append(iceServers, turnconn.Proxy) + + peerConn, err := peerbroker.Dial(stream, iceServers, options) if err != nil { return nil, xerrors.Errorf("dial peer: %w", err) } @@ -238,3 +287,24 @@ func (c *Client) WorkspaceAgent(ctx context.Context, id uuid.UUID) (WorkspaceAge var workspaceAgent WorkspaceAgent return workspaceAgent, json.NewDecoder(res.Body).Decode(&workspaceAgent) } + +func (c *Client) turnProxyDialer(ctx context.Context, httpClient *http.Client, path string) proxy.Dialer { + return turnconn.ProxyDialer(func() (net.Conn, error) { + turnURL, err := c.URL.Parse(path) + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + conn, res, err := websocket.Dial(ctx, turnURL.String(), &websocket.DialOptions{ + HTTPClient: httpClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) + } + return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil + }) +} diff --git a/go.mod b/go.mod index 97f401ad7e0b3..fee0628557fa4 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/pion/datachannel v1.5.2 github.com/pion/logging v0.2.2 github.com/pion/transport v0.13.0 + github.com/pion/turn/v2 v2.0.8 github.com/pion/webrtc/v3 v3.1.29 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/sftp v1.13.4 @@ -94,6 +95,7 @@ require ( golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 + golang.org/x/net v0.0.0-20220401154927-543a649e0bdd golang.org/x/oauth2 v0.0.0-20220309155454-6242fa91716a golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 @@ -186,7 +188,7 @@ require ( github.com/pelletier/go-toml/v2 v2.0.0-beta.7.0.20220408132554-2377ac4bc04c // indirect github.com/philhofer/fwd v1.1.1 // indirect github.com/pion/dtls/v2 v2.1.3 // indirect - github.com/pion/ice/v2 v2.2.3 // indirect + github.com/pion/ice/v2 v2.2.4 // indirect github.com/pion/interceptor v0.1.10 // indirect github.com/pion/mdns v0.0.5 // indirect github.com/pion/randutil v0.1.0 // indirect @@ -196,7 +198,6 @@ require ( github.com/pion/sdp/v3 v3.0.4 // indirect github.com/pion/srtp/v2 v2.0.5 // indirect github.com/pion/stun v0.3.5 // indirect - github.com/pion/turn/v2 v2.0.8 // indirect github.com/pion/udp v0.1.1 // indirect github.com/pires/go-proxyproto v0.5.0 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -219,7 +220,6 @@ require ( github.com/zclconf/go-cty v1.10.0 // indirect github.com/zeebo/errs v1.2.2 // indirect go.opencensus.io v0.23.0 // indirect - golang.org/x/net v0.0.0-20220401154927-543a649e0bdd // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index 127c93d5b9cb7..685d3ead2998b 100644 --- a/go.sum +++ b/go.sum @@ -1355,8 +1355,9 @@ github.com/pion/datachannel v1.5.2 h1:piB93s8LGmbECrpO84DnkIVWasRMk3IimbcXkTQLE6 github.com/pion/datachannel v1.5.2/go.mod h1:FTGQWaHrdCwIJ1rw6xBIfZVkslikjShim5yr05XFuCQ= github.com/pion/dtls/v2 v2.1.3 h1:3UF7udADqous+M2R5Uo2q/YaP4EzUoWKdfX2oscCUio= github.com/pion/dtls/v2 v2.1.3/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus= -github.com/pion/ice/v2 v2.2.3 h1:kBVhmtMcI1L3bWDepilO9kKpCGpLQeppCuVxVS8obhE= github.com/pion/ice/v2 v2.2.3/go.mod h1:SWuHiOGP17lGromHTFadUe1EuPgFh/oCU6FCMZHooVE= +github.com/pion/ice/v2 v2.2.4 h1:sTHT39ywr5uqzyEMT7thEhOWsNOcdkHSZBbgQohFuZU= +github.com/pion/ice/v2 v2.2.4/go.mod h1:SWuHiOGP17lGromHTFadUe1EuPgFh/oCU6FCMZHooVE= github.com/pion/interceptor v0.1.10 h1:DJ2GjMGm4XGIQgMJxuEpdaExdY/6RdngT7Uh4oVmquU= github.com/pion/interceptor v0.1.10/go.mod h1:Lh3JSl/cbJ2wP8I3ccrjh1K/deRGRn3UlSPuOTiHb6U= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= diff --git a/peer/conn.go b/peer/conn.go index 38466fa3cbf1b..23db734010c25 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -50,9 +50,9 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp } opts.SettingEngine.DetachDataChannels() - factory := logging.NewDefaultLoggerFactory() - factory.DefaultLogLevel = logging.LogLevelDisabled - opts.SettingEngine.LoggerFactory = factory + logger := logging.NewDefaultLoggerFactory() + logger.DefaultLogLevel = logging.LogLevelDisabled + opts.SettingEngine.LoggerFactory = logger api := webrtc.NewAPI(webrtc.WithSettingEngine(opts.SettingEngine)) rtc, err := api.NewPeerConnection(webrtc.Configuration{ ICEServers: servers, diff --git a/peerbroker/dial_test.go b/peerbroker/dial_test.go index 48cb7a59a5876..efd4e6917ac41 100644 --- a/peerbroker/dial_test.go +++ b/peerbroker/dial_test.go @@ -32,13 +32,13 @@ func TestDial(t *testing.T) { defer server.Close() settingEngine := webrtc.SettingEngine{} - listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, error) { + listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { return []webrtc.ICEServer{{ - URLs: []string{"stun:stun.l.google.com:19302"}, - }}, nil - }, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), - SettingEngine: settingEngine, + URLs: []string{"stun:stun.l.google.com:19302"}, + }}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), + SettingEngine: settingEngine, + }, nil }) require.NoError(t, err) diff --git a/peerbroker/listen.go b/peerbroker/listen.go index c68dfafa19af0..34c91ea6e51a4 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -17,22 +17,21 @@ import ( "github.com/coder/coder/peerbroker/proto" ) -// ICEServersFunc returns ICEServers when a new connection is requested. -type ICEServersFunc func(ctx context.Context) ([]webrtc.ICEServer, error) +// ConnSettingsFunc returns initialization options for a connection +type ConnSettingsFunc func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) // Listen consumes the transport as the server-side of the PeerBroker dRPC service. // The Accept function must be serviced, or new connections will hang. -func Listen(connListener net.Listener, iceServersFunc ICEServersFunc, opts *peer.ConnOptions) (*Listener, error) { - if iceServersFunc == nil { - iceServersFunc = func(ctx context.Context) ([]webrtc.ICEServer, error) { - return []webrtc.ICEServer{}, nil +func Listen(connListener net.Listener, connSettingsFunc ConnSettingsFunc) (*Listener, error) { + if connSettingsFunc == nil { + connSettingsFunc = func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { + return []webrtc.ICEServer{}, nil, nil } } ctx, cancelFunc := context.WithCancel(context.Background()) listener := &Listener{ connectionChannel: make(chan *peer.Conn), connectionListener: connListener, - iceServersFunc: iceServersFunc, closeFunc: cancelFunc, closed: make(chan struct{}), @@ -40,7 +39,7 @@ func Listen(connListener net.Listener, iceServersFunc ICEServersFunc, opts *peer mux := drpcmux.New() err := proto.DRPCRegisterPeerBroker(mux, &peerBrokerService{ - connOptions: opts, + connSettingsFunc: connSettingsFunc, listener: listener, }) @@ -59,7 +58,6 @@ func Listen(connListener net.Listener, iceServersFunc ICEServersFunc, opts *peer type Listener struct { connectionChannel chan *peer.Conn connectionListener net.Listener - iceServersFunc ICEServersFunc closeFunc context.CancelFunc closed chan struct{} @@ -112,17 +110,16 @@ func (l *Listener) isClosed() bool { type peerBrokerService struct { listener *Listener - connOptions *peer.ConnOptions + connSettingsFunc ConnSettingsFunc } // NegotiateConnection negotiates a WebRTC connection. func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error { - iceServers, err := b.listener.iceServersFunc(stream.Context()) + iceServers, connOptions, err := b.connSettingsFunc(stream.Context()) if err != nil { - return xerrors.Errorf("get ice servers: %w", err) + return xerrors.Errorf("get connection settings: %w", err) } - // Start with no ICE servers. They can be sent by the client if provided. - peerConn, err := peer.Server(iceServers, b.connOptions) + peerConn, err := peer.Server(iceServers, connOptions) if err != nil { return xerrors.Errorf("create peer connection: %w", err) } diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go index 622f2a1c0b7fb..81582a91d4b84 100644 --- a/peerbroker/listen_test.go +++ b/peerbroker/listen_test.go @@ -23,7 +23,7 @@ func TestListen(t *testing.T) { defer client.Close() defer server.Close() - listener, err := peerbroker.Listen(server, nil, nil) + listener, err := peerbroker.Listen(server, nil) require.NoError(t, err) api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) @@ -43,7 +43,7 @@ func TestListen(t *testing.T) { defer client.Close() defer server.Close() - listener, err := peerbroker.Listen(server, nil, nil) + listener, err := peerbroker.Listen(server, nil) require.NoError(t, err) go listener.Close() _, err = listener.Accept() diff --git a/peerbroker/proxy_test.go b/peerbroker/proxy_test.go index f036305b46957..72627c9d51a1d 100644 --- a/peerbroker/proxy_test.go +++ b/peerbroker/proxy_test.go @@ -29,8 +29,10 @@ func TestProxy(t *testing.T) { defer listenerClient.Close() defer listenerServer.Close() - listener, err := peerbroker.Listen(listenerServer, nil, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), + listener, err := peerbroker.Listen(listenerServer, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) { + return nil, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), + }, nil }) require.NoError(t, err)