@@ -6,19 +6,24 @@ import (
6
6
"net"
7
7
"net/http"
8
8
"net/http/httptest"
9
+ "sync"
10
+ "sync/atomic"
9
11
"testing"
10
12
"time"
11
13
12
- "nhooyr.io/websocket"
13
-
14
- "cdr.dev/slog"
15
- "cdr.dev/slog/sloggers/slogtest"
16
-
17
14
"github.com/google/uuid"
18
15
"github.com/stretchr/testify/assert"
19
16
"github.com/stretchr/testify/require"
17
+ "go.uber.org/mock/gomock"
18
+ "nhooyr.io/websocket"
19
+ "tailscale.com/tailcfg"
20
+ "tailscale.com/types/key"
20
21
22
+ "cdr.dev/slog"
23
+ "cdr.dev/slog/sloggers/slogtest"
21
24
"github.com/coder/coder/v2/tailnet"
25
+ "github.com/coder/coder/v2/tailnet/proto"
26
+ "github.com/coder/coder/v2/tailnet/tailnettest"
22
27
"github.com/coder/coder/v2/tailnet/test"
23
28
"github.com/coder/coder/v2/testutil"
24
29
)
@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
400
405
require .True (t , ok )
401
406
return client , server
402
407
}
408
+
409
+ func TestInMemoryCoordination (t * testing.T ) {
410
+ t .Parallel ()
411
+ ctx := testutil .Context (t , testutil .WaitShort )
412
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
413
+ clientID := uuid.UUID {1 }
414
+ agentID := uuid.UUID {2 }
415
+ mCoord := tailnettest .NewMockCoordinator (gomock .NewController (t ))
416
+ fConn := & fakeCoordinatee {}
417
+
418
+ reqs := make (chan * proto.CoordinateRequest , 100 )
419
+ resps := make (chan * proto.CoordinateResponse , 100 )
420
+ mCoord .EXPECT ().Coordinate (gomock .Any (), clientID , gomock .Any (), tailnet.ClientTunnelAuth {agentID }).
421
+ Times (1 ).Return (reqs , resps )
422
+
423
+ uut := tailnet .NewInMemoryCoordination (ctx , logger , clientID , agentID , mCoord , fConn )
424
+ defer uut .Close ()
425
+
426
+ coordinationTest (ctx , t , uut , fConn , reqs , resps , agentID )
427
+
428
+ select {
429
+ case err := <- uut .Error ():
430
+ require .NoError (t , err )
431
+ default :
432
+ // OK!
433
+ }
434
+ }
435
+
436
+ func TestRemoteCoordination (t * testing.T ) {
437
+ t .Parallel ()
438
+ ctx := testutil .Context (t , testutil .WaitShort )
439
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
440
+ clientID := uuid.UUID {1 }
441
+ agentID := uuid.UUID {2 }
442
+ mCoord := tailnettest .NewMockCoordinator (gomock .NewController (t ))
443
+ fConn := & fakeCoordinatee {}
444
+
445
+ reqs := make (chan * proto.CoordinateRequest , 100 )
446
+ resps := make (chan * proto.CoordinateResponse , 100 )
447
+ mCoord .EXPECT ().Coordinate (gomock .Any (), clientID , gomock .Any (), tailnet.ClientTunnelAuth {agentID }).
448
+ Times (1 ).Return (reqs , resps )
449
+
450
+ var coord tailnet.Coordinator = mCoord
451
+ coordPtr := atomic.Pointer [tailnet.Coordinator ]{}
452
+ coordPtr .Store (& coord )
453
+ svc , err := tailnet .NewClientService (
454
+ logger .Named ("svc" ), & coordPtr ,
455
+ time .Hour ,
456
+ func () * tailcfg.DERPMap { panic ("not implemented" ) },
457
+ )
458
+ require .NoError (t , err )
459
+ sC , cC := net .Pipe ()
460
+
461
+ serveErr := make (chan error , 1 )
462
+ go func () {
463
+ err := svc .ServeClient (ctx , tailnet .CurrentVersion .String (), sC , clientID , agentID )
464
+ serveErr <- err
465
+ }()
466
+
467
+ client , err := tailnet .NewDRPCClient (cC )
468
+ require .NoError (t , err )
469
+ protocol , err := client .Coordinate (ctx )
470
+ require .NoError (t , err )
471
+
472
+ uut := tailnet .NewRemoteCoordination (logger .Named ("coordination" ), protocol , fConn , agentID )
473
+ defer uut .Close ()
474
+
475
+ coordinationTest (ctx , t , uut , fConn , reqs , resps , agentID )
476
+
477
+ select {
478
+ case err := <- uut .Error ():
479
+ require .ErrorContains (t , err , "stream terminated by sending close" )
480
+ default :
481
+ // OK!
482
+ }
483
+ }
484
+
485
+ // coordinationTest tests that a coordination behaves correctly
486
+ func coordinationTest (
487
+ ctx context.Context , t * testing.T ,
488
+ uut tailnet.Coordination , fConn * fakeCoordinatee ,
489
+ reqs chan * proto.CoordinateRequest , resps chan * proto.CoordinateResponse ,
490
+ agentID uuid.UUID ,
491
+ ) {
492
+ // It should add the tunnel, since we configured as a client
493
+ req := testutil .RequireRecvCtx (ctx , t , reqs )
494
+ require .Equal (t , agentID [:], req .GetAddTunnel ().GetId ())
495
+
496
+ // when we call the callback, it should send a node update
497
+ require .NotNil (t , fConn .callback )
498
+ fConn .callback (& tailnet.Node {PreferredDERP : 1 })
499
+
500
+ req = testutil .RequireRecvCtx (ctx , t , reqs )
501
+ require .Equal (t , int32 (1 ), req .GetUpdateSelf ().GetNode ().GetPreferredDerp ())
502
+
503
+ // When we send a peer update, it should update the coordinatee
504
+ nk , err := key .NewNode ().Public ().MarshalBinary ()
505
+ require .NoError (t , err )
506
+ dk , err := key .NewDisco ().Public ().MarshalText ()
507
+ require .NoError (t , err )
508
+ updates := []* proto.CoordinateResponse_PeerUpdate {
509
+ {
510
+ Id : agentID [:],
511
+ Kind : proto .CoordinateResponse_PeerUpdate_NODE ,
512
+ Node : & proto.Node {
513
+ Id : 2 ,
514
+ Key : nk ,
515
+ Disco : string (dk ),
516
+ },
517
+ },
518
+ }
519
+ testutil .RequireSendCtx (ctx , t , resps , & proto.CoordinateResponse {PeerUpdates : updates })
520
+ require .Eventually (t , func () bool {
521
+ fConn .Lock ()
522
+ defer fConn .Unlock ()
523
+ return len (fConn .updates ) > 0
524
+ }, testutil .WaitShort , testutil .IntervalFast )
525
+ require .Len (t , fConn .updates [0 ], 1 )
526
+ require .Equal (t , agentID [:], fConn .updates [0 ][0 ].Id )
527
+
528
+ err = uut .Close ()
529
+ require .NoError (t , err )
530
+ uut .Error ()
531
+
532
+ // When we close, it should gracefully disconnect
533
+ req = testutil .RequireRecvCtx (ctx , t , reqs )
534
+ require .NotNil (t , req .Disconnect )
535
+
536
+ // It should set all peers lost on the coordinatee
537
+ require .Equal (t , 1 , fConn .setAllPeersLostCalls )
538
+ }
539
+
540
+ type fakeCoordinatee struct {
541
+ sync.Mutex
542
+ callback func (* tailnet.Node )
543
+ updates [][]* proto.CoordinateResponse_PeerUpdate
544
+ setAllPeersLostCalls int
545
+ }
546
+
547
+ func (f * fakeCoordinatee ) UpdatePeers (updates []* proto.CoordinateResponse_PeerUpdate ) error {
548
+ f .Lock ()
549
+ defer f .Unlock ()
550
+ f .updates = append (f .updates , updates )
551
+ return nil
552
+ }
553
+
554
+ func (f * fakeCoordinatee ) SetAllPeersLost () {
555
+ f .Lock ()
556
+ defer f .Unlock ()
557
+ f .setAllPeersLostCalls ++
558
+ }
559
+
560
+ func (f * fakeCoordinatee ) SetNodeCallback (callback func (* tailnet.Node )) {
561
+ f .Lock ()
562
+ defer f .Unlock ()
563
+ f .callback = callback
564
+ }
0 commit comments