@@ -412,6 +412,24 @@ func TestCoordinator(t *testing.T) {
412
412
_ = testutil .RequireRecvCtx (ctx , t , clientErrChan )
413
413
_ = testutil .RequireRecvCtx (ctx , t , closeClientChan )
414
414
})
415
+
416
+ t .Run ("AgentAck" , func (t * testing.T ) {
417
+ t .Parallel ()
418
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
419
+ coordinator := tailnet .NewCoordinator (logger )
420
+ ctx := testutil .Context (t , testutil .WaitShort )
421
+
422
+ clientID := uuid .New ()
423
+ agentID := uuid .New ()
424
+
425
+ aReq , _ := coordinator .Coordinate (ctx , agentID , agentID .String (), tailnet.AgentCoordinateeAuth {ID : agentID })
426
+ _ , cRes := coordinator .Coordinate (ctx , clientID , clientID .String (), tailnet.ClientCoordinateeAuth {AgentID : agentID })
427
+
428
+ aReq <- & proto.CoordinateRequest {TunnelAck : & proto.CoordinateRequest_Ack {Id : clientID [:]}}
429
+ ack := testutil .RequireRecvCtx (ctx , t , cRes )
430
+ require .NotNil (t , ack .TunnelAck )
431
+ require .Equal (t , agentID [:], ack .TunnelAck .Id )
432
+ })
415
433
}
416
434
417
435
// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
@@ -638,6 +656,61 @@ func TestRemoteCoordination(t *testing.T) {
638
656
}
639
657
}
640
658
659
+ func TestRemoteCoordination_Ack (t * testing.T ) {
660
+ t .Parallel ()
661
+ ctx := testutil .Context (t , testutil .WaitShort )
662
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
663
+ clientID := uuid.UUID {1 }
664
+ agentID := uuid.UUID {2 }
665
+ mCoord := tailnettest .NewMockCoordinator (gomock .NewController (t ))
666
+ fConn := & fakeCoordinatee {}
667
+
668
+ reqs := make (chan * proto.CoordinateRequest , 100 )
669
+ resps := make (chan * proto.CoordinateResponse , 100 )
670
+ mCoord .EXPECT ().Coordinate (gomock .Any (), clientID , gomock .Any (), tailnet.ClientCoordinateeAuth {agentID }).
671
+ Times (1 ).Return (reqs , resps )
672
+
673
+ var coord tailnet.Coordinator = mCoord
674
+ coordPtr := atomic.Pointer [tailnet.Coordinator ]{}
675
+ coordPtr .Store (& coord )
676
+ svc , err := tailnet .NewClientService (
677
+ logger .Named ("svc" ), & coordPtr ,
678
+ time .Hour ,
679
+ func () * tailcfg.DERPMap { panic ("not implemented" ) },
680
+ )
681
+ require .NoError (t , err )
682
+ sC , cC := net .Pipe ()
683
+
684
+ serveErr := make (chan error , 1 )
685
+ go func () {
686
+ err := svc .ServeClient (ctx , proto .CurrentVersion .String (), sC , clientID , agentID )
687
+ serveErr <- err
688
+ }()
689
+
690
+ client , err := tailnet .NewDRPCClient (cC , logger )
691
+ require .NoError (t , err )
692
+ protocol , err := client .Coordinate (ctx )
693
+ require .NoError (t , err )
694
+
695
+ uut := tailnet .NewRemoteCoordination (logger .Named ("coordination" ), protocol , fConn , agentID )
696
+ defer uut .Close ()
697
+
698
+ testutil .RequireSendCtx (ctx , t , resps , & proto.CoordinateResponse {
699
+ TunnelAck : & proto.CoordinateResponse_Ack {Id : agentID [:]},
700
+ })
701
+
702
+ testutil .RequireRecvCtx (ctx , t , uut .AwaitAck ())
703
+
704
+ require .NoError (t , uut .Close ())
705
+
706
+ select {
707
+ case err := <- uut .Error ():
708
+ require .ErrorContains (t , err , "stream terminated by sending close" )
709
+ default :
710
+ // OK!
711
+ }
712
+ }
713
+
641
714
// coordinationTest tests that a coordination behaves correctly
642
715
func coordinationTest (
643
716
ctx context.Context , t * testing.T ,
0 commit comments