diff --git a/README.md b/README.md index 2a65e19..ee6e38c 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ You can download the binary for your platform from [Releases](https://github.com Example: ```shell -GOHPTS_RELEASE=v1.9.3; wget -v https://github.com/shadowy-pycoder/go-http-proxy-to-socks/releases/download/$GOHPTS_RELEASE/gohpts-$GOHPTS_RELEASE-linux-amd64.tar.gz -O gohpts && tar xvzf gohpts && mv -f gohpts-$GOHPTS_RELEASE-linux-amd64 gohpts && ./gohpts -h +GOHPTS_RELEASE=v1.9.4; wget -v https://github.com/shadowy-pycoder/go-http-proxy-to-socks/releases/download/$GOHPTS_RELEASE/gohpts-$GOHPTS_RELEASE-linux-amd64.tar.gz -O gohpts && tar xvzf gohpts && mv -f gohpts-$GOHPTS_RELEASE-linux-amd64 gohpts && ./gohpts -h ``` Alternatively, you can install it using `go install` command (requires Go [1.24](https://go.dev/doc/install) or later): diff --git a/cmd/gohpts/cli.go b/cmd/gohpts/cli.go index f8019fb..e40d7b9 100644 --- a/cmd/gohpts/cli.go +++ b/cmd/gohpts/cli.go @@ -95,7 +95,7 @@ func root(args []string) error { ) flags.StringVar(&conf.Interface, "i", "", "Bind proxy to specific network interface") flags.BoolFunc("I", "Display list of network interfaces and exit", func(flagValue string) error { - if err := network.DisplayInterfaces(); err != nil { + if err := network.DisplayInterfaces(false); err != nil { fmt.Fprintf(os.Stderr, "%s: %v\n", app, err) os.Exit(2) } diff --git a/colorize.go b/colorize.go new file mode 100644 index 0000000..50d4247 --- /dev/null +++ b/colorize.go @@ -0,0 +1,406 @@ +package gohpts + +import ( + "bufio" + "bytes" + "fmt" + "math/rand" + "net" + "net/http" + "regexp" + "strings" + "time" + + "github.com/google/uuid" + "github.com/shadowy-pycoder/colors" + "github.com/shadowy-pycoder/mshark/layers" +) + +var ( + ipPortPattern = regexp.MustCompile( + `\b(?:\d{1,3}\.){3}\d{1,3}(?::(6553[0-5]|655[0-2]\d|65[0-4]\d{2}|6[0-4]\d{3}|[1-5]?\d{1,4}))?\b`, + ) + domainPattern = regexp.MustCompile( + `\b(?:[a-zA-Z0-9-]{1,63}\.)+(?:com|net|org|io|co|uk|ru|de|edu|gov|info|biz|dev|app|ai)(?::(6553[0-5]|655[0-2]\d|65[0-4]\d{2}|6[0-4]\d{3}|[1-5]?\d{1,4}))?\b`, + ) + jwtPattern = regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}\b`) + authPattern = regexp.MustCompile( + `(?i)(?:"|')?(authorization|auth[_-]?token|access[_-]?token|api[_-]?key|secret|token)(?:"|')?\s*[:=]\s*(?:"|')?([^\s"'&]+)`, + ) + credsPattern = regexp.MustCompile( + `(?i)(?:"|')?(username|user|login|email|password|pass|pwd)(?:"|')?\s*[:=]\s*(?:"|')?([^\s"'&]+)`, + ) + macPattern = regexp.MustCompile(`(?i)([a-z0-9_]+_[0-9a-f]{2}(?::[0-9a-f]{2}){2}|(?:[0-9a-f]{2}[:-]){5}[0-9a-f]{2})`) +) + +var rColors = []func(string) *colors.Color{ + colors.Beige, + colors.Blue, + colors.Gray, + colors.Green, + colors.LightBlue, + colors.Magenta, + colors.Red, + colors.Yellow, + colors.BeigeBg, + colors.BlueBg, + colors.GrayBg, + colors.GreenBg, + colors.LightBlueBg, + colors.MagentaBg, + colors.RedBgDark, + colors.YellowBg, +} + +func randColor() func(string) *colors.Color { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + randIndex := r.Intn(len(rColors)) + return rColors[randIndex] +} + +func getID(nocolor bool) string { + id := uuid.New() + if nocolor { + return colors.WrapBrackets(id.String()) + } + return randColor()(colors.WrapBrackets(id.String())).String() +} + +// https://stackoverflow.com/a/1094933/1333724 +func prettifyBytes(b int64) string { + bf := float64(b) + for _, unit := range []string{"", "K", "M", "G", "T", "P", "E", "Z"} { + if bf < 1000.0 { + return fmt.Sprintf("%3.1f%sB", bf, unit) + } + bf /= 1000.0 + } + return fmt.Sprintf("%.1fYB", bf) +} + +func colorizeStatus(code int, status string, bg bool) string { + if bg { + if code < 300 { + status = colors.GreenBg(status).String() + } else if code < 400 { + status = colors.YellowBg(status).String() + } else { + status = colors.RedBgDark(status).String() + } + } else { + if code < 300 { + status = colors.Green(status).String() + } else if code < 400 { + status = colors.Yellow(status).String() + } else { + status = colors.Red(status).String() + } + } + return status +} + +func colorizeHTTP( + req *http.Request, + resp *http.Response, + reqBodySaved, respBodySaved *[]byte, + id string, + ts, + body, + nocolor bool, +) string { + var sb strings.Builder + if ts { + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + } + if nocolor { + sb.WriteString(id) + sb.WriteString(fmt.Sprintf(" %s %s %s ", req.Method, req.URL, req.Proto)) + if req.UserAgent() != "" { + sb.WriteString(colors.WrapBrackets(req.UserAgent())) + } + if req.ContentLength > 0 { + sb.WriteString(fmt.Sprintf(" Len: %d", req.ContentLength)) + } + sb.WriteString(" → ") + sb.WriteString(fmt.Sprintf("%s %s ", resp.Proto, resp.Status)) + if resp.ContentLength > 0 { + sb.WriteString(fmt.Sprintf("Len: %d", resp.ContentLength)) + } + if body && len(*reqBodySaved) > 0 { + b := colorizeBody(reqBodySaved, nocolor) + if b != "" { + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(fmt.Sprintf(" req_body: %s", b)) + } + } + if body && len(*respBodySaved) > 0 { + b := colorizeBody(respBodySaved, nocolor) + if b != "" { + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(fmt.Sprintf(" resp_body: %s", b)) + } + } + } else { + sb.WriteString(id) + sb.WriteString(colors.Gray(fmt.Sprintf(" %s ", req.Method)).String()) + sb.WriteString(colors.YellowBg(fmt.Sprintf("%s ", req.URL)).String()) + sb.WriteString(colors.BlueBg(fmt.Sprintf("%s ", req.Proto)).String()) + if req.UserAgent() != "" { + sb.WriteString(colors.Gray(colors.WrapBrackets(req.UserAgent())).String()) + } + if req.ContentLength > 0 { + sb.WriteString(colors.BeigeBg(fmt.Sprintf(" Len: %d", req.ContentLength)).String()) + } + sb.WriteString(colors.MagentaBg(" → ").String()) + sb.WriteString(colors.BlueBg(fmt.Sprintf("%s ", resp.Proto)).String()) + sb.WriteString(colorizeStatus(resp.StatusCode, fmt.Sprintf("%s ", resp.Status), true)) + if resp.ContentLength > 0 { + sb.WriteString(colors.BeigeBg(fmt.Sprintf("Len: %d", resp.ContentLength)).String()) + } + if body && len(*reqBodySaved) > 0 { + b := colorizeBody(reqBodySaved, nocolor) + if b != "" { + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(colors.RedBgDark(" req_body: ").String()) + sb.WriteString(b) + } + } + if body && len(*respBodySaved) > 0 { + b := colorizeBody(respBodySaved, nocolor) + if b != "" { + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(colors.RedBgDark(" resp_body: ").String()) + sb.WriteString(b) + } + } + } + return sb.String() +} + +func colorizeTLS(req *layers.TLSClientHello, resp *layers.TLSServerHello, id string, nocolor bool) string { + var sb strings.Builder + if nocolor { + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(fmt.Sprintf(" %s ", req.TypeDesc)) + if req.Length > 0 { + sb.WriteString(fmt.Sprintf(" Len: %d", req.Length)) + } + if req.ServerName != nil && req.ServerName.SNName != "" { + sb.WriteString(fmt.Sprintf(" SNI: %s", req.ServerName.SNName)) + } + if req.Version != nil && req.Version.Desc != "" { + sb.WriteString(fmt.Sprintf(" Ver: %s", req.Version.Desc)) + } + if req.ALPN != nil { + sb.WriteString(fmt.Sprintf(" ALPN: %v", req.ALPN)) + } + sb.WriteString(" → ") + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(fmt.Sprintf(" %s ", resp.TypeDesc)) + if resp.Length > 0 { + sb.WriteString(fmt.Sprintf(" Len: %d", resp.Length)) + } + if resp.SessionID != "" { + sb.WriteString(fmt.Sprintf(" SID: %s", resp.SessionID)) + } + if resp.CipherSuite != nil && resp.CipherSuite.Desc != "" { + sb.WriteString(fmt.Sprintf(" CS: %s", resp.CipherSuite.Desc)) + } + if resp.SupportedVersion != nil && resp.SupportedVersion.Desc != "" { + sb.WriteString(fmt.Sprintf(" Ver: %s", resp.SupportedVersion.Desc)) + } + if resp.ExtensionLength > 0 { + sb.WriteString(fmt.Sprintf(" ExtLen: %d", resp.ExtensionLength)) + } + } else { + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(colors.Magenta(fmt.Sprintf(" %s ", req.TypeDesc)).Bold()) + if req.Length > 0 { + sb.WriteString(colors.BeigeBg(fmt.Sprintf(" Len: %d", req.Length)).String()) + } + if req.ServerName != nil && req.ServerName.SNName != "" { + sb.WriteString(colors.YellowBg(fmt.Sprintf(" SNI: %s", req.ServerName.SNName)).String()) + } + if req.Version != nil && req.Version.Desc != "" { + sb.WriteString(colors.GreenBg(fmt.Sprintf(" Ver: %s", req.Version.Desc)).String()) + } + if req.ALPN != nil { + sb.WriteString(colors.BlueBg(fmt.Sprintf(" ALPN: %v", req.ALPN)).String()) + } + sb.WriteString(colors.MagentaBg(" → ").String()) + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(colors.LightBlue(fmt.Sprintf(" %s ", resp.TypeDesc)).Bold()) + if resp.Length > 0 { + sb.WriteString(colors.BeigeBg(fmt.Sprintf(" Len: %d", resp.Length)).String()) + } + if resp.SessionID != "" { + sb.WriteString(colors.Gray(fmt.Sprintf(" SID: %s", resp.SessionID)).String()) + } + if resp.CipherSuite != nil && resp.CipherSuite.Desc != "" { + sb.WriteString(colors.Yellow(fmt.Sprintf(" CS: %s", resp.CipherSuite.Desc)).Bold()) + } + if resp.SupportedVersion != nil && resp.SupportedVersion.Desc != "" { + sb.WriteString(colors.GreenBg(fmt.Sprintf(" Ver: %s", resp.SupportedVersion.Desc)).String()) + } + if resp.ExtensionLength > 0 { + sb.WriteString(colors.BeigeBg(fmt.Sprintf(" ExtLen: %d", resp.ExtensionLength)).String()) + } + } + return sb.String() +} + +func highlightPatterns(line string, nocolor bool) (string, bool) { + matched := false + + // TODO: make this configurable + // line, matched = replace(line, ipPortPattern, colors.YellowBg, matched, nocolor) + // line, matched = replace(line, domainPattern, colors.YellowBg, matched, nocolor) + line, matched = replace(line, jwtPattern, colors.Magenta, matched, nocolor) + line, matched = replace(line, authPattern, colors.Magenta, matched, nocolor) + line, matched = replace(line, credsPattern, colors.GreenBg, matched, nocolor) + + return line, matched +} + +func replace(line string, re *regexp.Regexp, color func(string) *colors.Color, matched, nocolor bool) (string, bool) { + if re.MatchString(line) { + matched = true + if !nocolor { + line = re.ReplaceAllStringFunc(line, func(s string) string { + return color(s).String() + }) + } + } + return line, matched +} + +func colorizeBody(b *[]byte, nocolor bool) string { + matches := make([]string, 0, 3) + scanner := bufio.NewScanner(bytes.NewReader(*b)) + for scanner.Scan() { + line := scanner.Text() + if highlighted, ok := highlightPatterns(line, nocolor); ok { + matches = append(matches, strings.Trim(highlighted, "\r\n\t ")) + } + } + return strings.Join(matches, "\n") +} + +func colorizeTimestamp(ts time.Time, nocolor bool) string { + if nocolor { + return colors.WrapBrackets(ts.Format(time.TimeOnly)) + } + return colors.Gray(colors.WrapBrackets(ts.Format(time.TimeOnly))).String() +} + +func colorizeLogMessage(line string, nocolor bool) string { + if nocolor { + return line + } + result := ipPortPattern.ReplaceAllStringFunc(line, func(match string) string { + return colors.Gray(match).String() + }) + result = domainPattern.ReplaceAllStringFunc(result, func(match string) string { + return colors.Yellow(match).String() + }) + result = macPattern.ReplaceAllStringFunc(result, func(match string) string { + return colors.Yellow(match).String() + }) + return result +} + +func colorizeErrMessage(line string, nocolor bool) string { + if nocolor { + return line + } + result := ipPortPattern.ReplaceAllStringFunc(line, func(match string) string { + return colors.Red(match).String() + }) + result = domainPattern.ReplaceAllStringFunc(result, func(match string) string { + return colors.Red(match).String() + }) + result = strings.ReplaceAll(result, "->", "→ ") + return result +} + +func colorizeChainType(chainType string, nocolor bool) string { + if nocolor { + return colors.WrapBrackets(chainType) + } + return colors.WrapBrackets(colors.LightBlueBg(chainType).String()) +} + +func colorizeConnections(srcRemote, srcLocal, dstRemote, dstLocal net.Addr, id string, r *http.Request, nocolor bool) string { + var sb strings.Builder + if nocolor { + sb.WriteString(id) + sb.WriteString( + fmt.Sprintf( + " Src: %s→ %s → Dst: %s→ %s", + srcRemote, + srcLocal, + dstLocal, + dstRemote, + ), + ) + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(fmt.Sprintf(" %s %s %s ", r.Method, r.Host, r.Proto)) + } else { + sb.WriteString(id) + sb.WriteString(colors.Green(fmt.Sprintf(" Src: %s→ %s", srcRemote, srcLocal)).String()) + sb.WriteString(colors.Magenta(" → ").String()) + sb.WriteString(colors.Blue(fmt.Sprintf("Dst: %s→ %s", dstLocal, dstRemote)).String()) + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s ", colorizeTimestamp(time.Now(), nocolor))) + sb.WriteString(id) + sb.WriteString(colors.Gray(fmt.Sprintf(" %s ", r.Method)).String()) + sb.WriteString(colors.YellowBg(fmt.Sprintf("%s ", r.Host)).String()) + sb.WriteString(colors.BlueBg(fmt.Sprintf("%s ", r.Proto)).String()) + } + return sb.String() +} + +func colorizeConnectionsTransparent( + srcRemote, srcLocal, dstRemote, dstLocal net.Addr, + dst, + id string, + nocolor bool, +) string { + var sb strings.Builder + if nocolor { + sb.WriteString(id) + sb.WriteString( + fmt.Sprintf( + " Src: %s→ %s → Dst: %s→ %s Orig Dst: %s", + srcRemote, + srcLocal, + dstLocal, + dstRemote, + dst, + ), + ) + } else { + sb.WriteString(id) + sb.WriteString(colors.Green(fmt.Sprintf(" Src: %s→ %s", srcRemote, srcLocal)).String()) + sb.WriteString(colors.Magenta(" → ").String()) + sb.WriteString(colors.Blue(fmt.Sprintf("Dst: %s→ %s ", dstLocal, dstRemote)).String()) + sb.WriteString(colors.BeigeBg(fmt.Sprintf("Orig Dst: %s", dst)).String()) + } + return sb.String() +} diff --git a/go.mod b/go.mod index f526f67..6bd3717 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/google/uuid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/shadowy-pycoder/colors v0.0.1 - github.com/shadowy-pycoder/mshark v0.0.9 + github.com/shadowy-pycoder/mshark v0.0.10 golang.org/x/net v0.40.0 golang.org/x/sys v0.33.0 golang.org/x/term v0.32.0 diff --git a/go.sum b/go.sum index 921720b..09a471b 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/shadowy-pycoder/colors v0.0.1 h1:weCj/YIOupqy4BSP8KuVzr20fC+cuAv/tArz7bhhkP4= github.com/shadowy-pycoder/colors v0.0.1/go.mod h1:lkrJS1PY2oVigNLTT6pkbF7B/v0YcU2LD5PZnss1Q4U= -github.com/shadowy-pycoder/mshark v0.0.9 h1:mMHmkqUpkSlkt74DaSkNjhvO0nJ0AxZiYPH6QbllB9A= -github.com/shadowy-pycoder/mshark v0.0.9/go.mod h1:FqbHFdsx0zMnrZZH0+oPzaFcleP4O+tUWv8i5gxo87k= +github.com/shadowy-pycoder/mshark v0.0.10 h1:pLMIsgfvnO0oKeBNdy0fTGQsx//6scCPT52g93CqyT4= +github.com/shadowy-pycoder/mshark v0.0.10/go.mod h1:FqbHFdsx0zMnrZZH0+oPzaFcleP4O+tUWv8i5gxo87k= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= diff --git a/gohpts.go b/gohpts.go index 1866948..4de72a5 100644 --- a/gohpts.go +++ b/gohpts.go @@ -2,14 +2,12 @@ package gohpts import ( - "bufio" "bytes" "compress/gzip" "context" "crypto/sha256" "crypto/subtle" "crypto/tls" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -19,9 +17,7 @@ import ( "net" "net/http" "os" - "os/exec" "os/signal" - "regexp" "runtime" "slices" "strconv" @@ -31,9 +27,7 @@ import ( "time" "github.com/goccy/go-yaml" - "github.com/google/uuid" "github.com/rs/zerolog" - "github.com/shadowy-pycoder/colors" "github.com/shadowy-pycoder/mshark/arpspoof" "github.com/shadowy-pycoder/mshark/layers" "github.com/shadowy-pycoder/mshark/network" @@ -56,36 +50,8 @@ var ( supportedChainTypes = []string{"strict", "dynamic", "random", "round_robin"} SupportedTProxyModes = []string{"redirect", "tproxy"} errInvalidWrite = errors.New("invalid write result") - ipPortPattern = regexp.MustCompile( - `\b(?:\d{1,3}\.){3}\d{1,3}(?::(6553[0-5]|655[0-2]\d|65[0-4]\d{2}|6[0-4]\d{3}|[1-5]?\d{1,4}))?\b`, - ) - domainPattern = regexp.MustCompile( - `\b(?:[a-zA-Z0-9-]{1,63}\.)+(?:com|net|org|io|co|uk|ru|de|edu|gov|info|biz|dev|app|ai)(?::(6553[0-5]|655[0-2]\d|65[0-4]\d{2}|6[0-4]\d{3}|[1-5]?\d{1,4}))?\b`, - ) - jwtPattern = regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}\b`) - authPattern = regexp.MustCompile( - `(?i)(?:"|')?(authorization|auth[_-]?token|access[_-]?token|api[_-]?key|secret|token)(?:"|')?\s*[:=]\s*(?:"|')?([^\s"'&]+)`, - ) - credsPattern = regexp.MustCompile( - `(?i)(?:"|')?(username|user|login|email|password|pass|pwd)(?:"|')?\s*[:=]\s*(?:"|')?([^\s"'&]+)`, - ) - macPattern = regexp.MustCompile(`(?i)([a-z0-9_]+_[0-9a-f]{2}(?::[0-9a-f]{2}){2}|(?:[0-9a-f]{2}[:-]){5}[0-9a-f]{2})`) ) -// Hop-by-hop headers -// https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1 -var hopHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "TE", - "Trailer", - "Transfer-Encoding", - "Upgrade", -} - type Config struct { AddrHTTP string AddrSOCKS string @@ -112,6 +78,51 @@ type Config struct { Body bool } +type logWriter struct { + file *os.File +} + +func (writer logWriter) Write(bytes []byte) (int, error) { + return fmt.Fprintf(writer.file, "%s ERR %s", time.Now().Format(time.RFC3339), string(bytes)) +} + +type jsonLogWriter struct { + file *os.File +} + +func (writer jsonLogWriter) Write(bytes []byte) (int, error) { + return fmt.Fprintf(writer.file, "{\"level\":\"error\",\"time\":\"%s\",\"message\":\"%s\"}\n", + time.Now().Format(time.RFC3339), strings.TrimRight(string(bytes), "\n")) +} + +type proxyEntry struct { + Address string `yaml:"address"` + Username string `yaml:"username,omitempty"` + Password string `yaml:"password,omitempty"` +} + +func (pe proxyEntry) String() string { + return pe.Address +} + +type server struct { + Address string `yaml:"address"` + Interface string `yaml:"interface,omitempty"` + Username string `yaml:"username,omitempty"` + Password string `yaml:"password,omitempty"` + CertFile string `yaml:"cert_file,omitempty"` + KeyFile string `yaml:"key_file,omitempty"` +} +type chain struct { + Type string `yaml:"type"` + Length int `yaml:"length"` +} + +type serverConfig struct { + Chain chain `yaml:"chain"` + ProxyList []proxyEntry `yaml:"proxy_list"` + Server server `yaml:"server"` +} type proxyapp struct { httpServer *http.Server sockClient *http.Client @@ -138,629 +149,452 @@ type proxyapp struct { nocolor bool body bool json bool + debug bool closeConn chan bool mu sync.RWMutex availProxyList []proxyEntry } -var rColors = []func(string) *colors.Color{ - colors.Beige, - colors.Blue, - colors.Gray, - colors.Green, - colors.LightBlue, - colors.Magenta, - colors.Red, - colors.Yellow, - colors.BeigeBg, - colors.BlueBg, - colors.GrayBg, - colors.GreenBg, - colors.LightBlueBg, - colors.MagentaBg, - colors.RedBgDark, - colors.YellowBg, -} - -func randColor() func(string) *colors.Color { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - randIndex := r.Intn(len(rColors)) - return rColors[randIndex] -} - -func (p *proxyapp) getID() string { - id := uuid.New() - if p.nocolor { - return colors.WrapBrackets(id.String()) +func New(conf *Config) *proxyapp { + var logger, snifflogger zerolog.Logger + var p proxyapp + logfile := os.Stdout + var snifflog *os.File + var err error + p.sniff = conf.Sniff + p.body = conf.Body + p.json = conf.JSON + if conf.LogFilePath != "" { + f, err := os.OpenFile(conf.LogFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + log.Fatalf("Failed to open log file: %v", err) + } + logfile = f } - return randColor()(colors.WrapBrackets(id.String())).String() -} - -func (p *proxyapp) colorizeStatus(code int, status string, bg bool) string { - if bg { - if code < 300 { - status = colors.GreenBg(status).String() - } else if code < 400 { - status = colors.YellowBg(status).String() - } else { - status = colors.RedBgDark(status).String() + if conf.SniffLogFile != "" && conf.SniffLogFile != conf.LogFilePath { + f, err := os.OpenFile(conf.SniffLogFile, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + log.Fatalf("Failed to open sniff log file: %v", err) } + snifflog = f } else { - if code < 300 { - status = colors.Green(status).String() - } else if code < 400 { - status = colors.Yellow(status).String() - } else { - status = colors.Red(status).String() - } - } - return status -} - -func (p *proxyapp) colorizeHTTP( - req *http.Request, - resp *http.Response, - reqBodySaved, respBodySaved *[]byte, - id string, - ts bool, -) string { - var sb strings.Builder - if ts { - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) + snifflog = logfile } - if p.nocolor { - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" %s %s %s ", req.Method, req.URL, req.Proto)) - if req.UserAgent() != "" { - sb.WriteString(colors.WrapBrackets(req.UserAgent())) - } - if req.ContentLength > 0 { - sb.WriteString(fmt.Sprintf(" Len: %d", req.ContentLength)) - } - sb.WriteString(" → ") - sb.WriteString(fmt.Sprintf("%s %s ", resp.Proto, resp.Status)) - if resp.ContentLength > 0 { - sb.WriteString(fmt.Sprintf("Len: %d", resp.ContentLength)) - } - if p.body && len(*reqBodySaved) > 0 { - b := p.colorizeBody(reqBodySaved) - if b != "" { - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" req_body: %s", b)) - } - } - if p.body && len(*respBodySaved) > 0 { - b := p.colorizeBody(respBodySaved) - if b != "" { - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" resp_body: %s", b)) - } - } + p.nocolor = conf.JSON || conf.NoColor + if conf.JSON { + log.SetFlags(0) + jsonWriter := jsonLogWriter{file: logfile} + log.SetOutput(jsonWriter) + logger = zerolog.New(logfile).With().Timestamp().Logger() + snifflogger = zerolog.New(snifflog).With().Timestamp().Logger() } else { - sb.WriteString(id) - sb.WriteString(colors.Gray(fmt.Sprintf(" %s ", req.Method)).String()) - sb.WriteString(colors.YellowBg(fmt.Sprintf("%s ", req.URL)).String()) - sb.WriteString(colors.BlueBg(fmt.Sprintf("%s ", req.Proto)).String()) - if req.UserAgent() != "" { - sb.WriteString(colors.Gray(colors.WrapBrackets(req.UserAgent())).String()) - } - if req.ContentLength > 0 { - sb.WriteString(colors.BeigeBg(fmt.Sprintf(" Len: %d", req.ContentLength)).String()) - } - sb.WriteString(colors.MagentaBg(" → ").String()) - sb.WriteString(colors.BlueBg(fmt.Sprintf("%s ", resp.Proto)).String()) - sb.WriteString(p.colorizeStatus(resp.StatusCode, fmt.Sprintf("%s ", resp.Status), true)) - if resp.ContentLength > 0 { - sb.WriteString(colors.BeigeBg(fmt.Sprintf("Len: %d", resp.ContentLength)).String()) - } - if p.body && len(*reqBodySaved) > 0 { - b := p.colorizeBody(reqBodySaved) - if b != "" { - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(colors.RedBgDark(" req_body: ").String()) - sb.WriteString(b) - } + log.SetFlags(0) + logWriter := logWriter{file: logfile} + log.SetOutput(logWriter) + output := zerolog.ConsoleWriter{Out: logfile, NoColor: p.nocolor} + + output.FormatTimestamp = func(i any) string { + ts, _ := time.Parse(time.RFC3339, i.(string)) + return colorizeTimestamp(ts, p.nocolor) } - if p.body && len(*respBodySaved) > 0 { - b := p.colorizeBody(respBodySaved) - if b != "" { - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(colors.RedBgDark(" resp_body: ").String()) - sb.WriteString(b) + output.FormatMessage = func(i any) string { + if i == nil || i == "" { + return "" } + return colorizeLogMessage(i.(string), p.nocolor) } - } - return sb.String() -} -func (p *proxyapp) colorizeTLS(req *layers.TLSClientHello, resp *layers.TLSServerHello, id string) string { - var sb strings.Builder - if p.nocolor { - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" %s ", req.TypeDesc)) - if req.Length > 0 { - sb.WriteString(fmt.Sprintf(" Len: %d", req.Length)) - } - if req.ServerName != nil && req.ServerName.SNName != "" { - sb.WriteString(fmt.Sprintf(" SNI: %s", req.ServerName.SNName)) - } - if req.Version != nil && req.Version.Desc != "" { - sb.WriteString(fmt.Sprintf(" Ver: %s", req.Version.Desc)) - } - if req.ALPN != nil { - sb.WriteString(fmt.Sprintf(" ALPN: %v", req.ALPN)) + output.FormatErrFieldName = func(i any) string { + return fmt.Sprintf("%s", i) } - sb.WriteString(" → ") - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" %s ", resp.TypeDesc)) - if resp.Length > 0 { - sb.WriteString(fmt.Sprintf(" Len: %d", resp.Length)) + + output.FormatErrFieldValue = func(i any) string { + s := i.(string) + return colorizeErrMessage(s, p.nocolor) } - if resp.SessionID != "" { - sb.WriteString(fmt.Sprintf(" SID: %s", resp.SessionID)) + logger = zerolog.New(output).With().Timestamp().Logger() + sniffoutput := zerolog.ConsoleWriter{Out: snifflog, TimeFormat: time.RFC3339, NoColor: p.nocolor, PartsExclude: []string{"level"}} + sniffoutput.FormatTimestamp = func(i any) string { + ts, _ := time.Parse(time.RFC3339, i.(string)) + return colorizeTimestamp(ts, p.nocolor) } - if resp.CipherSuite != nil && resp.CipherSuite.Desc != "" { - sb.WriteString(fmt.Sprintf(" CS: %s", resp.CipherSuite.Desc)) + sniffoutput.FormatMessage = func(i any) string { + if i == nil || i == "" { + return "" + } + return fmt.Sprintf("%s", i) } - if resp.SupportedVersion != nil && resp.SupportedVersion.Desc != "" { - sb.WriteString(fmt.Sprintf(" Ver: %s", resp.SupportedVersion.Desc)) + sniffoutput.FormatErrFieldName = func(i any) string { + return fmt.Sprintf("%s", i) } - if resp.ExtensionLength > 0 { - sb.WriteString(fmt.Sprintf(" ExtLen: %d", resp.ExtensionLength)) + + sniffoutput.FormatErrFieldValue = func(i any) string { + return colorizeErrMessage(i.(string), p.nocolor) } + snifflogger = zerolog.New(sniffoutput).With().Timestamp().Logger() + } + zerolog.SetGlobalLevel(zerolog.DebugLevel) + lvl := zerolog.InfoLevel + if conf.Debug { + lvl = zerolog.DebugLevel + } + p.debug = conf.Debug + // the only way I found to make debug level independent between loggers + l := logger.Level(lvl) + sl := snifflogger.Level(lvl) + p.logger = &l + p.snifflogger = &sl + if runtime.GOOS == "linux" && conf.TProxy != "" && conf.TProxyOnly != "" { + p.logger.Fatal().Msg("Cannot specify TPRoxy and TProxyOnly at the same time") + } else if runtime.GOOS == "linux" && conf.TProxyMode != "" && !slices.Contains(SupportedTProxyModes, conf.TProxyMode) { + p.logger.Fatal().Msg("Incorrect TProxyMode provided") + } else if runtime.GOOS != "linux" && (conf.TProxy != "" || conf.TProxyOnly != "" || conf.TProxyMode != "") { + conf.TProxy = "" + conf.TProxyOnly = "" + conf.TProxyMode = "" + p.logger.Warn().Msgf("[%s] functionality only available on linux systems", conf.TProxyMode) + } + p.tproxyMode = conf.TProxyMode + tproxyonly := conf.TProxyOnly != "" + var tAddr string + if tproxyonly { + tAddr = conf.TProxyOnly } else { - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(colors.Magenta(fmt.Sprintf(" %s ", req.TypeDesc)).Bold()) - if req.Length > 0 { - sb.WriteString(colors.BeigeBg(fmt.Sprintf(" Len: %d", req.Length)).String()) - } - if req.ServerName != nil && req.ServerName.SNName != "" { - sb.WriteString(colors.YellowBg(fmt.Sprintf(" SNI: %s", req.ServerName.SNName)).String()) - } - if req.Version != nil && req.Version.Desc != "" { - sb.WriteString(colors.GreenBg(fmt.Sprintf(" Ver: %s", req.Version.Desc)).String()) - } - if req.ALPN != nil { - sb.WriteString(colors.BlueBg(fmt.Sprintf(" ALPN: %v", req.ALPN)).String()) - } - sb.WriteString(colors.MagentaBg(" → ").String()) - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(colors.LightBlue(fmt.Sprintf(" %s ", resp.TypeDesc)).Bold()) - if resp.Length > 0 { - sb.WriteString(colors.BeigeBg(fmt.Sprintf(" Len: %d", resp.Length)).String()) - } - if resp.SessionID != "" { - sb.WriteString(colors.Gray(fmt.Sprintf(" SID: %s", resp.SessionID)).String()) - } - if resp.CipherSuite != nil && resp.CipherSuite.Desc != "" { - sb.WriteString(colors.Yellow(fmt.Sprintf(" CS: %s", resp.CipherSuite.Desc)).Bold()) - } - if resp.SupportedVersion != nil && resp.SupportedVersion.Desc != "" { - sb.WriteString(colors.GreenBg(fmt.Sprintf(" Ver: %s", resp.SupportedVersion.Desc)).String()) + tAddr = conf.TProxy + } + if p.tproxyMode != "" { + p.tproxyAddr, err = getFullAddress(tAddr, "", true) + if err != nil { + p.logger.Fatal().Err(err).Msg("") } - if resp.ExtensionLength > 0 { - sb.WriteString(colors.BeigeBg(fmt.Sprintf(" ExtLen: %d", resp.ExtensionLength)).String()) + } else { + p.tproxyAddr, err = getFullAddress(tAddr, "", false) + if err != nil { + p.logger.Fatal().Err(err).Msg("") } } - return sb.String() -} - -func (p *proxyapp) highlightPatterns(line string) (string, bool) { - matched := false - - // TODO: make this configurable - // line, matched = p.replace(line, ipPortPattern, colors.YellowBg, matched) - // line, matched = p.replace(line, domainPattern, colors.YellowBg, matched) - line, matched = p.replace(line, jwtPattern, colors.Magenta, matched) - line, matched = p.replace(line, authPattern, colors.Magenta, matched) - line, matched = p.replace(line, credsPattern, colors.GreenBg, matched) - - return line, matched -} - -func (p *proxyapp) replace(line string, re *regexp.Regexp, color func(string) *colors.Color, matched bool) (string, bool) { - if re.MatchString(line) { - matched = true - if !p.nocolor { - line = re.ReplaceAllStringFunc(line, func(s string) string { - return color(s).String() - }) - } + p.auto = conf.Auto + if p.auto && runtime.GOOS != "linux" { + p.logger.Fatal().Msg("Auto setup is available only on linux systems") } - return line, matched -} - -func (p *proxyapp) colorizeBody(b *[]byte) string { - matches := make([]string, 0, 3) - scanner := bufio.NewScanner(bytes.NewReader(*b)) - for scanner.Scan() { - line := scanner.Text() - if highlighted, ok := p.highlightPatterns(line); ok { - matches = append(matches, strings.Trim(highlighted, "\r\n\t ")) - } + p.mark = conf.Mark + if p.mark > 0 && runtime.GOOS != "linux" { + p.logger.Fatal().Msg("SO_MARK is available only on linux systems") } - return strings.Join(matches, "\n") -} - -func (p *proxyapp) colorizeTimestamp() string { - ts := time.Now() - if p.nocolor { - return colors.WrapBrackets(ts.Format(time.TimeOnly)) + if p.mark > 0xFFFFFFFF { + p.logger.Fatal().Msg("SO_MARK is out of range") } - return colors.Gray(colors.WrapBrackets(ts.Format(time.TimeOnly))).String() -} - -func (p *proxyapp) colorizeTunnel(req, resp layers.Layer, sniffheader *[]string, id string) error { - switch reqt := req.(type) { - case *layers.HTTPMessage: - var reqBodySaved, respBodySaved []byte - rest := resp.(*layers.HTTPMessage) - if p.body { - reqBodySaved, _ = io.ReadAll(reqt.Request.Body) - respBodySaved, _ = io.ReadAll(rest.Response.Body) - reqBodySaved = bytes.Trim(reqBodySaved, "\r\n\t ") - respBodySaved = bytes.Trim(respBodySaved, "\r\n\t ") + if p.mark == 0 && p.tproxyMode == "tproxy" { + p.mark = 100 + } + var addrHTTP, addrSOCKS, certFile, keyFile string + if conf.ServerConfPath != "" { + var sconf serverConfig + yamlFile, err := os.ReadFile(expandPath(conf.ServerConfPath)) + if err != nil { + p.logger.Fatal().Err(err).Msg("[yaml config] Parsing failed") } - if p.json { - j1, err := json.Marshal(reqt) - if err != nil { - return err + err = yaml.Unmarshal(yamlFile, &sconf) + if err != nil { + p.logger.Fatal().Err(err).Msg("[yaml config] Parsing failed") + } + if !tproxyonly { + if sconf.Server.Address == "" { + p.logger.Fatal().Err(err).Msg("[yaml config] Server address is empty") } - j2, err := json.Marshal(rest) - if err != nil { - return err + if sconf.Server.Interface != "" { + p.iface, err = net.InterfaceByName(sconf.Server.Interface) + if err != nil { + if ifIdx, err := strconv.Atoi(sconf.Server.Interface); err == nil { + p.iface, err = net.InterfaceByIndex(ifIdx) + if err != nil { + p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", sconf.Server.Interface) + } + } else { + p.logger.Warn().Msgf("Failed binding to %s, using default interface", sconf.Server.Interface) + } + } } - *sniffheader = append(*sniffheader, string(j1), string(j2)) - if p.body && len(reqBodySaved) > 0 { - *sniffheader = append(*sniffheader, fmt.Sprintf("{\"req_body\":%s}", reqBodySaved)) + iAddr, err := getAddressFromInterface(p.iface) + if err != nil { + p.iface = nil + p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", sconf.Server.Interface) } - if p.body && len(respBodySaved) > 0 { - *sniffheader = append(*sniffheader, fmt.Sprintf("{\"resp_body\":%s}", respBodySaved)) + addrHTTP, err = getFullAddress(sconf.Server.Address, iAddr, false) + if err != nil { + p.logger.Fatal().Err(err).Msg("") } - } else { - *sniffheader = append(*sniffheader, p.colorizeHTTP(reqt.Request, rest.Response, &reqBodySaved, &respBodySaved, id, true)) + p.httpServerAddr = addrHTTP + certFile = expandPath(sconf.Server.CertFile) + keyFile = expandPath(sconf.Server.KeyFile) + p.user = sconf.Server.Username + p.pass = sconf.Server.Password } - case *layers.TLSMessage: - var chs *layers.TLSClientHello - var shs *layers.TLSServerHello - hsrec := reqt.Records[0] // len(Records) > 0 after dispatch - if hsrec.ContentType == layers.HandshakeTLSVal { // TODO: add more cases, parse all records - switch parser := layers.HSTLSParserByType(hsrec.Data[0]).(type) { - case *layers.TLSClientHello: - err := parser.ParseHS(hsrec.Data) - if err != nil { - return err - } - chs = parser - } + p.proxychain = sconf.Chain + p.proxylist = sconf.ProxyList + p.availProxyList = make([]proxyEntry, 0, len(p.proxylist)) + if len(p.proxylist) == 0 { + p.logger.Fatal().Msg("[yaml config] Proxy list is empty") } - rest := resp.(*layers.TLSMessage) - hsrec = rest.Records[0] - if hsrec.ContentType == layers.HandshakeTLSVal { - switch parser := layers.HSTLSParserByType(hsrec.Data[0]).(type) { - case *layers.TLSServerHello: - err := parser.ParseHS(hsrec.Data) - if err != nil { - return err - } - shs = parser + seen := make(map[string]struct{}) + for idx, pr := range p.proxylist { + addr, err := getFullAddress(pr.Address, "", false) + if err != nil { + p.logger.Fatal().Err(err).Msg("") } - } - if chs != nil && shs != nil { - if p.json { - j1, err := json.Marshal(chs) - if err != nil { - return err - } - j2, err := json.Marshal(shs) - if err != nil { - return err - } - *sniffheader = append(*sniffheader, string(j1), string(j2)) + if _, ok := seen[addr]; !ok { + seen[addr] = struct{}{} + p.proxylist[idx].Address = addr } else { - *sniffheader = append(*sniffheader, p.colorizeTLS(chs, shs, id)) + p.logger.Fatal().Msgf("[yaml config] Duplicate entry `%s`", addr) } } - } - return nil -} - -// https://stackoverflow.com/a/1094933/1333724 -func prettifyBytes(b int64) string { - bf := float64(b) - for _, unit := range []string{"", "K", "M", "G", "T", "P", "E", "Z"} { - if bf < 1000.0 { - return fmt.Sprintf("%3.1f%sB", bf, unit) - } - bf /= 1000.0 - } - return fmt.Sprintf("%.1fYB", bf) -} - -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) + addrSOCKS = p.printProxyChain(p.proxylist) + chainType := p.proxychain.Type + if !slices.Contains(supportedChainTypes, chainType) { + p.logger.Fatal().Msgf("[yaml config] Chain type `%s` is not supported", chainType) } - } -} - -func delHopHeaders(header http.Header) { - for _, h := range hopHeaders { - header.Del(h) - } -} - -// delConnectionHeaders removes hop-by-hop headers listed in the "Connection" header -// https://datatracker.ietf.org/doc/html/rfc7230#section-6.1 -func delConnectionHeaders(h http.Header) { - for _, f := range h["Connection"] { - for sf := range strings.SplitSeq(f, ",") { - if sf = strings.TrimSpace(sf); sf != "" { - h.Del(sf) + p.rrIndexReset = rrIndexMax + } else { + if !tproxyonly { + if conf.Interface != "" { + p.iface, err = net.InterfaceByName(conf.Interface) + if err != nil { + if ifIdx, err := strconv.Atoi(conf.Interface); err == nil { + p.iface, err = net.InterfaceByIndex(ifIdx) + if err != nil { + p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", conf.Interface) + } + } else { + p.logger.Warn().Msgf("Failed binding to %s, using default interface", conf.Interface) + } + } + } + iAddr, err := getAddressFromInterface(p.iface) + if err != nil { + p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", conf.Interface) + p.iface = nil + } + addrHTTP, err = getFullAddress(conf.AddrHTTP, iAddr, false) + if err != nil { + p.logger.Fatal().Err(err).Msg("") } + p.httpServerAddr = addrHTTP + certFile = expandPath(conf.CertFile) + keyFile = expandPath(conf.KeyFile) + p.user = conf.ServerUser + p.pass = conf.ServerPass } - } -} - -func appendHostToXForwardHeader(header http.Header, host string) { - if prior, ok := header["X-Forwarded-For"]; ok { - host = strings.Join(prior, ", ") + ", " + host - } - header.Set("X-Forwarded-For", host) -} - -func isLocalAddress(addr string) bool { - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - ip := net.ParseIP(host) - if ip != nil { - return ip.IsLoopback() - } - host = strings.ToLower(host) - return strings.HasSuffix(host, ".local") || host == "localhost" -} - -func (p *proxyapp) printProxyChain(pc []proxyEntry) string { - var sb strings.Builder - sb.WriteString("client → ") - if p.httpServerAddr != "" { - sb.WriteString(p.httpServerAddr) - if p.tproxyAddr != "" { - sb.WriteString(" | ") - sb.WriteString(p.tproxyAddr) - sb.WriteString(fmt.Sprintf(" (%s)", p.tproxyMode)) + addrSOCKS, err = getFullAddress(conf.AddrSOCKS, "", false) + if err != nil { + p.logger.Fatal().Err(err).Msg("") } - } else if p.tproxyAddr != "" { - sb.WriteString(p.tproxyAddr) - sb.WriteString(fmt.Sprintf(" (%s)", p.tproxyMode)) - } - sb.WriteString(" → ") - for _, pe := range pc { - sb.WriteString(pe.String()) - sb.WriteString(" → ") - } - sb.WriteString("target") - return sb.String() -} - -func (p *proxyapp) updateSocksList() { - p.mu.Lock() - defer p.mu.Unlock() - p.availProxyList = p.availProxyList[:0] - var base proxy.Dialer = getBaseDialer(timeout, p.mark) - var dialer proxy.Dialer - var err error - failed := 0 - chainType := p.proxychain.Type - var ctl string - if p.nocolor { - ctl = colors.WrapBrackets(chainType) - } else { - ctl = colors.WrapBrackets(colors.LightBlueBg(chainType).String()) - } - for _, pr := range p.proxylist { auth := proxy.Auth{ - User: pr.Username, - Password: pr.Password, + User: conf.User, + Password: conf.Pass, } - dialer, err = proxy.SOCKS5("tcp", pr.Address, &auth, base) + dialer, err := proxy.SOCKS5("tcp", addrSOCKS, &auth, getBaseDialer(timeout, p.mark)) if err != nil { - p.logger.Error().Err(err).Msgf("%s Unable to create SOCKS5 dialer %s", ctl, pr.Address) - failed++ - continue + p.logger.Fatal().Err(err).Msg("Unable to create SOCKS5 dialer") } - ctx, cancel := context.WithTimeout(context.Background(), hopTimeout) - defer cancel() - conn, err := dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", pr.Address) - if err != nil && !errors.Is(err, io.EOF) { // check for EOF to include localhost SOCKS5 in the chain - p.logger.Error().Err(err).Msgf("%s Unable to connect to %s", ctl, pr.Address) - failed++ - if conn != nil { - conn.Close() - } - continue - } else { - p.availProxyList = append(p.availProxyList, proxyEntry{Address: pr.Address, Username: pr.Username, Password: pr.Password}) - if conn != nil { - conn.Close() + p.sockDialer = dialer + if !tproxyonly { + p.sockClient = &http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } - break } } - if failed == len(p.proxylist) { - p.logger.Error().Err(err).Msgf("%s No SOCKS5 Proxy available", ctl) - return + if !tproxyonly { + hs := &http.Server{ + Addr: addrHTTP, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + MaxHeaderBytes: 1 << 20, + Protocols: new(http.Protocols), + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + }, + }, + } + hs.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + hs.Protocols.SetHTTP1(true) + p.httpServer = hs + p.httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + DialContext: getBaseDialer(timeout, p.mark).DialContext, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: timeout, + } } - currentDialer := dialer - for _, pr := range p.proxylist[failed+1:] { - auth := proxy.Auth{ - User: pr.Username, - Password: pr.Password, + if conf.ARPSpoof != "" { + if runtime.GOOS != "linux" { + p.logger.Fatal().Msg("ARP spoof setup is available only on linux systems") } - dialer, err = proxy.SOCKS5("tcp", pr.Address, &auth, currentDialer) + if !p.auto { + p.logger.Warn().Msg("ARP spoof setup requires iptables configuration") + } + asc, err := arpspoof.NewARPSpoofConfig(conf.ARPSpoof, p.logger) if err != nil { - p.logger.Error().Err(err).Msgf("%s Unable to create SOCKS5 dialer %s", ctl, pr.Address) - continue + p.logger.Fatal().Err(err).Msg("Failed creating arp spoofer") } - // https://github.com/golang/go/issues/37549#issuecomment-1178745487 - ctx, cancel := context.WithTimeout(context.Background(), hopTimeout) - defer cancel() - conn, err := dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", pr.Address) + asc.Interface = "" + asc.Gateway = nil + if p.iface != nil { + asc.Interface = p.iface.Name + } + p.arpspoofer, err = arpspoof.NewARPSpoofer(asc) if err != nil { - p.logger.Error().Err(err).Msgf("%s Unable to connect to %s", ctl, pr.Address) - if conn != nil { - conn.Close() - } - continue + p.logger.Fatal().Err(err).Msg("Failed creating arp spoofer") } - conn.Close() - currentDialer = dialer - p.availProxyList = append(p.availProxyList, proxyEntry{Address: pr.Address, Username: pr.Username, Password: pr.Password}) } - p.logger.Debug().Msgf("%s Available SOCKS5 Proxy [%d/%d]: %s", ctl, - len(p.availProxyList), len(p.proxylist), p.printProxyChain(p.availProxyList)) -} - -// https://www.calhoun.io/how-to-shuffle-arrays-and-slices-in-go/ -func shuffle(vals []proxyEntry) { - r := rand.New(rand.NewSource(time.Now().Unix())) - for len(vals) > 0 { - n := len(vals) - randIndex := r.Intn(n) - vals[n-1], vals[randIndex] = vals[randIndex], vals[n-1] - vals = vals[:n-1] + if conf.ServerConfPath != "" { + p.logger.Info().Msgf("SOCKS5 Proxy [%s] chain: %s", p.proxychain.Type, addrSOCKS) + } else { + p.logger.Info().Msgf("SOCKS5 Proxy: %s", addrSOCKS) } -} - -func (p *proxyapp) getSocks() (proxy.Dialer, *http.Client, error) { - if p.proxylist == nil { - return p.sockDialer, p.sockClient, nil + if !tproxyonly { + if certFile != "" && keyFile != "" { + p.certFile = certFile + p.keyFile = keyFile + p.logger.Info().Msgf("HTTPS Proxy: %s", p.httpServerAddr) + } else { + p.logger.Info().Msgf("HTTP Proxy: %s", p.httpServerAddr) + } } - p.mu.RLock() - defer p.mu.RUnlock() - chainType := p.proxychain.Type - var ctl string - if p.nocolor { - ctl = colors.WrapBrackets(chainType) - } else { - ctl = colors.WrapBrackets(colors.LightBlueBg(chainType).String()) + if p.tproxyAddr != "" { + if p.tproxyMode == "tproxy" { + p.logger.Info().Msgf("TPROXY: %s", p.tproxyAddr) + } else { + p.logger.Info().Msgf("REDIRECT: %s", p.tproxyAddr) + } } - if len(p.availProxyList) == 0 { - p.logger.Error().Msgf("%s No SOCKS5 Proxy available", ctl) - return nil, nil, fmt.Errorf("no socks5 proxy available") + return &p +} + +func (p *proxyapp) Run() { + done := make(chan bool) + quit := make(chan os.Signal, 1) + p.closeConn = make(chan bool) + signal.Notify(quit, os.Interrupt) + if p.arpspoofer != nil { + go p.arpspoofer.Start() } - var chainLength int - if p.proxychain.Length > len(p.availProxyList) || p.proxychain.Length <= 0 { - chainLength = len(p.availProxyList) - } else { - chainLength = p.proxychain.Length + var tproxyServer *tproxyServer + var output map[string]string + if p.tproxyAddr != "" { + tproxyServer = newTproxyServer(p) + if p.auto { + output = tproxyServer.applyRedirectRules() + } } - copyProxyList := make([]proxyEntry, 0, len(p.availProxyList)) - switch chainType { - case "strict", "dynamic": - copyProxyList = p.availProxyList - case "random": - copyProxyList = append(copyProxyList, p.availProxyList...) - shuffle(copyProxyList) - copyProxyList = copyProxyList[:chainLength] - case "round_robin": - var start uint32 - for { - start = atomic.LoadUint32(&p.rrIndex) - next := start + 1 - if start >= p.rrIndexReset { - p.logger.Debug().Msg("Resetting round robin index") - next = 0 + if p.proxylist != nil { + chainType := p.proxychain.Type + ctl := colorizeChainType(chainType, p.nocolor) + go func() { + for { + p.logger.Debug().Msgf("%s Updating available proxy", ctl) + p.updateSocksList() + time.Sleep(availProxyUpdateInterval) } - if atomic.CompareAndSwapUint32(&p.rrIndex, start, next) { - break + }() + } + if p.httpServer != nil { + go func() { + <-quit + if p.arpspoofer != nil { + err := p.arpspoofer.Stop() + if err != nil { + p.logger.Error().Err(err).Msg("Failed stopping arp spoofer") + } } + close(p.closeConn) + if tproxyServer != nil { + if p.auto { + err := tproxyServer.clearRedirectRules(output) + if err != nil { + p.logger.Error().Err(err).Msg("Failed clearing iptables rules") + } + } + p.logger.Info().Msgf("[%s] Server is shutting down...", p.tproxyMode) + tproxyServer.Shutdown() + } + p.logger.Info().Msg("Server is shutting down...") + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + + defer cancel() + p.httpServer.SetKeepAlivesEnabled(false) + if err := p.httpServer.Shutdown(ctx); err != nil { + p.logger.Fatal().Err(err).Msg("Could not gracefully shutdown the server") + } + close(done) + }() + if tproxyServer != nil { + go tproxyServer.ListenAndServe() } - startIdx := int(start % uint32(len(p.availProxyList))) - for i := range chainLength { - idx := (startIdx + i) % len(p.availProxyList) - copyProxyList = append(copyProxyList, p.availProxyList[idx]) - } - default: - p.logger.Fatal().Msg("Unreachable") - } - if len(copyProxyList) == 0 { - p.logger.Error().Msgf("%s No SOCKS5 Proxy available", ctl) - return nil, nil, fmt.Errorf("no socks5 proxy available") - } - if p.proxychain.Type == "strict" && len(copyProxyList) != len(p.proxylist) { - p.logger.Error().Msgf("%s Not all SOCKS5 Proxy available", ctl) - return nil, nil, fmt.Errorf("not all socks5 proxy available") - } - var dialer proxy.Dialer = getBaseDialer(timeout, p.mark) - var err error - for _, pr := range copyProxyList { - auth := proxy.Auth{ - User: pr.Username, - Password: pr.Password, + if p.user != "" && p.pass != "" { + p.httpServer.Handler = p.proxyAuth(p.handler()) + } else { + p.httpServer.Handler = p.handler() } - dialer, err = proxy.SOCKS5("tcp", pr.Address, &auth, dialer) - if err != nil { - p.logger.Error().Err(err).Msgf("%s Unable to create SOCKS5 dialer %s", ctl, pr.Address) - return nil, nil, err + if p.certFile != "" && p.keyFile != "" { + if err := p.httpServer.ListenAndServeTLS(p.certFile, p.keyFile); err != nil && err != http.ErrServerClosed { + p.logger.Fatal().Err(err).Msg("Unable to start HTTPS server") + } + } else { + if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + p.logger.Fatal().Err(err).Msg("Unable to start HTTP server") + } } + p.logger.Info().Msg("Server stopped") + } else { + go func() { + <-quit + if p.arpspoofer != nil { + err := p.arpspoofer.Stop() + if err != nil { + p.logger.Error().Err(err).Msg("Failed stopping arp spoofer") + } + } + if p.auto { + err := tproxyServer.clearRedirectRules(output) + if err != nil { + p.logger.Error().Err(err).Msg("Failed clearing iptables rules") + } + } + close(p.closeConn) + p.logger.Info().Msgf("[%s] Server is shutting down...", p.tproxyMode) + tproxyServer.Shutdown() + close(done) + }() + tproxyServer.ListenAndServe() } - socks := &http.Client{ - Transport: &http.Transport{ - Dial: dialer.Dial, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - p.logger.Debug().Msgf("%s Request chain: %s", ctl, p.printProxyChain(copyProxyList)) - return dialer, socks, nil + <-done } -func (p *proxyapp) doReq(w http.ResponseWriter, r *http.Request, sock *http.Client) *http.Response { - var ( - resp *http.Response - err error - msg string - client *http.Client - ) - if sock != nil { - client = sock - msg = "Connection to SOCKS5 server failed" - } else { - client = p.httpClient - msg = "Connection failed" - } - resp, err = client.Do(r) - if err != nil { - p.logger.Error().Err(err).Msg(msg) - w.WriteHeader(http.StatusServiceUnavailable) - return nil - } - if resp == nil { - p.logger.Error().Msg(msg) - w.WriteHeader(http.StatusServiceUnavailable) - return nil +func (p *proxyapp) handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + p.handleTunnel(w, r) + } else { + p.handleForward(w, r) + } } - return resp } func (p *proxyapp) handleForward(w http.ResponseWriter, r *http.Request) { @@ -786,7 +620,7 @@ func (p *proxyapp) handleForward(w http.ResponseWriter, r *http.Request) { var chunked bool var respBodySaved []byte p.httpClient.Timeout = timeout - if isLocalAddress(r.Host) { + if network.IsLocalAddress(r.Host) { resp = p.doReq(w, req, nil) } else { _, sockClient, err := p.getSocks() @@ -822,25 +656,25 @@ func (p *proxyapp) handleForward(w http.ResponseWriter, r *http.Request) { respBodySaved = bytes.Trim(respBodySaved, "\r\n\t ") } if p.json { - sniffheader := make([]string, 0, 4) + sniffdata := make([]string, 0, 4) j, err := json.Marshal(&layers.HTTPMessage{Request: r}) if err == nil { - sniffheader = append(sniffheader, string(j)) + sniffdata = append(sniffdata, string(j)) } j, err = json.Marshal(&layers.HTTPMessage{Response: resp}) if err == nil { - sniffheader = append(sniffheader, string(j)) + sniffdata = append(sniffdata, string(j)) } if p.body && len(reqBodySaved) > 0 { - sniffheader = append(sniffheader, fmt.Sprintf("{\"req_body\":%s}", reqBodySaved)) + sniffdata = append(sniffdata, fmt.Sprintf("{\"req_body\":%s}", reqBodySaved)) } if p.body && len(respBodySaved) > 0 { - sniffheader = append(sniffheader, fmt.Sprintf("{\"resp_body\":%s}", respBodySaved)) + sniffdata = append(sniffdata, fmt.Sprintf("{\"resp_body\":%s}", respBodySaved)) } - p.snifflogger.Log().Msg(fmt.Sprintf("[%s]", strings.Join(sniffheader, ","))) + p.snifflogger.Log().Msg(fmt.Sprintf("[%s]", strings.Join(sniffdata, ","))) } else { - id := p.getID() - p.snifflogger.Log().Msg(p.colorizeHTTP(req, resp, &reqBodySaved, &respBodySaved, id, false)) + id := getID(p.nocolor) + p.snifflogger.Log().Msg(colorizeHTTP(req, resp, &reqBodySaved, &respBodySaved, id, false, p.body, p.nocolor)) } } defer resp.Body.Close() @@ -896,7 +730,7 @@ func (p *proxyapp) handleForward(w http.ResponseWriter, r *http.Request) { } status := resp.Status if !p.nocolor { - status = p.colorizeStatus(resp.StatusCode, status, false) + status = colorizeStatus(resp.StatusCode, status, false) } p.logger.Debug().Msgf("%s - %s - %s - %s - %s", r.Proto, r.Method, r.Host, status, written) if len(resp.Trailer) == announcedTrailers { @@ -914,7 +748,7 @@ func (p *proxyapp) handleForward(w http.ResponseWriter, r *http.Request) { func (p *proxyapp) handleTunnel(w http.ResponseWriter, r *http.Request) { var dstConn net.Conn var err error - if isLocalAddress(r.Host) { + if network.IsLocalAddress(r.Host) { dstConn, err = getBaseDialer(timeout, p.mark).Dial("tcp", r.Host) if err != nil { p.logger.Error().Err(err).Msgf("Failed connecting to %s", r.Host) @@ -967,1125 +801,516 @@ func (p *proxyapp) handleTunnel(w http.ResponseWriter, r *http.Request) { go p.transfer(&wg, srcConn, dstConn, srcConnStr, dstConnStr, respChan) if p.sniff { wg.Add(1) - sniffheader := make([]string, 0, 6) - id := p.getID() + sniffdata := make([]string, 0, 6) + id := getID(p.nocolor) if p.json { - sniffheader = append( - sniffheader, + sniffdata = append( + sniffdata, fmt.Sprintf("{\"connection\":{\"src_remote\":%s,\"src_local\":%s,\"dst_local\":%s,\"dst_remote\":%s}}", srcConn.RemoteAddr(), srcConn.LocalAddr(), dstConn.LocalAddr(), dstConn.RemoteAddr()), ) j, err := json.Marshal(&layers.HTTPMessage{Request: r}) if err == nil { - sniffheader = append(sniffheader, string(j)) + sniffdata = append(sniffdata, string(j)) } } else { - var sb strings.Builder - if p.nocolor { - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" Src: %s→ %s → Dst: %s→ %s", srcConn.RemoteAddr(), srcConn.LocalAddr(), dstConn.LocalAddr(), dstConn.RemoteAddr())) - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" %s %s %s ", r.Method, r.Host, r.Proto)) - } else { - sb.WriteString(id) - sb.WriteString(colors.Green(fmt.Sprintf(" Src: %s→ %s", srcConn.RemoteAddr(), srcConn.LocalAddr())).String()) - sb.WriteString(colors.Magenta(" → ").String()) - sb.WriteString(colors.Blue(fmt.Sprintf("Dst: %s→ %s", dstConn.LocalAddr(), dstConn.RemoteAddr())).String()) - sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("%s ", p.colorizeTimestamp())) - sb.WriteString(id) - sb.WriteString(colors.Gray(fmt.Sprintf(" %s ", r.Method)).String()) - sb.WriteString(colors.YellowBg(fmt.Sprintf("%s ", r.Host)).String()) - sb.WriteString(colors.BlueBg(fmt.Sprintf("%s ", r.Proto)).String()) - } - sniffheader = append(sniffheader, sb.String()) + connections := colorizeConnections(srcConn.RemoteAddr(), srcConn.LocalAddr(), dstConn.RemoteAddr(), dstConn.LocalAddr(), id, r, p.nocolor) + sniffdata = append(sniffdata, connections) } - go p.sniffreporter(&wg, &sniffheader, reqChan, respChan, id) + go p.sniffreporter(&wg, &sniffdata, reqChan, respChan, id) } wg.Wait() } -func (p *proxyapp) sniffreporter(wg *sync.WaitGroup, sniffheader *[]string, reqChan, respChan <-chan layers.Layer, id string) { - defer wg.Done() - sniffheaderlen := len(*sniffheader) - var reqTLSQueue, respTLSQueue, reqHTTPQueue, respHTTPQueue []layers.Layer - for { - select { - case req, ok := <-reqChan: - if !ok { - return - } else { - switch req.(type) { - case *layers.TLSMessage: - reqTLSQueue = append(reqTLSQueue, req) - case *layers.HTTPMessage: - reqHTTPQueue = append(reqHTTPQueue, req) - } - } - case resp, ok := <-respChan: - if !ok { - return - } else { - switch resp.(type) { - case *layers.TLSMessage: - // request comes first or response arrived first - if len(reqTLSQueue) > 0 || len(respTLSQueue) == 0 { - respTLSQueue = append(respTLSQueue, resp) - // remove unmatched response if still no requests - } else if len(reqTLSQueue) == 0 && len(respTLSQueue) == 1 { - respTLSQueue = respTLSQueue[1:] - } - case *layers.HTTPMessage: - if len(reqHTTPQueue) > 0 || len(respHTTPQueue) == 0 { - respHTTPQueue = append(respHTTPQueue, resp) - } else if len(reqHTTPQueue) == 0 && len(respHTTPQueue) == 1 { - respHTTPQueue = respHTTPQueue[1:] - } - } - } - } - if len(reqHTTPQueue) > 0 && len(respHTTPQueue) > 0 { - req := reqHTTPQueue[0] - resp := respHTTPQueue[0] - reqHTTPQueue = reqHTTPQueue[1:] - respHTTPQueue = respHTTPQueue[1:] - - err := p.colorizeTunnel(req, resp, sniffheader, id) - if err == nil && len(*sniffheader) > sniffheaderlen { - if p.json { - p.snifflogger.Log().Msg(fmt.Sprintf("[%s]", strings.Join(*sniffheader, ","))) - } else { - p.snifflogger.Log().Msg(strings.Join(*sniffheader, "\n")) - } - } - *sniffheader = (*sniffheader)[:sniffheaderlen] +func (p *proxyapp) printProxyChain(pc []proxyEntry) string { + var sb strings.Builder + sb.WriteString("client → ") + if p.httpServerAddr != "" { + sb.WriteString(p.httpServerAddr) + if p.tproxyAddr != "" { + sb.WriteString(" | ") + sb.WriteString(p.tproxyAddr) + sb.WriteString(fmt.Sprintf(" (%s)", p.tproxyMode)) } - if len(reqTLSQueue) > 0 && len(respTLSQueue) > 0 { - req := reqTLSQueue[0] - resp := respTLSQueue[0] - reqTLSQueue = reqTLSQueue[1:] - respTLSQueue = respTLSQueue[1:] - - err := p.colorizeTunnel(req, resp, sniffheader, id) - if err == nil && len(*sniffheader) > sniffheaderlen { - if p.json { - p.snifflogger.Log().Msg(fmt.Sprintf("[%s]", strings.Join(*sniffheader, ","))) - } else { - p.snifflogger.Log().Msg(strings.Join(*sniffheader, "\n")) - } - } - *sniffheader = (*sniffheader)[:sniffheaderlen] - } - } -} - -func dispatch(data []byte) (layers.Layer, error) { - // TODO: check if it is http or tls beforehand - h := &layers.HTTPMessage{} - if err := h.Parse(data); err == nil && !h.IsEmpty() { - return h, nil - } - m := &layers.TLSMessage{} - if err := m.Parse(data); err == nil && len(m.Records) > 0 { - return m, nil - } - return nil, fmt.Errorf("failed sniffing traffic") -} - -func (p *proxyapp) copyWithTimeout(dst net.Conn, src net.Conn, msgChan chan<- layers.Layer) (written int64, err error) { - buf := make([]byte, 32*1024) -readLoop: - for { - select { - case <-p.closeConn: - break readLoop - default: - er := src.SetReadDeadline(time.Now().Add(readTimeout)) - if er != nil { - if errors.Is(er, net.ErrClosed) { - break readLoop - } - err = er - break readLoop - } - nr, er := src.Read(buf) - if nr > 0 { - er := dst.SetWriteDeadline(time.Now().Add(writeTimeout)) - if er != nil { - if errors.Is(er, net.ErrClosed) { - break readLoop - } - err = er - break readLoop - } - if p.sniff { - l, err := dispatch(buf[0:nr]) - if err == nil { - msgChan <- l - } - } - nw, ew := dst.Write(buf[0:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = errInvalidWrite - } - } - written += int64(nw) - if ew != nil { - if ne, ok := ew.(net.Error); ok && ne.Timeout() { - break readLoop - } - if errors.Is(ew, net.ErrClosed) { - break readLoop - } - } - if nr != nw { - err = io.ErrShortWrite - break readLoop - } - } - if er != nil { - if ne, ok := er.(net.Error); ok && ne.Timeout() { - continue // support long-lived connections (SSE, WebSockets, etc) - } - if errors.Is(er, net.ErrClosed) { - break readLoop - } - if er == io.EOF { - break readLoop - } - err = er - break readLoop - } - } - } - return written, err -} - -func (p *proxyapp) transfer( - wg *sync.WaitGroup, - dst net.Conn, - src net.Conn, - destName, srcName string, - msgChan chan<- layers.Layer, -) { - defer func() { - wg.Done() - close(msgChan) - }() - n, err := p.copyWithTimeout(dst, src, msgChan) - if err != nil { - p.logger.Error().Err(err).Msgf("Error during copy from %s to %s: %v", srcName, destName, err) - } - if n > 0 { - p.logger.Debug().Msgf("copied %s from %s to %s", prettifyBytes(n), srcName, destName) - } - src.Close() -} - -func parseProxyAuth(auth string) (username, password string, ok bool) { - if auth == "" { - return "", "", false - } - const prefix = "Basic " - if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) { - return "", "", false - } - c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) - if err != nil { - return "", "", false + } else if p.tproxyAddr != "" { + sb.WriteString(p.tproxyAddr) + sb.WriteString(fmt.Sprintf(" (%s)", p.tproxyMode)) } - cs := string(c) - username, password, ok = strings.Cut(cs, ":") - if !ok { - return "", "", false + sb.WriteString(" → ") + for _, pe := range pc { + sb.WriteString(pe.String()) + sb.WriteString(" → ") } - return username, password, true -} - -func (p *proxyapp) proxyAuth(next http.HandlerFunc) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Proxy-Authorization") - r.Header.Del("Proxy-Authorization") - username, password, ok := parseProxyAuth(auth) - if ok { - usernameHash := sha256.Sum256([]byte(username)) - passwordHash := sha256.Sum256([]byte(password)) - expectedUsernameHash := sha256.Sum256([]byte(p.user)) - expectedPasswordHash := sha256.Sum256([]byte(p.pass)) - - usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1) - passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1) - - if usernameMatch && passwordMatch { - next.ServeHTTP(w, r) - return - } - } - w.Header().Set("Proxy-Authenticate", `Basic realm="restricted", charset="UTF-8"`) - http.Error(w, "Proxy Authentication Required", http.StatusProxyAuthRequired) - }) + sb.WriteString("target") + return sb.String() } -func (p *proxyapp) handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodConnect { - p.handleTunnel(w, r) - } else { - p.handleForward(w, r) +func (p *proxyapp) updateSocksList() { + p.mu.Lock() + defer p.mu.Unlock() + p.availProxyList = p.availProxyList[:0] + var base proxy.Dialer = getBaseDialer(timeout, p.mark) + var dialer proxy.Dialer + var err error + failed := 0 + chainType := p.proxychain.Type + ctl := colorizeChainType(chainType, p.nocolor) + for _, pr := range p.proxylist { + auth := proxy.Auth{ + User: pr.Username, + Password: pr.Password, } - } -} - -func (p *proxyapp) applyRedirectRules() string { - _, tproxyPort, _ := net.SplitHostPort(p.tproxyAddr) - switch p.tproxyMode { - case "redirect": - cmdClear := exec.Command("bash", "-c", ` - set -ex - iptables -t nat -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true - iptables -t nat -D OUTPUT -p tcp -j GOHPTS 2>/dev/null || true - iptables -t nat -F GOHPTS 2>/dev/null || true - iptables -t nat -X GOHPTS 2>/dev/null || true - `) - cmdClear.Stdout = os.Stdout - cmdClear.Stderr = os.Stderr - if err := cmdClear.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - cmdInit := exec.Command("bash", "-c", ` - set -ex - iptables -t nat -N GOHPTS 2>/dev/null - iptables -t nat -F GOHPTS - - iptables -t nat -A GOHPTS -d 127.0.0.0/8 -j RETURN - iptables -t nat -A GOHPTS -p tcp --dport 22 -j RETURN - `) - cmdInit.Stdout = os.Stdout - cmdInit.Stderr = os.Stderr - if err := cmdInit.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - if p.httpServerAddr != "" { - _, httpPort, _ := net.SplitHostPort(p.httpServerAddr) - cmdHTTP := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t nat -A GOHPTS -p tcp --dport %s -j RETURN - `, httpPort)) - cmdHTTP.Stdout = os.Stdout - cmdHTTP.Stderr = os.Stderr - if err := cmdHTTP.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } + dialer, err = proxy.SOCKS5("tcp", pr.Address, &auth, base) + if err != nil { + p.logger.Error().Err(err).Msgf("%s Unable to create SOCKS5 dialer %s", ctl, pr.Address) + failed++ + continue } - if p.mark > 0 { - cmdMark := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t nat -A GOHPTS -p tcp -m mark --mark %d -j RETURN - `, p.mark)) - cmdMark.Stdout = os.Stdout - cmdMark.Stderr = os.Stderr - if err := cmdMark.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + ctx, cancel := context.WithTimeout(context.Background(), hopTimeout) + defer cancel() + conn, err := dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", pr.Address) + if err != nil && !errors.Is(err, io.EOF) { // check for EOF to include localhost SOCKS5 in the chain + p.logger.Error().Err(err).Msgf("%s Unable to connect to %s", ctl, pr.Address) + failed++ + if conn != nil { + conn.Close() } + continue } else { - cmd0 := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t nat -A GOHPTS -p tcp --dport %s -j RETURN - `, tproxyPort)) - cmd0.Stdout = os.Stdout - cmd0.Stderr = os.Stderr - if err := cmd0.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - if len(p.proxylist) > 0 { - for _, pr := range p.proxylist { - _, port, _ := net.SplitHostPort(pr.Address) - cmd1 := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t nat -A GOHPTS -p tcp --dport %s -j RETURN - `, port)) - cmd1.Stdout = os.Stdout - cmd1.Stderr = os.Stderr - if err := cmd1.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - if p.proxychain.Type == "strict" { - break - } - } + p.availProxyList = append(p.availProxyList, proxyEntry{Address: pr.Address, Username: pr.Username, Password: pr.Password}) + if conn != nil { + conn.Close() } + break } - cmdDocker := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - if command -v docker >/dev/null 2>&1 - then - for subnet in $(docker network inspect $(docker network ls -q) --format '{{range .IPAM.Config}}{{.Subnet}}{{end}}'); do - iptables -t nat -A GOHPTS -d "$subnet" -j RETURN - done - fi - - iptables -t nat -A GOHPTS -p tcp -j REDIRECT --to-ports %s - - iptables -t nat -C PREROUTING -p tcp -j GOHPTS 2>/dev/null || \ - iptables -t nat -A PREROUTING -p tcp -j GOHPTS - - iptables -t nat -C OUTPUT -p tcp -j GOHPTS 2>/dev/null || \ - iptables -t nat -A OUTPUT -p tcp -j GOHPTS - `, tproxyPort)) - cmdDocker.Stdout = os.Stdout - cmdDocker.Stderr = os.Stderr - if err := cmdDocker.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - case "tproxy": - cmdClear := exec.Command("bash", "-c", ` - set -ex - iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT 2>/dev/null || true - iptables -t mangle -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true - iptables -t mangle -F DIVERT 2>/dev/null || true - iptables -t mangle -F GOHPTS 2>/dev/null || true - iptables -t mangle -X DIVERT 2>/dev/null || true - iptables -t mangle -X GOHPTS 2>/dev/null || true - - ip rule del fwmark 1 lookup 100 2>/dev/null || true - ip route flush table 100 2>/dev/null || true - `) - cmdClear.Stdout = os.Stdout - cmdClear.Stderr = os.Stderr - if err := cmdClear.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - cmdInit0 := exec.Command("bash", "-c", ` - set -ex - ip rule add fwmark 1 lookup 100 2>/dev/null || true - ip route add local 0.0.0.0/0 dev lo table 100 2>/dev/null || true - - iptables -t mangle -N DIVERT 2>/dev/null || true - iptables -t mangle -F DIVERT - iptables -t mangle -A DIVERT -j MARK --set-mark 1 - iptables -t mangle -A DIVERT -j ACCEPT - - iptables -t mangle -N GOHPTS 2>/dev/null || true - iptables -t mangle -F GOHPTS - iptables -t mangle -A GOHPTS -d 127.0.0.0/8 -j RETURN - iptables -t mangle -A GOHPTS -d 224.0.0.0/4 -j RETURN - iptables -t mangle -A GOHPTS -d 255.255.255.255/32 -j RETURN - `) - cmdInit0.Stdout = os.Stdout - cmdInit0.Stderr = os.Stderr - if err := cmdInit0.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - cmdDocker := exec.Command("bash", "-c", ` - set -ex - if command -v docker >/dev/null 2>&1 - then - for subnet in $(docker network inspect $(docker network ls -q) --format '{{range .IPAM.Config}}{{.Subnet}}{{end}}'); do - iptables -t mangle -A GOHPTS -d "$subnet" -j RETURN - done - fi`) - cmdDocker.Stdout = os.Stdout - cmdDocker.Stderr = os.Stderr - if err := cmdDocker.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - cmdInit := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t mangle -A GOHPTS -p tcp -m mark --mark %d -j RETURN - iptables -t mangle -A GOHPTS -p tcp -j TPROXY --on-port %s --tproxy-mark 1 - - iptables -t mangle -A PREROUTING -p tcp -m socket -j DIVERT - iptables -t mangle -A PREROUTING -p tcp -j GOHPTS - `, p.mark, tproxyPort)) - cmdInit.Stdout = os.Stdout - cmdInit.Stderr = os.Stderr - if err := cmdInit.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - default: - p.logger.Fatal().Msgf("Unreachable, unknown mode: %s", p.tproxyMode) - } - cmdCat := exec.Command("bash", "-c", ` - cat /proc/sys/net/ipv4/ip_forward - `) - output, err := cmdCat.CombinedOutput() - if err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") } - cmdForward := exec.Command("bash", "-c", ` - set -ex - sysctl -w net.ipv4.ip_forward=1 - `) - cmdForward.Stdout = os.Stdout - cmdForward.Stderr = os.Stderr - _ = cmdForward.Run() - cmdClearForward := exec.Command("bash", "-c", ` - set -ex - iptables -t filter -F GOHPTS 2>/dev/null || true - iptables -t filter -D FORWARD -j GOHPTS 2>/dev/null || true - iptables -t filter -X GOHPTS 2>/dev/null || true - `) - cmdClearForward.Stdout = os.Stdout - cmdClearForward.Stderr = os.Stderr - if err := cmdClearForward.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + if failed == len(p.proxylist) { + p.logger.Error().Err(err).Msgf("%s No SOCKS5 Proxy available", ctl) + return } - var iface *net.Interface - if p.iface != nil { - iface = p.iface - } else { - iface, err = getDefaultInterface() + currentDialer := dialer + for _, pr := range p.proxylist[failed+1:] { + auth := proxy.Auth{ + User: pr.Username, + Password: pr.Password, + } + dialer, err = proxy.SOCKS5("tcp", pr.Address, &auth, currentDialer) + if err != nil { + p.logger.Error().Err(err).Msgf("%s Unable to create SOCKS5 dialer %s", ctl, pr.Address) + continue + } + // https://github.com/golang/go/issues/37549#issuecomment-1178745487 + ctx, cancel := context.WithTimeout(context.Background(), hopTimeout) + defer cancel() + conn, err := dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", pr.Address) if err != nil { - p.logger.Fatal().Err(err).Msg("failed getting default network interface") + p.logger.Error().Err(err).Msgf("%s Unable to connect to %s", ctl, pr.Address) + if conn != nil { + conn.Close() + } + continue } + conn.Close() + currentDialer = dialer + p.availProxyList = append(p.availProxyList, proxyEntry{Address: pr.Address, Username: pr.Username, Password: pr.Password}) } - cmdForwardFilter := exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t filter -N GOHPTS 2>/dev/null - iptables -t filter -F GOHPTS - iptables -t filter -A FORWARD -j GOHPTS - iptables -t filter -A GOHPTS -i %s -j ACCEPT - iptables -t filter -A GOHPTS -o %s -j ACCEPT - `, iface.Name, iface.Name)) - cmdForwardFilter.Stdout = os.Stdout - cmdForwardFilter.Stderr = os.Stderr - if err := cmdForwardFilter.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - return string(output) + p.logger.Debug().Msgf("%s Available SOCKS5 Proxy [%d/%d]: %s", ctl, + len(p.availProxyList), len(p.proxylist), p.printProxyChain(p.availProxyList)) } -func (p *proxyapp) clearRedirectRules(output string) error { - cmdClear := exec.Command("bash", "-c", ` - set -ex - iptables -t filter -F GOHPTS 2>/dev/null || true - iptables -t filter -D FORWARD -j GOHPTS 2>/dev/null || true - iptables -t filter -X GOHPTS 2>/dev/null || true - `) - cmdClear.Stdout = os.Stdout - cmdClear.Stderr = os.Stderr - if err := cmdClear.Run(); err != nil { - p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") - } - var cmd *exec.Cmd - switch p.tproxyMode { - case "redirect": - cmd = exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t nat -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true - iptables -t nat -D OUTPUT -p tcp -j GOHPTS 2>/dev/null || true - iptables -t nat -F GOHPTS 2>/dev/null || true - iptables -t nat -X GOHPTS 2>/dev/null || true - sysctl -w net.ipv4.ip_forward=%s - `, output)) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - case "tproxy": - cmd = exec.Command("bash", "-c", fmt.Sprintf(` - set -ex - iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT 2>/dev/null || true - iptables -t mangle -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true - iptables -t mangle -F DIVERT 2>/dev/null || true - iptables -t mangle -F GOHPTS 2>/dev/null || true - iptables -t mangle -X DIVERT 2>/dev/null || true - iptables -t mangle -X GOHPTS 2>/dev/null || true - - ip rule del fwmark 1 lookup 100 2>/dev/null || true - ip route flush table 100 2>/dev/null || true - sysctl -w net.ipv4.ip_forward=%s - `, output)) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr +// https://www.calhoun.io/how-to-shuffle-arrays-and-slices-in-go/ +func shuffle(vals []proxyEntry) { + r := rand.New(rand.NewSource(time.Now().Unix())) + for len(vals) > 0 { + n := len(vals) + randIndex := r.Intn(n) + vals[n-1], vals[randIndex] = vals[randIndex], vals[n-1] + vals = vals[:n-1] } - return cmd.Run() } -func (p *proxyapp) Run() { - done := make(chan bool) - quit := make(chan os.Signal, 1) - p.closeConn = make(chan bool) - signal.Notify(quit, os.Interrupt) - if p.arpspoofer != nil { - go p.arpspoofer.Start() - } - var tproxyServer *tproxyServer - if p.tproxyAddr != "" { - tproxyServer = newTproxyServer(p) +func (p *proxyapp) getSocks() (proxy.Dialer, *http.Client, error) { + if p.proxylist == nil { + return p.sockDialer, p.sockClient, nil } - var output string - if p.auto { - output = p.applyRedirectRules() + p.mu.RLock() + defer p.mu.RUnlock() + chainType := p.proxychain.Type + ctl := colorizeChainType(chainType, p.nocolor) + if len(p.availProxyList) == 0 { + p.logger.Error().Msgf("%s No SOCKS5 Proxy available", ctl) + return nil, nil, fmt.Errorf("no socks5 proxy available") } - if p.proxylist != nil { - chainType := p.proxychain.Type - var ctl string - if p.nocolor { - ctl = colors.WrapBrackets(chainType) - } else { - ctl = colors.WrapBrackets(colors.LightBlueBg(chainType).String()) - } - go func() { - for { - p.logger.Debug().Msgf("%s Updating available proxy", ctl) - p.updateSocksList() - time.Sleep(availProxyUpdateInterval) - } - }() + var chainLength int + if p.proxychain.Length > len(p.availProxyList) || p.proxychain.Length <= 0 { + chainLength = len(p.availProxyList) + } else { + chainLength = p.proxychain.Length } - if p.httpServer != nil { - go func() { - <-quit - if p.arpspoofer != nil { - err := p.arpspoofer.Stop() - if err != nil { - p.logger.Error().Err(err).Msg("Failed stopping arp spoofer") - } - } - if p.auto { - err := p.clearRedirectRules(output) - if err != nil { - p.logger.Error().Err(err).Msg("Failed clearing iptables rules") - } - } - close(p.closeConn) - if tproxyServer != nil { - p.logger.Info().Msgf("[%s] Server is shutting down...", p.tproxyMode) - tproxyServer.Shutdown() + copyProxyList := make([]proxyEntry, 0, len(p.availProxyList)) + switch chainType { + case "strict", "dynamic": + copyProxyList = p.availProxyList + case "random": + copyProxyList = append(copyProxyList, p.availProxyList...) + shuffle(copyProxyList) + copyProxyList = copyProxyList[:chainLength] + case "round_robin": + var start uint32 + for { + start = atomic.LoadUint32(&p.rrIndex) + next := start + 1 + if start >= p.rrIndexReset { + p.logger.Debug().Msg("Resetting round robin index") + next = 0 } - p.logger.Info().Msg("Server is shutting down...") - ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) - - defer cancel() - p.httpServer.SetKeepAlivesEnabled(false) - if err := p.httpServer.Shutdown(ctx); err != nil { - p.logger.Fatal().Err(err).Msg("Could not gracefully shutdown the server") + if atomic.CompareAndSwapUint32(&p.rrIndex, start, next) { + break } - close(done) - }() - if tproxyServer != nil { - go tproxyServer.ListenAndServe() - } - if p.user != "" && p.pass != "" { - p.httpServer.Handler = p.proxyAuth(p.handler()) - } else { - p.httpServer.Handler = p.handler() } - if p.certFile != "" && p.keyFile != "" { - if err := p.httpServer.ListenAndServeTLS(p.certFile, p.keyFile); err != nil && err != http.ErrServerClosed { - p.logger.Fatal().Err(err).Msg("Unable to start HTTPS server") - } - } else { - if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - p.logger.Fatal().Err(err).Msg("Unable to start HTTP server") - } + startIdx := int(start % uint32(len(p.availProxyList))) + for i := range chainLength { + idx := (startIdx + i) % len(p.availProxyList) + copyProxyList = append(copyProxyList, p.availProxyList[idx]) } - p.logger.Info().Msg("Server stopped") - } else { - go func() { - <-quit - if p.arpspoofer != nil { - err := p.arpspoofer.Stop() - if err != nil { - p.logger.Error().Err(err).Msg("Failed stopping arp spoofer") - } - } - if p.auto { - err := p.clearRedirectRules(output) - if err != nil { - p.logger.Error().Err(err).Msg("Failed clearing iptables rules") - } - } - close(p.closeConn) - p.logger.Info().Msgf("[%s] Server is shutting down...", p.tproxyMode) - tproxyServer.Shutdown() - close(done) - }() - tproxyServer.ListenAndServe() + default: + p.logger.Fatal().Msg("Unreachable") } - <-done -} - -type logWriter struct { - file *os.File -} - -func (writer logWriter) Write(bytes []byte) (int, error) { - return fmt.Fprintf(writer.file, "%s ERR %s", time.Now().Format(time.RFC3339), string(bytes)) -} - -type jsonLogWriter struct { - file *os.File -} - -func (writer jsonLogWriter) Write(bytes []byte) (int, error) { - return fmt.Fprintf(writer.file, "{\"level\":\"error\",\"time\":\"%s\",\"message\":\"%s\"}\n", - time.Now().Format(time.RFC3339), strings.TrimRight(string(bytes), "\n")) -} - -type proxyEntry struct { - Address string `yaml:"address"` - Username string `yaml:"username,omitempty"` - Password string `yaml:"password,omitempty"` -} - -func (pe proxyEntry) String() string { - return pe.Address -} - -type server struct { - Address string `yaml:"address"` - Interface string `yaml:"interface,omitempty"` - Username string `yaml:"username,omitempty"` - Password string `yaml:"password,omitempty"` - CertFile string `yaml:"cert_file,omitempty"` - KeyFile string `yaml:"key_file,omitempty"` -} -type chain struct { - Type string `yaml:"type"` - Length int `yaml:"length"` -} - -type serverConfig struct { - Chain chain `yaml:"chain"` - ProxyList []proxyEntry `yaml:"proxy_list"` - Server server `yaml:"server"` -} - -func getFullAddress(v, ip string, all bool) (string, error) { - if v == "" { - return "", nil + if len(copyProxyList) == 0 { + p.logger.Error().Msgf("%s No SOCKS5 Proxy available", ctl) + return nil, nil, fmt.Errorf("no socks5 proxy available") } - ipAddr := "127.0.0.1" - if all { - ipAddr = "0.0.0.0" + if p.proxychain.Type == "strict" && len(copyProxyList) != len(p.proxylist) { + p.logger.Error().Msgf("%s Not all SOCKS5 Proxy available", ctl) + return nil, nil, fmt.Errorf("not all socks5 proxy available") } - if port, err := strconv.Atoi(v); err == nil { - if ip != "" { - return fmt.Sprintf("%s:%d", ip, port), nil - } else { - return fmt.Sprintf("%s:%d", ipAddr, port), nil + var dialer proxy.Dialer = getBaseDialer(timeout, p.mark) + var err error + for _, pr := range copyProxyList { + auth := proxy.Auth{ + User: pr.Username, + Password: pr.Password, + } + dialer, err = proxy.SOCKS5("tcp", pr.Address, &auth, dialer) + if err != nil { + p.logger.Error().Err(err).Msgf("%s Unable to create SOCKS5 dialer %s", ctl, pr.Address) + return nil, nil, err } } - host, port, err := net.SplitHostPort(v) - if err != nil { - return "", err - } - if port == "" { - return "", fmt.Errorf("port is missing") - } - if ip != "" { - return fmt.Sprintf("%s:%s", ip, port), nil - } else if host == "" { - return fmt.Sprintf("%s:%s", ipAddr, port), nil + socks := &http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } - return fmt.Sprintf("%s:%s", host, port), nil + p.logger.Debug().Msgf("%s Request chain: %s", ctl, p.printProxyChain(copyProxyList)) + return dialer, socks, nil } -func expandPath(p string) string { - p = os.ExpandEnv(p) - if strings.HasPrefix(p, "~") { - if home, err := os.UserHomeDir(); err == nil { - return strings.Replace(p, "~", home, 1) - } +func (p *proxyapp) doReq(w http.ResponseWriter, r *http.Request, sock *http.Client) *http.Response { + var ( + resp *http.Response + err error + msg string + client *http.Client + ) + if sock != nil { + client = sock + msg = "Connection to SOCKS5 server failed" + } else { + client = p.httpClient + msg = "Connection failed" + } + resp, err = client.Do(r) + if err != nil { + p.logger.Error().Err(err).Msg(msg) + w.WriteHeader(http.StatusServiceUnavailable) + return nil + } + if resp == nil { + p.logger.Error().Msg(msg) + w.WriteHeader(http.StatusServiceUnavailable) + return nil } - return p + return resp } -func getAddressFromInterface(iface *net.Interface) (string, error) { - if iface == nil { - return "", nil - } - prefix, err := network.GetIPv4PrefixFromInterface(iface) +func (p *proxyapp) transfer( + wg *sync.WaitGroup, + dst net.Conn, + src net.Conn, + destName, srcName string, + msgChan chan<- layers.Layer, +) { + defer func() { + wg.Done() + close(msgChan) + }() + n, err := p.copyWithTimeout(dst, src, msgChan) if err != nil { - return "", err + p.logger.Error().Err(err).Msgf("Error during copy from %s to %s: %v", srcName, destName, err) + } + if n > 0 { + p.logger.Debug().Msgf("copied %s from %s to %s", prettifyBytes(n), srcName, destName) } - return prefix.Addr().String(), nil + src.Close() } -func New(conf *Config) *proxyapp { - var logger, snifflogger zerolog.Logger - var p proxyapp - logfile := os.Stdout - var snifflog *os.File - var err error - p.sniff = conf.Sniff - p.body = conf.Body - p.json = conf.JSON - if conf.LogFilePath != "" { - f, err := os.OpenFile(conf.LogFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) - if err != nil { - log.Fatalf("Failed to open log file: %v", err) - } - logfile = f - } - if conf.SniffLogFile != "" && conf.SniffLogFile != conf.LogFilePath { - f, err := os.OpenFile(conf.SniffLogFile, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) - if err != nil { - log.Fatalf("Failed to open sniff log file: %v", err) - } - snifflog = f - } else { - snifflog = logfile - } - p.nocolor = conf.JSON || conf.NoColor - if conf.JSON { - log.SetFlags(0) - jsonWriter := jsonLogWriter{file: logfile} - log.SetOutput(jsonWriter) - logger = zerolog.New(logfile).With().Timestamp().Logger() - snifflogger = zerolog.New(snifflog).With().Timestamp().Logger() - } else { - log.SetFlags(0) - logWriter := logWriter{file: logfile} - log.SetOutput(logWriter) - output := zerolog.ConsoleWriter{Out: logfile, NoColor: p.nocolor} - - output.FormatTimestamp = func(i any) string { - ts, _ := time.Parse(time.RFC3339, i.(string)) - if p.nocolor { - return colors.WrapBrackets(ts.Format(time.TimeOnly)) - } - return colors.Gray(colors.WrapBrackets(ts.Format(time.TimeOnly))).String() +func (p *proxyapp) gatherSniffData(req, resp layers.Layer, sniffdata *[]string, id string) error { + switch reqt := req.(type) { + case *layers.HTTPMessage: + var reqBodySaved, respBodySaved []byte + rest := resp.(*layers.HTTPMessage) + if p.body { + reqBodySaved, _ = io.ReadAll(reqt.Request.Body) + respBodySaved, _ = io.ReadAll(rest.Response.Body) + reqBodySaved = bytes.Trim(reqBodySaved, "\r\n\t ") + respBodySaved = bytes.Trim(respBodySaved, "\r\n\t ") } - output.FormatMessage = func(i any) string { - if i == nil || i == "" { - return "" + if p.json { + j1, err := json.Marshal(reqt) + if err != nil { + return err } - s := i.(string) - if p.nocolor { - return s + j2, err := json.Marshal(rest) + if err != nil { + return err } - result := ipPortPattern.ReplaceAllStringFunc(s, func(match string) string { - return colors.Gray(match).String() - }) - result = domainPattern.ReplaceAllStringFunc(result, func(match string) string { - return colors.Yellow(match).String() - }) - result = macPattern.ReplaceAllStringFunc(result, func(match string) string { - return colors.Yellow(match).String() - }) - return result - } - - output.FormatErrFieldName = func(i any) string { - return fmt.Sprintf("%s", i) - } - - output.FormatErrFieldValue = func(i any) string { - s := i.(string) - if p.nocolor { - return s + *sniffdata = append(*sniffdata, string(j1), string(j2)) + if p.body && len(reqBodySaved) > 0 { + *sniffdata = append(*sniffdata, fmt.Sprintf("{\"req_body\":%s}", reqBodySaved)) } - result := ipPortPattern.ReplaceAllStringFunc(s, func(match string) string { - return colors.Red(match).String() - }) - result = domainPattern.ReplaceAllStringFunc(result, func(match string) string { - return colors.Red(match).String() - }) - result = strings.ReplaceAll(result, "->", "→ ") - return result - } - logger = zerolog.New(output).With().Timestamp().Logger() - sniffoutput := zerolog.ConsoleWriter{Out: snifflog, TimeFormat: time.RFC3339, NoColor: p.nocolor, PartsExclude: []string{"level"}} - sniffoutput.FormatTimestamp = func(i any) string { - ts, _ := time.Parse(time.RFC3339, i.(string)) - if p.nocolor { - return colors.WrapBrackets(ts.Format(time.TimeOnly)) + if p.body && len(respBodySaved) > 0 { + *sniffdata = append(*sniffdata, fmt.Sprintf("{\"resp_body\":%s}", respBodySaved)) } - return colors.Gray(colors.WrapBrackets(ts.Format(time.TimeOnly))).String() + } else { + *sniffdata = append(*sniffdata, colorizeHTTP(reqt.Request, rest.Response, &reqBodySaved, &respBodySaved, id, true, p.body, p.nocolor)) } - sniffoutput.FormatMessage = func(i any) string { - if i == nil || i == "" { - return "" + case *layers.TLSMessage: + var chs *layers.TLSClientHello + var shs *layers.TLSServerHello + hsrec := reqt.Records[0] // len(Records) > 0 after dispatch + if hsrec.ContentType == layers.HandshakeTLSVal { // TODO: add more cases, parse all records + switch parser := layers.HSTLSParserByType(hsrec.Data[0]).(type) { + case *layers.TLSClientHello: + err := parser.ParseHS(hsrec.Data) + if err != nil { + return err + } + chs = parser } - return fmt.Sprintf("%s", i) } - sniffoutput.FormatErrFieldName = func(i any) string { - return fmt.Sprintf("%s", i) - } - - sniffoutput.FormatErrFieldValue = func(i any) string { - s := i.(string) - if p.nocolor { - return s + rest := resp.(*layers.TLSMessage) + hsrec = rest.Records[0] + if hsrec.ContentType == layers.HandshakeTLSVal { + switch parser := layers.HSTLSParserByType(hsrec.Data[0]).(type) { + case *layers.TLSServerHello: + err := parser.ParseHS(hsrec.Data) + if err != nil { + return err + } + shs = parser } - result := ipPortPattern.ReplaceAllStringFunc(s, func(match string) string { - return colors.Red(match).String() - }) - result = domainPattern.ReplaceAllStringFunc(result, func(match string) string { - return colors.Red(match).String() - }) - result = strings.ReplaceAll(result, "->", "→ ") - return result - } - snifflogger = zerolog.New(sniffoutput).With().Timestamp().Logger() - } - zerolog.SetGlobalLevel(zerolog.InfoLevel) - if conf.Debug { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - p.logger = &logger - p.snifflogger = &snifflogger - if runtime.GOOS == "linux" && conf.TProxy != "" && conf.TProxyOnly != "" { - p.logger.Fatal().Msg("Cannot specify TPRoxy and TProxyOnly at the same time") - } else if runtime.GOOS == "linux" && conf.TProxyMode != "" && !slices.Contains(SupportedTProxyModes, conf.TProxyMode) { - p.logger.Fatal().Msg("Incorrect TProxyMode provided") - } else if runtime.GOOS != "linux" && (conf.TProxy != "" || conf.TProxyOnly != "" || conf.TProxyMode != "") { - conf.TProxy = "" - conf.TProxyOnly = "" - conf.TProxyMode = "" - p.logger.Warn().Msgf("[%s] functionality only available on linux systems", conf.TProxyMode) - } - p.tproxyMode = conf.TProxyMode - tproxyonly := conf.TProxyOnly != "" - var tAddr string - if tproxyonly { - tAddr = conf.TProxyOnly - } else { - tAddr = conf.TProxy - } - if p.tproxyMode != "" { - p.tproxyAddr, err = getFullAddress(tAddr, "", true) - if err != nil { - p.logger.Fatal().Err(err).Msg("") - } - } else { - p.tproxyAddr, err = getFullAddress(tAddr, "", false) - if err != nil { - p.logger.Fatal().Err(err).Msg("") - } - } - p.auto = conf.Auto - if p.auto && runtime.GOOS != "linux" { - p.logger.Fatal().Msg("Auto setup is available only on linux systems") - } - p.mark = conf.Mark - if p.mark > 0 && runtime.GOOS != "linux" { - p.logger.Fatal().Msg("SO_MARK is available only on linux systems") - } - if p.mark > 0xFFFFFFFF { - p.logger.Fatal().Msg("SO_MARK is out of range") - } - if p.mark == 0 && p.tproxyMode == "tproxy" { - p.mark = 100 - } - var addrHTTP, addrSOCKS, certFile, keyFile string - if conf.ServerConfPath != "" { - var sconf serverConfig - yamlFile, err := os.ReadFile(expandPath(conf.ServerConfPath)) - if err != nil { - p.logger.Fatal().Err(err).Msg("[yaml config] Parsing failed") } - err = yaml.Unmarshal(yamlFile, &sconf) - if err != nil { - p.logger.Fatal().Err(err).Msg("[yaml config] Parsing failed") - } - if !tproxyonly { - if sconf.Server.Address == "" { - p.logger.Fatal().Err(err).Msg("[yaml config] Server address is empty") - } - if sconf.Server.Interface != "" && sconf.Server.Interface != "any" && conf.Interface != "0" { - p.iface, err = net.InterfaceByName(sconf.Server.Interface) + if chs != nil && shs != nil { + if p.json { + j1, err := json.Marshal(chs) if err != nil { - if ifIdx, err := strconv.Atoi(sconf.Server.Interface); err == nil { - p.iface, err = net.InterfaceByIndex(ifIdx) - if err != nil { - p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", sconf.Server.Interface) - } - } else { - p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", sconf.Server.Interface) - } + return err } + j2, err := json.Marshal(shs) + if err != nil { + return err + } + *sniffdata = append(*sniffdata, string(j1), string(j2)) + } else { + *sniffdata = append(*sniffdata, colorizeTLS(chs, shs, id, p.nocolor)) } - iAddr, err := getAddressFromInterface(p.iface) - if err != nil { - p.iface = nil - p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", sconf.Server.Interface) - } - addrHTTP, err = getFullAddress(sconf.Server.Address, iAddr, false) - if err != nil { - p.logger.Fatal().Err(err).Msg("") - } - p.httpServerAddr = addrHTTP - certFile = expandPath(sconf.Server.CertFile) - keyFile = expandPath(sconf.Server.KeyFile) - p.user = sconf.Server.Username - p.pass = sconf.Server.Password - } - p.proxychain = sconf.Chain - p.proxylist = sconf.ProxyList - p.availProxyList = make([]proxyEntry, 0, len(p.proxylist)) - if len(p.proxylist) == 0 { - p.logger.Fatal().Msg("[yaml config] Proxy list is empty") } - seen := make(map[string]struct{}) - for idx, pr := range p.proxylist { - addr, err := getFullAddress(pr.Address, "", false) - if err != nil { - p.logger.Fatal().Err(err).Msg("") - } - if _, ok := seen[addr]; !ok { - seen[addr] = struct{}{} - p.proxylist[idx].Address = addr + } + return nil +} + +func (p *proxyapp) sniffreporter(wg *sync.WaitGroup, sniffdata *[]string, reqChan, respChan <-chan layers.Layer, id string) { + defer wg.Done() + sniffdatalen := len(*sniffdata) + var reqTLSQueue, respTLSQueue, reqHTTPQueue, respHTTPQueue []layers.Layer + for { + select { + case req, ok := <-reqChan: + if !ok { + return } else { - p.logger.Fatal().Msgf("[yaml config] Duplicate entry `%s`", addr) + switch req.(type) { + case *layers.TLSMessage: + reqTLSQueue = append(reqTLSQueue, req) + case *layers.HTTPMessage: + reqHTTPQueue = append(reqHTTPQueue, req) + } } - } - addrSOCKS = p.printProxyChain(p.proxylist) - chainType := p.proxychain.Type - if !slices.Contains(supportedChainTypes, chainType) { - p.logger.Fatal().Msgf("[yaml config] Chain type `%s` is not supported", chainType) - } - p.rrIndexReset = rrIndexMax - } else { - if !tproxyonly { - if conf.Interface != "" && conf.Interface != "any" && conf.Interface != "0" { - p.iface, err = net.InterfaceByName(conf.Interface) - if err != nil { - if ifIdx, err := strconv.Atoi(conf.Interface); err == nil { - p.iface, err = net.InterfaceByIndex(ifIdx) - if err != nil { - p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", conf.Interface) - } - } else { - p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", conf.Interface) + case resp, ok := <-respChan: + if !ok { + return + } else { + switch resp.(type) { + case *layers.TLSMessage: + // request comes first or response arrived first + if len(reqTLSQueue) > 0 || len(respTLSQueue) == 0 { + respTLSQueue = append(respTLSQueue, resp) + // remove unmatched response if still no requests + } else if len(reqTLSQueue) == 0 && len(respTLSQueue) == 1 { + respTLSQueue = respTLSQueue[1:] + } + case *layers.HTTPMessage: + if len(reqHTTPQueue) > 0 || len(respHTTPQueue) == 0 { + respHTTPQueue = append(respHTTPQueue, resp) + } else if len(reqHTTPQueue) == 0 && len(respHTTPQueue) == 1 { + respHTTPQueue = respHTTPQueue[1:] } } } - iAddr, err := getAddressFromInterface(p.iface) - if err != nil { - p.logger.Warn().Err(err).Msgf("Failed binding to %s, using default interface", conf.Interface) - p.iface = nil - } - addrHTTP, err = getFullAddress(conf.AddrHTTP, iAddr, false) - if err != nil { - p.logger.Fatal().Err(err).Msg("") - } - p.httpServerAddr = addrHTTP - certFile = expandPath(conf.CertFile) - keyFile = expandPath(conf.KeyFile) - p.user = conf.ServerUser - p.pass = conf.ServerPass - } - addrSOCKS, err = getFullAddress(conf.AddrSOCKS, "", false) - if err != nil { - p.logger.Fatal().Err(err).Msg("") } - auth := proxy.Auth{ - User: conf.User, - Password: conf.Pass, - } - dialer, err := proxy.SOCKS5("tcp", addrSOCKS, &auth, getBaseDialer(timeout, p.mark)) - if err != nil { - p.logger.Fatal().Err(err).Msg("Unable to create SOCKS5 dialer") + if len(reqHTTPQueue) > 0 && len(respHTTPQueue) > 0 { + req := reqHTTPQueue[0] + resp := respHTTPQueue[0] + reqHTTPQueue = reqHTTPQueue[1:] + respHTTPQueue = respHTTPQueue[1:] + + err := p.gatherSniffData(req, resp, sniffdata, id) + if err == nil && len(*sniffdata) > sniffdatalen { + if p.json { + p.snifflogger.Log().Msg(fmt.Sprintf("[%s]", strings.Join(*sniffdata, ","))) + } else { + p.snifflogger.Log().Msg(strings.Join(*sniffdata, "\n")) + } + } + *sniffdata = (*sniffdata)[:sniffdatalen] } - p.sockDialer = dialer - if !tproxyonly { - p.sockClient = &http.Client{ - Transport: &http.Transport{ - Dial: dialer.Dial, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, + if len(reqTLSQueue) > 0 && len(respTLSQueue) > 0 { + req := reqTLSQueue[0] + resp := respTLSQueue[0] + reqTLSQueue = reqTLSQueue[1:] + respTLSQueue = respTLSQueue[1:] + + err := p.gatherSniffData(req, resp, sniffdata, id) + if err == nil && len(*sniffdata) > sniffdatalen { + if p.json { + p.snifflogger.Log().Msg(fmt.Sprintf("[%s]", strings.Join(*sniffdata, ","))) + } else { + p.snifflogger.Log().Msg(strings.Join(*sniffdata, "\n")) + } } + *sniffdata = (*sniffdata)[:sniffdatalen] } } - if !tproxyonly { - hs := &http.Server{ - Addr: addrHTTP, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - MaxHeaderBytes: 1 << 20, - Protocols: new(http.Protocols), - TLSConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - }, - }, - } - hs.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) - hs.Protocols.SetHTTP1(true) - p.httpServer = hs - p.httpClient = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - DialContext: getBaseDialer(timeout, p.mark).DialContext, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Timeout: timeout, - } +} + +func dispatch(data []byte) (layers.Layer, error) { + // TODO: check if it is http or tls beforehand + h := &layers.HTTPMessage{} + if err := h.Parse(data); err == nil && !h.IsEmpty() { + return h, nil } - if conf.ARPSpoof != "" { - if runtime.GOOS != "linux" { - p.logger.Fatal().Msg("ARP spoof setup is available only on linux systems") - } - if !p.auto { - p.logger.Warn().Msg("ARP spoof setup requires iptables configuration") - } - asc := &arpspoof.ARPSpoofConfig{Logger: p.logger} - errMsg := `Failed parsing arp options. Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true"` - for opt := range strings.SplitSeq(strings.ToLower(conf.ARPSpoof), ";") { - keyval := strings.SplitN(strings.Trim(opt, " "), " ", 2) - if len(keyval) < 2 { - p.logger.Fatal().Msg(errMsg) + m := &layers.TLSMessage{} + if err := m.Parse(data); err == nil && len(m.Records) > 0 { + return m, nil + } + return nil, fmt.Errorf("failed sniffing traffic") +} + +func (p *proxyapp) copyWithTimeout(dst net.Conn, src net.Conn, msgChan chan<- layers.Layer) (written int64, err error) { + buf := make([]byte, 32*1024) +readLoop: + for { + select { + case <-p.closeConn: + break readLoop + default: + er := src.SetReadDeadline(time.Now().Add(readTimeout)) + if er != nil { + if errors.Is(er, net.ErrClosed) { + break readLoop + } + err = er + break readLoop } - key := keyval[0] - val := keyval[1] - switch key { - case "targets": - asc.Targets = val - case "fullduplex": - if val == "true" { - asc.FullDuplex = true + nr, er := src.Read(buf) + if nr > 0 { + er := dst.SetWriteDeadline(time.Now().Add(writeTimeout)) + if er != nil { + if errors.Is(er, net.ErrClosed) { + break readLoop + } + err = er + break readLoop } - case "debug": - if val == "true" { - asc.Debug = true + if p.sniff { + l, err := dispatch(buf[0:nr]) + if err == nil { + msgChan <- l + } + } + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errInvalidWrite + } + } + written += int64(nw) + if ew != nil { + if ne, ok := ew.(net.Error); ok && ne.Timeout() { + break readLoop + } + if errors.Is(ew, net.ErrClosed) { + break readLoop + } + } + if nr != nw { + err = io.ErrShortWrite + break readLoop } - default: - p.logger.Fatal().Msg(errMsg) } - } - if p.iface != nil { - asc.Interface = p.iface.Name - } - p.arpspoofer, err = arpspoof.NewARPSpoofer(asc) - if err != nil { - p.logger.Fatal().Err(err).Msg("Failed creating arp spoofer") - } - } - if conf.ServerConfPath != "" { - p.logger.Info().Msgf("SOCKS5 Proxy [%s] chain: %s", p.proxychain.Type, addrSOCKS) - } else { - p.logger.Info().Msgf("SOCKS5 Proxy: %s", addrSOCKS) - } - if !tproxyonly { - if certFile != "" && keyFile != "" { - p.certFile = certFile - p.keyFile = keyFile - p.logger.Info().Msgf("HTTPS Proxy: %s", p.httpServerAddr) - } else { - p.logger.Info().Msgf("HTTP Proxy: %s", p.httpServerAddr) + if er != nil { + if ne, ok := er.(net.Error); ok && ne.Timeout() { + continue // support long-lived connections (SSE, WebSockets, etc) + } + if errors.Is(er, net.ErrClosed) { + break readLoop + } + if er == io.EOF { + break readLoop + } + err = er + break readLoop + } } } - if p.tproxyAddr != "" { - if p.tproxyMode == "tproxy" { - p.logger.Info().Msgf("TPROXY: %s", p.tproxyAddr) - } else { - p.logger.Info().Msgf("REDIRECT: %s", p.tproxyAddr) + return written, err +} + +func (p *proxyapp) proxyAuth(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Proxy-Authorization") + r.Header.Del("Proxy-Authorization") + username, password, ok := parseProxyAuth(auth) + if ok { + usernameHash := sha256.Sum256([]byte(username)) + passwordHash := sha256.Sum256([]byte(password)) + expectedUsernameHash := sha256.Sum256([]byte(p.user)) + expectedPasswordHash := sha256.Sum256([]byte(p.pass)) + + usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1) + passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1) + + if usernameMatch && passwordMatch { + next.ServeHTTP(w, r) + return + } } - } - return &p + w.Header().Set("Proxy-Authenticate", `Basic realm="restricted", charset="UTF-8"`) + http.Error(w, "Proxy Authentication Required", http.StatusProxyAuthRequired) + }) } diff --git a/helpers.go b/helpers.go new file mode 100644 index 0000000..6fbcaae --- /dev/null +++ b/helpers.go @@ -0,0 +1,131 @@ +package gohpts + +import ( + "encoding/base64" + "fmt" + "net" + "net/http" + "os" + "strconv" + "strings" + + "github.com/shadowy-pycoder/mshark/network" +) + +// Hop-by-hop headers +// https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1 +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "TE", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func delHopHeaders(header http.Header) { + for _, h := range hopHeaders { + header.Del(h) + } +} + +// delConnectionHeaders removes hop-by-hop headers listed in the "Connection" header +// https://datatracker.ietf.org/doc/html/rfc7230#section-6.1 +func delConnectionHeaders(h http.Header) { + for _, f := range h["Connection"] { + for sf := range strings.SplitSeq(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + h.Del(sf) + } + } + } +} + +func appendHostToXForwardHeader(header http.Header, host string) { + if prior, ok := header["X-Forwarded-For"]; ok { + host = strings.Join(prior, ", ") + ", " + host + } + header.Set("X-Forwarded-For", host) +} + +func getFullAddress(v, ip string, all bool) (string, error) { + if v == "" { + return "", nil + } + ipAddr := "127.0.0.1" + if all { + ipAddr = "0.0.0.0" + } + if port, err := strconv.Atoi(v); err == nil { + if ip != "" { + return fmt.Sprintf("%s:%d", ip, port), nil + } else { + return fmt.Sprintf("%s:%d", ipAddr, port), nil + } + } + host, port, err := net.SplitHostPort(v) + if err != nil { + return "", err + } + if port == "" { + return "", fmt.Errorf("port is missing") + } + if ip != "" { + return fmt.Sprintf("%s:%s", ip, port), nil + } else if host == "" { + return fmt.Sprintf("%s:%s", ipAddr, port), nil + } + return fmt.Sprintf("%s:%s", host, port), nil +} + +func expandPath(p string) string { + p = os.ExpandEnv(p) + if strings.HasPrefix(p, "~") { + if home, err := os.UserHomeDir(); err == nil { + return strings.Replace(p, "~", home, 1) + } + } + return p +} + +func getAddressFromInterface(iface *net.Interface) (string, error) { + if iface == nil { + return "", nil + } + prefix, err := network.GetIPv4PrefixFromInterface(iface) + if err != nil { + return "", err + } + return prefix.Addr().String(), nil +} + +func parseProxyAuth(auth string) (username, password string, ok bool) { + if auth == "" { + return "", "", false + } + const prefix = "Basic " + if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) { + return "", "", false + } + c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + if err != nil { + return "", "", false + } + cs := string(c) + username, password, ok = strings.Cut(cs, ":") + if !ok { + return "", "", false + } + return username, password, true +} diff --git a/tproxy_linux.go b/tproxy_linux.go index 8aa92cc..17ee869 100644 --- a/tproxy_linux.go +++ b/tproxy_linux.go @@ -7,15 +7,18 @@ import ( "context" "errors" "fmt" + "maps" "net" "net/netip" + "os" + "os/exec" + "slices" "strings" "sync" "syscall" "time" "unsafe" - "github.com/shadowy-pycoder/colors" "github.com/shadowy-pycoder/mshark/layers" "github.com/shadowy-pycoder/mshark/network" "golang.org/x/net/proxy" @@ -26,13 +29,13 @@ type tproxyServer struct { listener net.Listener quit chan struct{} wg sync.WaitGroup - pa *proxyapp + p *proxyapp } -func newTproxyServer(pa *proxyapp) *tproxyServer { +func newTproxyServer(p *proxyapp) *tproxyServer { ts := &tproxyServer{ quit: make(chan struct{}), - pa: pa, + p: p, } // https://iximiuz.com/en/posts/go-net-http-setsockopt-example/ lc := net.ListenConfig{ @@ -41,7 +44,7 @@ func newTproxyServer(pa *proxyapp) *tproxyServer { if err := conn.Control(func(fd uintptr) { operr = unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int(timeout.Milliseconds())) operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) - if ts.pa.tproxyMode == "tproxy" { + if ts.p.tproxyMode == "tproxy" { operr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1) } }); err != nil { @@ -51,13 +54,13 @@ func newTproxyServer(pa *proxyapp) *tproxyServer { }, } - ln, err := lc.Listen(context.Background(), "tcp4", ts.pa.tproxyAddr) + ln, err := lc.Listen(context.Background(), "tcp4", ts.p.tproxyAddr) if err != nil { var msg string if errors.Is(err, unix.EPERM) { msg = "try `sudo setcap 'cap_net_admin+ep` for the binary or run with sudo:" } - ts.pa.logger.Fatal().Err(err).Msg(msg) + ts.p.logger.Fatal().Err(err).Msg(msg) } ts.listener = ln return ts @@ -78,13 +81,13 @@ func (ts *tproxyServer) serve() { case <-ts.quit: return default: - ts.pa.logger.Error().Err(err).Msg("Failed accepting connection") + ts.p.logger.Error().Err(err).Msg("Failed accepting connection") } } else { ts.wg.Add(1) err := conn.SetDeadline(time.Now().Add(timeout)) if err != nil { - ts.pa.logger.Error().Err(err).Msg("") + ts.p.logger.Error().Err(err).Msg("") } go func() { ts.handleConnection(conn) @@ -116,11 +119,11 @@ func (ts *tproxyServer) getOriginalDst(rawConn syscall.RawConn) (string, error) optlen := uint32(unsafe.Sizeof(originalDst)) err := getsockopt(int(fd), unix.SOL_IP, unix.SO_ORIGINAL_DST, unsafe.Pointer(&originalDst), &optlen) if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] getsockopt SO_ORIGINAL_DST failed", ts.pa.tproxyMode) + ts.p.logger.Error().Err(err).Msgf("[%s] getsockopt SO_ORIGINAL_DST failed", ts.p.tproxyMode) } }) if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] Failed invoking control connection", ts.pa.tproxyMode) + ts.p.logger.Error().Err(err).Msgf("[%s] Failed invoking control connection", ts.p.tproxyMode) return "", err } dstHost := netip.AddrFrom4(originalDst.Addr) @@ -135,42 +138,42 @@ func (ts *tproxyServer) handleConnection(srcConn net.Conn) { err error ) defer srcConn.Close() - switch ts.pa.tproxyMode { + switch ts.p.tproxyMode { case "redirect": rawConn, err := srcConn.(*net.TCPConn).SyscallConn() if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] Failed to get raw connection", ts.pa.tproxyMode) + ts.p.logger.Error().Err(err).Msgf("[%s] Failed to get raw connection", ts.p.tproxyMode) return } dst, err = ts.getOriginalDst(rawConn) if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] Failed to get destination address", ts.pa.tproxyMode) + ts.p.logger.Error().Err(err).Msgf("[%s] Failed to get destination address", ts.p.tproxyMode) return } - ts.pa.logger.Debug().Msgf("[%s] getsockopt SO_ORIGINAL_DST %s", ts.pa.tproxyMode, dst) + ts.p.logger.Debug().Msgf("[%s] getsockopt SO_ORIGINAL_DST %s", ts.p.tproxyMode, dst) case "tproxy": dst = srcConn.LocalAddr().String() - ts.pa.logger.Debug().Msgf("[%s] IP_TRANSPARENT %s", ts.pa.tproxyMode, dst) + ts.p.logger.Debug().Msgf("[%s] IP_TRANSPARENT %s", ts.p.tproxyMode, dst) default: - ts.pa.logger.Fatal().Msg("Unknown tproxyMode") + ts.p.logger.Fatal().Msg("Unknown tproxyMode") } - if isLocalAddress(dst) { - dstConn, err = getBaseDialer(timeout, ts.pa.mark).Dial("tcp", dst) + if network.IsLocalAddress(dst) { + dstConn, err = getBaseDialer(timeout, ts.p.mark).Dial("tcp", dst) if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] Failed connecting to %s", ts.pa.tproxyMode, dst) + ts.p.logger.Error().Err(err).Msgf("[%s] Failed connecting to %s", ts.p.tproxyMode, dst) return } } else { - sockDialer, _, err := ts.pa.getSocks() + sockDialer, _, err := ts.p.getSocks() if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] Failed getting SOCKS5 client", ts.pa.tproxyMode) + ts.p.logger.Error().Err(err).Msgf("[%s] Failed getting SOCKS5 client", ts.p.tproxyMode) return } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() dstConn, err = sockDialer.(proxy.ContextDialer).DialContext(ctx, "tcp", dst) if err != nil { - ts.pa.logger.Error().Err(err).Msgf("[%s] Failed connecting to %s", ts.pa.tproxyMode, dst) + ts.p.logger.Error().Err(err).Msgf("[%s] Failed connecting to %s", ts.p.tproxyMode, dst) return } } @@ -179,24 +182,24 @@ func (ts *tproxyServer) handleConnection(srcConn net.Conn) { dstConnStr := fmt.Sprintf("%s→ %s→ %s", dstConn.LocalAddr().String(), dstConn.RemoteAddr().String(), dst) srcConnStr := fmt.Sprintf("%s→ %s", srcConn.RemoteAddr().String(), srcConn.LocalAddr().String()) - ts.pa.logger.Debug().Msgf("[%s] src: %s - dst: %s", ts.pa.tproxyMode, srcConnStr, dstConnStr) + ts.p.logger.Debug().Msgf("[%s] src: %s - dst: %s", ts.p.tproxyMode, srcConnStr, dstConnStr) reqChan := make(chan layers.Layer) respChan := make(chan layers.Layer) var wg sync.WaitGroup wg.Add(2) - go ts.pa.transfer(&wg, dstConn, srcConn, dstConnStr, srcConnStr, reqChan) - go ts.pa.transfer(&wg, srcConn, dstConn, srcConnStr, dstConnStr, respChan) - if ts.pa.sniff { + go ts.p.transfer(&wg, dstConn, srcConn, dstConnStr, srcConnStr, reqChan) + go ts.p.transfer(&wg, srcConn, dstConn, srcConnStr, dstConnStr, respChan) + if ts.p.sniff { wg.Add(1) sniffheader := make([]string, 0, 6) - id := ts.pa.getID() - if ts.pa.json { + id := getID(ts.p.nocolor) + if ts.p.json { sniffheader = append( sniffheader, fmt.Sprintf( "{\"connection\":{\"tproxy_mode\":%s,\"src_remote\":%s,\"src_local\":%s,\"dst_local\":%s,\"dst_remote\":%s,\"original_dst\":%s}}", - ts.pa.tproxyMode, + ts.p.tproxyMode, srcConn.RemoteAddr(), srcConn.LocalAddr(), dstConn.LocalAddr(), @@ -205,20 +208,15 @@ func (ts *tproxyServer) handleConnection(srcConn net.Conn) { ), ) } else { - var sb strings.Builder - if ts.pa.nocolor { - sb.WriteString(id) - sb.WriteString(fmt.Sprintf(" Src: %s→ %s → Dst: %s→ %s Orig: %s", srcConn.RemoteAddr(), srcConn.LocalAddr(), dstConn.LocalAddr(), dstConn.RemoteAddr(), dst)) - } else { - sb.WriteString(id) - sb.WriteString(colors.Green(fmt.Sprintf(" Src: %s→ %s", srcConn.RemoteAddr(), srcConn.LocalAddr())).String()) - sb.WriteString(colors.Magenta(" → ").String()) - sb.WriteString(colors.Blue(fmt.Sprintf("Dst: %s→ %s ", dstConn.LocalAddr(), dstConn.RemoteAddr())).String()) - sb.WriteString(colors.BeigeBg(fmt.Sprintf("Orig Dst: %s", dst)).String()) - } - sniffheader = append(sniffheader, sb.String()) + connections := colorizeConnectionsTransparent( + srcConn.RemoteAddr(), + srcConn.LocalAddr(), + dstConn.RemoteAddr(), + dstConn.LocalAddr(), + dst, id, ts.p.nocolor) + sniffheader = append(sniffheader, connections) } - go ts.pa.sniffreporter(&wg, &sniffheader, reqChan, respChan, id) + go ts.p.sniffreporter(&wg, &sniffheader, reqChan, respChan, id) } wg.Wait() } @@ -234,10 +232,10 @@ func (ts *tproxyServer) Shutdown() { select { case <-done: - ts.pa.logger.Info().Msgf("[%s] Server gracefully shutdown", ts.pa.tproxyMode) + ts.p.logger.Info().Msgf("[%s] Server gracefully shutdown", ts.p.tproxyMode) return case <-time.After(shutdownTimeout): - ts.pa.logger.Error().Msgf("[%s] Server timed out waiting for connections to finish", ts.pa.tproxyMode) + ts.p.logger.Error().Msgf("[%s] Server timed out waiting for connections to finish", ts.p.tproxyMode) return } } @@ -259,6 +257,313 @@ func getBaseDialer(timeout time.Duration, mark uint) *net.Dialer { return dialer } -func getDefaultInterface() (*net.Interface, error) { - return network.GetDefaultInterface() +func (ts *tproxyServer) createSysctlOptCmd(opt, value, setex string, opts map[string]string) *exec.Cmd { + cmdCat := exec.Command("bash", "-c", fmt.Sprintf(` + cat /proc/sys/%s + `, strings.ReplaceAll(opt, ".", "/"))) + output, err := cmdCat.CombinedOutput() + if err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + opts[opt] = string(output) + cmd := exec.Command("bash", "-c", fmt.Sprintf(` + %s + sysctl -w %s=%s + `, setex, opt, value)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if !ts.p.debug { + cmd.Stdout = nil + } + return cmd +} + +func (ts *tproxyServer) applyRedirectRules() map[string]string { + _, tproxyPort, _ := net.SplitHostPort(ts.p.tproxyAddr) + var setex string + if ts.p.debug { + setex = "set -ex" + } + ipv4Settings := make(map[string]string, 5) + switch ts.p.tproxyMode { + case "redirect": + cmdClear := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true + iptables -t nat -D OUTPUT -p tcp -j GOHPTS 2>/dev/null || true + iptables -t nat -F GOHPTS 2>/dev/null || true + iptables -t nat -X GOHPTS 2>/dev/null || true + `, setex)) + cmdClear.Stdout = os.Stdout + cmdClear.Stderr = os.Stderr + if err := cmdClear.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + cmdInit := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -N GOHPTS 2>/dev/null + iptables -t nat -F GOHPTS + + iptables -t nat -A GOHPTS -d 127.0.0.0/8 -j RETURN + iptables -t nat -A GOHPTS -p tcp --dport 22 -j RETURN + `, setex)) + cmdInit.Stdout = os.Stdout + cmdInit.Stderr = os.Stderr + if err := cmdInit.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + if ts.p.httpServerAddr != "" { + _, httpPort, _ := net.SplitHostPort(ts.p.httpServerAddr) + cmdHTTP := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -A GOHPTS -p tcp --dport %s -j RETURN + `, setex, httpPort)) + cmdHTTP.Stdout = os.Stdout + cmdHTTP.Stderr = os.Stderr + if err := cmdHTTP.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + } + if ts.p.mark > 0 { + cmdMark := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -A GOHPTS -p tcp -m mark --mark %d -j RETURN + `, setex, ts.p.mark)) + cmdMark.Stdout = os.Stdout + cmdMark.Stderr = os.Stderr + if err := cmdMark.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + } else { + cmd0 := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -A GOHPTS -p tcp --dport %s -j RETURN + `, setex, tproxyPort)) + cmd0.Stdout = os.Stdout + cmd0.Stderr = os.Stderr + if err := cmd0.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + if len(ts.p.proxylist) > 0 { + for _, pr := range ts.p.proxylist { + _, port, _ := net.SplitHostPort(pr.Address) + cmd1 := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -A GOHPTS -p tcp --dport %s -j RETURN + `, setex, port)) + cmd1.Stdout = os.Stdout + cmd1.Stderr = os.Stderr + if err := cmd1.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + if ts.p.proxychain.Type == "strict" { + break + } + } + } + } + cmdDocker := exec.Command("bash", "-c", fmt.Sprintf(` + %s + if command -v docker >/dev/null 2>&1 + then + for subnet in $(docker network inspect $(docker network ls -q) --format '{{range .IPAM.Config}}{{.Subnet}}{{end}}'); do + iptables -t nat -A GOHPTS -d "$subnet" -j RETURN + done + fi + + iptables -t nat -A GOHPTS -p tcp -j REDIRECT --to-ports %s + + iptables -t nat -C PREROUTING -p tcp -j GOHPTS 2>/dev/null || \ + iptables -t nat -A PREROUTING -p tcp -j GOHPTS + + iptables -t nat -C OUTPUT -p tcp -j GOHPTS 2>/dev/null || \ + iptables -t nat -A OUTPUT -p tcp -j GOHPTS + `, setex, tproxyPort)) + cmdDocker.Stdout = os.Stdout + cmdDocker.Stderr = os.Stderr + if err := cmdDocker.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + case "tproxy": + cmdClear := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT 2>/dev/null || true + iptables -t mangle -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true + iptables -t mangle -F DIVERT 2>/dev/null || true + iptables -t mangle -F GOHPTS 2>/dev/null || true + iptables -t mangle -X DIVERT 2>/dev/null || true + iptables -t mangle -X GOHPTS 2>/dev/null || true + + ip rule del fwmark 1 lookup 100 2>/dev/null || true + ip route flush table 100 2>/dev/null || true + `, setex)) + cmdClear.Stdout = os.Stdout + cmdClear.Stderr = os.Stderr + if err := cmdClear.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + cmdInit0 := exec.Command("bash", "-c", fmt.Sprintf(` + %s + ip rule add fwmark 1 lookup 100 2>/dev/null || true + ip route add local 0.0.0.0/0 dev lo table 100 2>/dev/null || true + + iptables -t mangle -N DIVERT 2>/dev/null || true + iptables -t mangle -F DIVERT + iptables -t mangle -A DIVERT -j MARK --set-mark 1 + iptables -t mangle -A DIVERT -j ACCEPT + + iptables -t mangle -N GOHPTS 2>/dev/null || true + iptables -t mangle -F GOHPTS + iptables -t mangle -A GOHPTS -d 127.0.0.0/8 -j RETURN + iptables -t mangle -A GOHPTS -d 224.0.0.0/4 -j RETURN + iptables -t mangle -A GOHPTS -d 255.255.255.255/32 -j RETURN + `, setex)) + cmdInit0.Stdout = os.Stdout + cmdInit0.Stderr = os.Stderr + if err := cmdInit0.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + cmdDocker := exec.Command("bash", "-c", fmt.Sprintf(` + %s + if command -v docker >/dev/null 2>&1 + then + for subnet in $(docker network inspect $(docker network ls -q) --format '{{range .IPAM.Config}}{{.Subnet}}{{end}}'); do + iptables -t mangle -A GOHPTS -d "$subnet" -j RETURN + done + fi`, setex)) + cmdDocker.Stdout = os.Stdout + cmdDocker.Stderr = os.Stderr + if err := cmdDocker.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + cmdInit := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t mangle -A GOHPTS -p tcp -m mark --mark %d -j RETURN + iptables -t mangle -A GOHPTS -p tcp -j TPROXY --on-port %s --tproxy-mark 1 + + iptables -t mangle -A PREROUTING -p tcp -m socket -j DIVERT + iptables -t mangle -A PREROUTING -p tcp -j GOHPTS + `, setex, ts.p.mark, tproxyPort)) + cmdInit.Stdout = os.Stdout + cmdInit.Stderr = os.Stderr + if err := cmdInit.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + default: + ts.p.logger.Fatal().Msgf("Unreachable, unknown mode: %s", ts.p.tproxyMode) + } + _ = ts.createSysctlOptCmd("net.ipv4.ip_forward", "1", setex, ipv4Settings).Run() + cmdCheckBBR := exec.Command("bash", "-c", fmt.Sprintf(` + %s + lsmod | grep -q '^tcp_bbr' || modprobe tcp_bbr + `, setex)) + cmdCheckBBR.Stdout = os.Stdout + cmdCheckBBR.Stderr = os.Stderr + if !ts.p.debug { + cmdCheckBBR.Stdout = nil + } + _ = cmdCheckBBR.Run() + _ = ts.createSysctlOptCmd("net.ipv4.tcp_congestion_control", "bbr", setex, ipv4Settings).Run() + _ = ts.createSysctlOptCmd("net.core.default_qdisc", "fq", setex, ipv4Settings).Run() + _ = ts.createSysctlOptCmd("net.ipv4.tcp_tw_reuse", "1", setex, ipv4Settings).Run() + _ = ts.createSysctlOptCmd("net.ipv4.tcp_fin_timeout", "15", setex, ipv4Settings).Run() + cmdClearForward := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t filter -F GOHPTS 2>/dev/null || true + iptables -t filter -D FORWARD -j GOHPTS 2>/dev/null || true + iptables -t filter -X GOHPTS 2>/dev/null || true + `, setex)) + cmdClearForward.Stdout = os.Stdout + cmdClearForward.Stderr = os.Stderr + if err := cmdClearForward.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + var iface *net.Interface + var err error + if ts.p.iface != nil { + iface = ts.p.iface + } else { + iface, err = network.GetDefaultInterface() + if err != nil { + ts.p.logger.Fatal().Err(err).Msg("failed getting default network interface") + } + } + cmdForwardFilter := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t filter -N GOHPTS 2>/dev/null + iptables -t filter -F GOHPTS + iptables -t filter -A FORWARD -j GOHPTS + iptables -t filter -A GOHPTS -i %s -j ACCEPT + iptables -t filter -A GOHPTS -o %s -j ACCEPT + `, setex, iface.Name, iface.Name)) + cmdForwardFilter.Stdout = os.Stdout + cmdForwardFilter.Stderr = os.Stderr + if err := cmdForwardFilter.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + return ipv4Settings +} + +func (ts *tproxyServer) clearRedirectRules(opts map[string]string) error { + var setex string + if ts.p.debug { + setex = "set -ex" + } + cmdClear := exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t filter -F GOHPTS 2>/dev/null || true + iptables -t filter -D FORWARD -j GOHPTS 2>/dev/null || true + iptables -t filter -X GOHPTS 2>/dev/null || true + `, setex)) + cmdClear.Stdout = os.Stdout + cmdClear.Stderr = os.Stderr + if err := cmdClear.Run(); err != nil { + ts.p.logger.Fatal().Err(err).Msg("Failed while configuring iptables. Are you root?") + } + cmds := make([]string, 0, len(opts)) + for _, cmd := range slices.Sorted(maps.Keys(opts)) { + cmds = append(cmds, fmt.Sprintf("sysctl -w %s=%s", cmd, opts[cmd])) + } + cmdRestoreOpts := exec.Command("bash", "-c", fmt.Sprintf(` + %s + %s + `, setex, strings.Join(cmds, "\n"))) + cmdRestoreOpts.Stdout = os.Stdout + cmdRestoreOpts.Stderr = os.Stderr + if !ts.p.debug { + cmdRestoreOpts.Stdout = nil + } + _ = cmdRestoreOpts.Run() + var cmd *exec.Cmd + switch ts.p.tproxyMode { + case "redirect": + cmd = exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t nat -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true + iptables -t nat -D OUTPUT -p tcp -j GOHPTS 2>/dev/null || true + iptables -t nat -F GOHPTS 2>/dev/null || true + iptables -t nat -X GOHPTS 2>/dev/null || true + `, setex)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + case "tproxy": + cmd = exec.Command("bash", "-c", fmt.Sprintf(` + %s + iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT 2>/dev/null || true + iptables -t mangle -D PREROUTING -p tcp -j GOHPTS 2>/dev/null || true + iptables -t mangle -F DIVERT 2>/dev/null || true + iptables -t mangle -F GOHPTS 2>/dev/null || true + iptables -t mangle -X DIVERT 2>/dev/null || true + iptables -t mangle -X GOHPTS 2>/dev/null || true + + ip rule del fwmark 1 lookup 100 2>/dev/null || true + ip route flush table 100 2>/dev/null || true + `, setex)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if !ts.p.debug { + cmd.Stdout = nil + } + } + return cmd.Run() } diff --git a/tproxy_nonlinux.go b/tproxy_nonlinux.go index 76cbaa2..32e08d3 100644 --- a/tproxy_nonlinux.go +++ b/tproxy_nonlinux.go @@ -4,8 +4,8 @@ package gohpts import ( - "fmt" "net" + "os/exec" "sync" "syscall" "time" @@ -15,11 +15,11 @@ type tproxyServer struct { listener net.Listener quit chan struct{} wg sync.WaitGroup - pa *proxyapp + p *proxyapp } -func newTproxyServer(pa *proxyapp) *tproxyServer { - _ = pa +func newTproxyServer(p *proxyapp) *tproxyServer { + _ = p return nil } @@ -48,6 +48,20 @@ func getBaseDialer(timeout time.Duration, mark uint) *net.Dialer { return &net.Dialer{Timeout: timeout} } -func getDefaultInterface() (*net.Interface, error) { - return nil, fmt.Errorf("not implemented") +func (ts *tproxyServer) createSysctlOptCmd(opt, value, setex string, opts map[string]string) *exec.Cmd { + _ = opt + _ = value + _ = setex + _ = opts + return nil +} + +func (ts *tproxyServer) applyRedirectRules() map[string]string { + _ = ts.createSysctlOptCmd("", "", "", nil) + return nil +} + +func (ts *tproxyServer) clearRedirectRules(opts map[string]string) error { + _ = opts + return nil } diff --git a/version.go b/version.go index ab305fa..a9462f7 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package gohpts -const Version string = "gohpts v1.9.3" +const Version string = "gohpts v1.9.4"