diff --git a/gohpts.go b/gohpts.go index 0355991..ab67a37 100644 --- a/gohpts.go +++ b/gohpts.go @@ -14,24 +14,60 @@ import ( "golang.org/x/net/proxy" ) +// Hop-by-hop headers +// https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1 +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "TE", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func delHopHeaders(header http.Header) { + for _, h := range hopHeaders { + header.Del(h) + } +} + +func appendHostToXForwardHeader(header http.Header, host string) { + if prior, ok := header["X-Forwarded-For"]; ok { + host = strings.Join(prior, ", ") + ", " + host + } + header.Set("X-Forwarded-For", host) +} + type app struct { hs *http.Server sc *http.Client + dialer proxy.Dialer logger *zerolog.Logger } -func (app *app) handleSOCKS(w http.ResponseWriter, r *http.Request) { +func (app *app) handleForward(w http.ResponseWriter, r *http.Request) { + req, err := http.NewRequest(r.Method, r.URL.String(), r.Body) if err != nil { app.logger.Error().Err(err).Msgf("Error during NewRequest() %s: %s", r.URL.String(), err) w.WriteHeader(http.StatusInternalServerError) return } - - for key, values := range r.Header { - for _, value := range values { - req.Header.Add(key, value) - } + delHopHeaders(r.Header) + copyHeader(req.Header, r.Header) + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + appendHostToXForwardHeader(req.Header, clientIP) } resp, err := app.sc.Do(req) if err != nil { @@ -46,6 +82,9 @@ func (app *app) handleSOCKS(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() + delHopHeaders(resp.Header) + copyHeader(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) written, err := io.Copy(w, resp.Body) if err != nil { app.logger.Error().Err(err).Msgf("Error during Copy() %s: %s", r.URL.String(), err) @@ -56,13 +95,12 @@ func (app *app) handleSOCKS(w http.ResponseWriter, r *http.Request) { } func (app *app) handleTunnel(w http.ResponseWriter, r *http.Request) { - dstConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) + dstConn, err := app.dialer.Dial("tcp", r.Host) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } defer dstConn.Close() - w.WriteHeader(http.StatusOK) hj, ok := w.(http.Hijacker) @@ -77,7 +115,7 @@ func (app *app) handleTunnel(w http.ResponseWriter, r *http.Request) { } defer srcConn.Close() - dstConnStr := fmt.Sprintf("%s->%s", dstConn.LocalAddr().String(), dstConn.RemoteAddr().String()) + dstConnStr := fmt.Sprintf("%s->%s->%s", dstConn.LocalAddr().String(), dstConn.RemoteAddr().String(), r.Host) srcConnStr := fmt.Sprintf("%s->%s", srcConn.LocalAddr().String(), srcConn.RemoteAddr().String()) app.logger.Debug().Msgf("%s - %s - %s", r.Proto, r.Method, r.Host) @@ -104,7 +142,7 @@ func (app *app) handler() http.HandlerFunc { if r.Method == http.MethodConnect { app.handleTunnel(w, r) } else { - app.handleSOCKS(w, r) + app.handleForward(w, r) } } } @@ -147,6 +185,9 @@ func New(conf *Config) *app { Transport: &http.Transport{ Dial: dialer.Dial, }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } hs := &http.Server{ Addr: conf.AddrHTTP, @@ -156,5 +197,5 @@ func New(conf *Config) *app { } logger.Info().Msgf("SOCKS5 Proxy: %s", conf.AddrSOCKS) logger.Info().Msgf("HTTP Proxy: %s", conf.AddrHTTP) - return &app{hs: hs, sc: socks, logger: &logger} + return &app{hs: hs, sc: socks, dialer: dialer, logger: &logger} }