@@ -159,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
159
159
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
160
160
defer cancel ()
161
161
//nolint:dogsled
162
- conn , _ , stats , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
162
+ conn , agentClient , stats , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
163
163
sshClient , err := conn .SSHClient (ctx )
164
164
require .NoError (t , err )
165
165
defer sshClient .Close ()
@@ -189,6 +189,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
189
189
_ = stdin .Close ()
190
190
err = session .Wait ()
191
191
require .NoError (t , err )
192
+
193
+ assertConnectionReport (t , agentClient , proto .Connection_VSCODE , 0 , "" )
192
194
})
193
195
194
196
t .Run ("TracksJetBrains" , func (t * testing.T ) {
@@ -225,7 +227,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
225
227
remotePort := sc .Text ()
226
228
227
229
//nolint:dogsled
228
- conn , _ , stats , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
230
+ conn , agentClient , stats , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
229
231
sshClient , err := conn .SSHClient (ctx )
230
232
require .NoError (t , err )
231
233
@@ -261,6 +263,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
261
263
}, testutil .WaitLong , testutil .IntervalFast ,
262
264
"never saw stats after conn closes" ,
263
265
)
266
+
267
+ assertConnectionReport (t , agentClient , proto .Connection_JETBRAINS , 0 , "" )
264
268
})
265
269
}
266
270
@@ -918,7 +922,7 @@ func TestAgent_SFTP(t *testing.T) {
918
922
home = "/" + strings .ReplaceAll (home , "\\ " , "/" )
919
923
}
920
924
//nolint:dogsled
921
- conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
925
+ conn , agentClient , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
922
926
sshClient , err := conn .SSHClient (ctx )
923
927
require .NoError (t , err )
924
928
defer sshClient .Close ()
@@ -941,6 +945,10 @@ func TestAgent_SFTP(t *testing.T) {
941
945
require .NoError (t , err )
942
946
_ , err = os .Stat (tempFile )
943
947
require .NoError (t , err )
948
+
949
+ // Close the client to trigger disconnect event.
950
+ _ = client .Close ()
951
+ assertConnectionReport (t , agentClient , proto .Connection_SSH , 0 , "" )
944
952
}
945
953
946
954
func TestAgent_SCP (t * testing.T ) {
@@ -950,7 +958,7 @@ func TestAgent_SCP(t *testing.T) {
950
958
defer cancel ()
951
959
952
960
//nolint:dogsled
953
- conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
961
+ conn , agentClient , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
954
962
sshClient , err := conn .SSHClient (ctx )
955
963
require .NoError (t , err )
956
964
defer sshClient .Close ()
@@ -963,6 +971,10 @@ func TestAgent_SCP(t *testing.T) {
963
971
require .NoError (t , err )
964
972
_ , err = os .Stat (tempFile )
965
973
require .NoError (t , err )
974
+
975
+ // Close the client to trigger disconnect event.
976
+ scpClient .Close ()
977
+ assertConnectionReport (t , agentClient , proto .Connection_SSH , 0 , "" )
966
978
}
967
979
968
980
func TestAgent_FileTransferBlocked (t * testing.T ) {
@@ -987,7 +999,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
987
999
defer cancel ()
988
1000
989
1001
//nolint:dogsled
990
- conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1002
+ conn , agentClient , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
991
1003
o .BlockFileTransfer = true
992
1004
})
993
1005
sshClient , err := conn .SSHClient (ctx )
@@ -996,6 +1008,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
996
1008
_ , err = sftp .NewClient (sshClient )
997
1009
require .Error (t , err )
998
1010
assertFileTransferBlocked (t , err .Error ())
1011
+
1012
+ assertConnectionReport (t , agentClient , proto .Connection_SSH , agentssh .BlockedFileTransferErrorCode , "" )
999
1013
})
1000
1014
1001
1015
t .Run ("SCP with go-scp package" , func (t * testing.T ) {
@@ -1005,7 +1019,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
1005
1019
defer cancel ()
1006
1020
1007
1021
//nolint:dogsled
1008
- conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1022
+ conn , agentClient , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1009
1023
o .BlockFileTransfer = true
1010
1024
})
1011
1025
sshClient , err := conn .SSHClient (ctx )
@@ -1018,6 +1032,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
1018
1032
err = scpClient .CopyFile (context .Background (), strings .NewReader ("hello world" ), tempFile , "0755" )
1019
1033
require .Error (t , err )
1020
1034
assertFileTransferBlocked (t , err .Error ())
1035
+
1036
+ assertConnectionReport (t , agentClient , proto .Connection_SSH , agentssh .BlockedFileTransferErrorCode , "" )
1021
1037
})
1022
1038
1023
1039
t .Run ("Forbidden commands" , func (t * testing.T ) {
@@ -1031,7 +1047,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
1031
1047
defer cancel ()
1032
1048
1033
1049
//nolint:dogsled
1034
- conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1050
+ conn , agentClient , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1035
1051
o .BlockFileTransfer = true
1036
1052
})
1037
1053
sshClient , err := conn .SSHClient (ctx )
@@ -1053,6 +1069,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
1053
1069
msg , err := io .ReadAll (stdout )
1054
1070
require .NoError (t , err )
1055
1071
assertFileTransferBlocked (t , string (msg ))
1072
+
1073
+ assertConnectionReport (t , agentClient , proto .Connection_SSH , agentssh .BlockedFileTransferErrorCode , "" )
1056
1074
})
1057
1075
}
1058
1076
})
@@ -1661,8 +1679,16 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
1661
1679
defer cancel ()
1662
1680
1663
1681
//nolint:dogsled
1664
- conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
1682
+ conn , agentClient , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 )
1665
1683
id := uuid .New ()
1684
+
1685
+ // Test that the connection is reported. This must be tested in the
1686
+ // first connection because we care about verifying all of these.
1687
+ netConn0 , err := conn .ReconnectingPTY (ctx , id , 80 , 80 , "bash --norc" )
1688
+ require .NoError (t , err )
1689
+ _ = netConn0 .Close ()
1690
+ assertConnectionReport (t , agentClient , proto .Connection_RECONNECTING_PTY , 0 , "" )
1691
+
1666
1692
// --norc disables executing .bashrc, which is often used to customize the bash prompt
1667
1693
netConn1 , err := conn .ReconnectingPTY (ctx , id , 80 , 80 , "bash --norc" )
1668
1694
require .NoError (t , err )
@@ -2691,3 +2717,35 @@ func requireEcho(t *testing.T, conn net.Conn) {
2691
2717
require .NoError (t , err )
2692
2718
require .Equal (t , "test" , string (b ))
2693
2719
}
2720
+
2721
+ func assertConnectionReport (t testing.TB , agentClient * agenttest.Client , connectionType proto.Connection_Type , status int , reason string ) {
2722
+ t .Helper ()
2723
+
2724
+ var reports []* proto.ReportConnectionRequest
2725
+ if ! assert .Eventually (t , func () bool {
2726
+ reports = agentClient .GetConnectionReports ()
2727
+ return len (reports ) >= 2
2728
+ }, testutil .WaitMedium , testutil .IntervalFast , "waiting for 2 connection reports or more; got %d" , len (reports )) {
2729
+ return
2730
+ }
2731
+
2732
+ assert .Len (t , reports , 2 , "want 2 connection reports" )
2733
+
2734
+ assert .Equal (t , proto .Connection_CONNECT , reports [0 ].GetConnection ().GetAction (), "first report should be connect" )
2735
+ assert .Equal (t , proto .Connection_DISCONNECT , reports [1 ].GetConnection ().GetAction (), "second report should be disconnect" )
2736
+ assert .Equal (t , connectionType , reports [0 ].GetConnection ().GetType (), "connect type should be %s" , connectionType )
2737
+ assert .Equal (t , connectionType , reports [1 ].GetConnection ().GetType (), "disconnect type should be %s" , connectionType )
2738
+ t1 := reports [0 ].GetConnection ().GetTimestamp ().AsTime ()
2739
+ t2 := reports [1 ].GetConnection ().GetTimestamp ().AsTime ()
2740
+ assert .True (t , t1 .Before (t2 ) || t1 .Equal (t2 ), "connect timestamp should be before or equal to disconnect timestamp" )
2741
+ assert .NotEmpty (t , reports [0 ].GetConnection ().GetIp (), "connect ip should not be empty" )
2742
+ assert .NotEmpty (t , reports [1 ].GetConnection ().GetIp (), "disconnect ip should not be empty" )
2743
+ assert .Equal (t , 0 , int (reports [0 ].GetConnection ().GetStatusCode ()), "connect status code should be 0" )
2744
+ assert .Equal (t , status , int (reports [1 ].GetConnection ().GetStatusCode ()), "disconnect status code should be %d" , status )
2745
+ assert .Equal (t , "" , reports [0 ].GetConnection ().GetReason (), "connect reason should be empty" )
2746
+ if reason != "" {
2747
+ assert .Contains (t , reports [1 ].GetConnection ().GetReason (), reason , "disconnect reason should contain %s" , reason )
2748
+ } else {
2749
+ t .Logf ("connection report disconnect reason: %s" , reports [1 ].GetConnection ().GetReason ())
2750
+ }
2751
+ }
0 commit comments