Skip to content

Commit e9a28be

Browse files
committed
feat: add port scanning to agent
1 parent 8cd7d4f commit e9a28be

File tree

5 files changed

+161
-4
lines changed

5 files changed

+161
-4
lines changed

agent/agent.go

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"time"
2222

2323
"github.com/armon/circbuf"
24+
"github.com/cakturk/go-netstat/netstat"
2425
"github.com/gliderlabs/ssh"
2526
"github.com/google/uuid"
2627
"github.com/pkg/sftp"
@@ -37,13 +38,15 @@ import (
3738
)
3839

3940
const (
41+
ProtocolNetstat = "netstat"
4042
ProtocolReconnectingPTY = "reconnecting-pty"
4143
ProtocolSSH = "ssh"
4244
ProtocolDial = "dial"
4345
)
4446

4547
type Options struct {
4648
ReconnectingPTYTimeout time.Duration
49+
NetstatInterval time.Duration
4750
EnvironmentVariables map[string]string
4851
Logger slog.Logger
4952
}
@@ -65,10 +68,14 @@ func New(dialer Dialer, options *Options) io.Closer {
6568
if options.ReconnectingPTYTimeout == 0 {
6669
options.ReconnectingPTYTimeout = 5 * time.Minute
6770
}
71+
if options.NetstatInterval == 0 {
72+
options.NetstatInterval = 5 * time.Second
73+
}
6874
ctx, cancelFunc := context.WithCancel(context.Background())
6975
server := &agent{
7076
dialer: dialer,
7177
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
78+
netstatInterval: options.NetstatInterval,
7279
logger: options.Logger,
7380
closeCancel: cancelFunc,
7481
closed: make(chan struct{}),
@@ -85,6 +92,8 @@ type agent struct {
8592
reconnectingPTYs sync.Map
8693
reconnectingPTYTimeout time.Duration
8794

95+
netstatInterval time.Duration
96+
8897
connCloseWait sync.WaitGroup
8998
closeCancel context.CancelFunc
9099
closeMutex sync.Mutex
@@ -225,6 +234,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
225234
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
226235
case ProtocolDial:
227236
go a.handleDial(ctx, channel.Label(), channel.NetConn())
237+
case ProtocolNetstat:
238+
go a.handleNetstat(ctx, channel.Label(), channel.NetConn())
228239
default:
229240
a.logger.Warn(ctx, "unhandled protocol from channel",
230241
slog.F("protocol", channel.Protocol()),
@@ -359,12 +370,10 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
359370
if err != nil {
360371
return nil, xerrors.Errorf("getting os executable: %w", err)
361372
}
362-
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
363-
cmd.Env = append(cmd.Env, fmt.Sprintf(`PATH=%s%c%s`, os.Getenv("PATH"), filepath.ListSeparator, filepath.Dir(executablePath)))
364373
// Git on Windows resolves with UNIX-style paths.
365374
// If using backslashes, it's unable to find the executable.
366-
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
367-
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
375+
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
376+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
368377
// These prevent the user from having to specify _anything_ to successfully commit.
369378
// Both author and committer must be set!
370379
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_EMAIL=%s`, metadata.OwnerEmail))
@@ -707,6 +716,87 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
707716
Bicopy(ctx, conn, nconn)
708717
}
709718

719+
type NetstatPort struct {
720+
Name string `json:"name"`
721+
Port uint16 `json:"port"`
722+
}
723+
724+
type NetstatResponse struct {
725+
Ports []NetstatPort `json:"ports"`
726+
Error string `json:"error,omitempty"`
727+
Took time.Duration `json:"took"`
728+
}
729+
730+
func (a *agent) handleNetstat(ctx context.Context, label string, conn net.Conn) {
731+
write := func(resp NetstatResponse) error {
732+
b, err := json.Marshal(resp)
733+
if err != nil {
734+
a.logger.Warn(ctx, "write netstat response", slog.F("label", label), slog.Error(err))
735+
return xerrors.Errorf("marshal agent netstat response: %w", err)
736+
}
737+
_, err = conn.Write(b)
738+
if err != nil {
739+
a.logger.Warn(ctx, "write netstat response", slog.F("label", label), slog.Error(err))
740+
}
741+
return err
742+
}
743+
744+
scan := func() ([]NetstatPort, error) {
745+
if runtime.GOOS != "linux" && runtime.GOOS != "windows" {
746+
return nil, xerrors.New(fmt.Sprintf("Port scanning is not supported on %s", runtime.GOOS))
747+
}
748+
749+
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
750+
return s.State == netstat.Listen
751+
})
752+
if err != nil {
753+
return nil, err
754+
}
755+
756+
ports := []NetstatPort{}
757+
for _, tab := range tabs {
758+
ports = append(ports, NetstatPort{
759+
Name: tab.Process.Name,
760+
Port: tab.LocalAddr.Port,
761+
})
762+
}
763+
return ports, nil
764+
}
765+
766+
scanAndWrite := func() {
767+
start := time.Now()
768+
ports, err := scan()
769+
response := NetstatResponse{
770+
Ports: ports,
771+
Took: time.Since(start),
772+
}
773+
if err != nil {
774+
response.Error = err.Error()
775+
}
776+
_ = write(response)
777+
}
778+
779+
scanAndWrite()
780+
781+
// Using a timer instead of a ticker to ensure delay between calls otherwise
782+
// if nestat took longer than the interval we would constantly run it.
783+
timer := time.NewTimer(a.netstatInterval)
784+
go func() {
785+
defer conn.Close()
786+
defer timer.Stop()
787+
788+
for {
789+
select {
790+
case <-ctx.Done():
791+
return
792+
case <-timer.C:
793+
scanAndWrite()
794+
timer.Reset(a.netstatInterval)
795+
}
796+
}
797+
}()
798+
}
799+
710800
// isClosed returns whether the API is closed or not.
711801
func (a *agent) isClosed() bool {
712802
select {

agent/agent_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,57 @@ func TestAgent(t *testing.T) {
373373
require.ErrorContains(t, err, "no such file")
374374
require.Nil(t, netConn)
375375
})
376+
377+
t.Run("Netstat", func(t *testing.T) {
378+
t.Parallel()
379+
380+
var ports []agent.NetstatPort
381+
listen := func() {
382+
listener, err := net.Listen("tcp", "127.0.0.1:0")
383+
require.NoError(t, err)
384+
t.Cleanup(func() {
385+
_ = listener.Close()
386+
})
387+
388+
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
389+
require.True(t, valid)
390+
391+
name, err := os.Executable()
392+
require.NoError(t, err)
393+
394+
ports = append(ports, agent.NetstatPort{
395+
Name: filepath.Base(name),
396+
Port: uint16(tcpAddr.Port),
397+
})
398+
}
399+
400+
conn := setupAgent(t, agent.Metadata{}, 0)
401+
netConn, err := conn.Netstat(context.Background())
402+
require.NoError(t, err)
403+
t.Cleanup(func() {
404+
_ = netConn.Close()
405+
})
406+
407+
decoder := json.NewDecoder(netConn)
408+
409+
expectNetstat := func() {
410+
var res agent.NetstatResponse
411+
err = decoder.Decode(&res)
412+
require.NoError(t, err)
413+
414+
if runtime.GOOS == "linux" || runtime.GOOS == "windows" {
415+
require.Subset(t, res.Ports, ports)
416+
} else {
417+
require.Equal(t, fmt.Sprintf("Port scanning is not supported on %s", runtime.GOOS), res.Error)
418+
}
419+
}
420+
421+
listen()
422+
expectNetstat()
423+
424+
listen()
425+
expectNetstat()
426+
})
376427
}
377428

378429
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
@@ -420,6 +471,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
420471
}, &agent.Options{
421472
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
422473
ReconnectingPTYTimeout: ptyTimeout,
474+
NetstatInterval: 100 * time.Millisecond,
423475
})
424476
t.Cleanup(func() {
425477
_ = client.Close()

agent/conn.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne
112112
return channel.NetConn(), nil
113113
}
114114

115+
// Netstat returns a connection that serves a list of listening ports.
116+
func (c *Conn) Netstat(ctx context.Context) (net.Conn, error) {
117+
channel, err := c.CreateChannel(ctx, "netstat", &peer.ChannelOptions{
118+
Protocol: ProtocolNetstat,
119+
})
120+
if err != nil {
121+
return nil, xerrors.Errorf("netsat: %w", err)
122+
}
123+
return channel.NetConn(), nil
124+
}
125+
115126
func (c *Conn) Close() error {
116127
_ = c.Negotiator.DRPCConn().Close()
117128
return c.Conn.Close()

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ require (
126126
storj.io/drpc v0.0.30
127127
)
128128

129+
require github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5
130+
129131
require (
130132
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
131133
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ github.com/bugsnag/osext v0.0.0-20130617224835-0dd3f918b21b/go.mod h1:obH5gd0Bsq
240240
github.com/bugsnag/panicwrap v0.0.0-20151223152923-e2c28503fcd0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
241241
github.com/bytecodealliance/wasmtime-go v0.35.0 h1:VZjaZ0XOY0qp9TQfh0CQj9zl/AbdeXePVTALy8V1sKs=
242242
github.com/bytecodealliance/wasmtime-go v0.35.0/go.mod h1:q320gUxqyI8yB+ZqRuaJOEnGkAnHh6WtJjMaT2CW4wI=
243+
github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5 h1:BjkPE3785EwPhhyuFkbINB+2a1xATwk8SNDWnJiD41g=
244+
github.com/cakturk/go-netstat v0.0.0-20200220111822-e5b49efee7a5/go.mod h1:jtAfVaU/2cu1+wdSRPWE2c1N2qeAA3K4RH9pYgqwets=
243245
github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
244246
github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
245247
github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4=

0 commit comments

Comments
 (0)