diff --git a/device/device.go b/device/device.go index f2c1797..2184092 100644 --- a/device/device.go +++ b/device/device.go @@ -17,8 +17,8 @@ import ( "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" + "github.com/amnezia-vpn/amneziawg-go/tls/pipe" "github.com/amnezia-vpn/amneziawg-go/tun" - "github.com/leninalive/udptlspipe/pipe" "github.com/tevino/abool/v2" ) diff --git a/go.mod b/go.mod index 1a409ca..d194233 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.21.6 toolchain go1.21.8 require ( - github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f + github.com/gobwas/ws v1.3.2 + github.com/refraction-networking/utls v1.6.2 github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.19.0 golang.org/x/net v0.21.0 @@ -15,15 +16,12 @@ require ( ) require ( - github.com/AdguardTeam/golibs v0.20.0 // indirect github.com/andybalholm/brotli v1.0.6 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gobwas/ws v1.3.2 // indirect github.com/google/btree v1.0.1 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/quic-go/quic-go v0.40.1 // indirect - github.com/refraction-networking/utls v1.6.2 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index 00f7acc..6e84baf 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,3 @@ -github.com/AdguardTeam/golibs v0.20.0 h1:A9FIdYq7wUKhFYy3z+YZ/Aw5oFUYgW+xgaVAJ0pnnPY= -github.com/AdguardTeam/golibs v0.20.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMssnipVXCxDL/w= -github.com/ameshkov/udptlspipe v1.3.1 h1:e+eC2Yb+04KPzH9b/Uktwn6W6lw5CgbFdHnGfAaofx8= -github.com/ameshkov/udptlspipe v1.3.1/go.mod h1:UnpDx2J//7WS/RRe5hb2UVZpwJzHga95ArLkPS9aRBk= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= @@ -26,8 +22,6 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= -github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f h1:VR2M22cXDtgp78N1mkCmxiXj1zYIP9ScUXS8gMHi6Vs= -github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f/go.mod h1:U3O6PfEGIxmmxAkOucn8Ty1akGF/1N1lDPeHPLCz3Cg= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= diff --git a/tls/pipe/server.go b/tls/pipe/server.go new file mode 100644 index 0000000..71f4278 --- /dev/null +++ b/tls/pipe/server.go @@ -0,0 +1,710 @@ +// Package pipe implements the pipe logic, i.e. listening for TLS or UDP +// connections and proxying data to the target destination. +package pipe + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/gobwas/ws" + "github.com/amnezia-vpn/amneziawg-go/tls/tunnel" + "github.com/amnezia-vpn/amneziawg-go/tls/udp" + tls "github.com/refraction-networking/utls" + "golang.org/x/net/proxy" +) + +// defaultSNI is the default server name that will be used in both the client +// TLS ClientHello and the server's certificate when no TLS configuration is +// configured. +const defaultSNI = "example.org" + +// upgradeTimeout is the read timeout for the first auth packet. +const upgradeTimeout = time.Second * 60 + +// Server represents an udptlspipe pipe. Depending on whether it is created in +// server- or client- mode, it listens to TLS or UDP connections and pipes the +// data to the destination. +type Server struct { + listenAddr string + destinationAddr string + dialer proxy.Dialer + serverMode bool + + probeReverseProxyURL string + probeReverseProxyListen net.Listener + + // tlsConfig to use for TLS connections. In server mode it also has the + // certificate that will be used. + tlsConfig *tls.Config + + // password is a string that the server will search for in the first bytes. + // If not found, the server will return a stub web page. + password string + + // listen is the TLS listener for incoming connections + listen net.Listener + + // srcConns is a set that is used to track active incoming TCP connections. + srcConns map[net.Conn]struct{} + srcConnsMu *sync.Mutex + + // dstConns is a set that is used to track active connections to the proxy + // destination. + dstConns map[net.Conn]struct{} + dstConnsMu *sync.Mutex + + // Shutdown handling + // -- + + // lock protects started, tcpListener and udpListener. + lock sync.RWMutex + started bool + // wg tracks active workers. Stop won't finish until there is at least + // won't finish until there's at least one active worker. + wg sync.WaitGroup +} + +// Config represents the server configuration. +type Config struct { + // ListenAddr is the address (ip:port) where the server will be listening + // to. Depending on the mode the server uses, it will either listen for TLS + // or UDP connections. + ListenAddr string + + // DestinationAddr is the address (host:port) to where the server will try + // to connect. Depending on the mode the server uses, it will either + // connect to a TLS endpoint (the pipe server) or not. + DestinationAddr string + + // Password enables authentication of the pipe clients. If set, it also + // enables active probing protection. + Password string + + // ServerMode controls the way the pipe operates. When it's true, the pipe + // server operates in server mode, i.e. it accepts incoming TLS connections + // and proxies the data to the destination address over UDP. When it works + // in client mode, it is the other way around: accepts UDP traffic and + // proxies it to the destination pipe server over TLS. + ServerMode bool + + // URL of a proxy server that can be used for proxying traffic to the + // destination. + ProxyURL string + + // VerifyCertificate enables server certificate verification in client mode. + // If enabled, the client will verify the server certificate using the + // system root certs store. + VerifyCertificate bool + + // TLSServerName configures the server name to send in TLS ClientHello when + // operating in client mode and the server name that will be used when + // generating a stub certificate. If not set, the default domain name will + // be used for these purposes. + TLSServerName string + + // TLSCertificate is an optional field that allows to configure the TLS + // certificate to use when running in server mode. This option makes sense + // only for server mode. If not configured, the server will generate a stub + // self-signed certificate automatically. + TLSCertificate *tls.Certificate + + // ProbeReverseProxyURL is the URL that will be used by the reverse HTTP + // proxy to respond to unauthorized or proxy requests. If not specified, + // it will respond with a stub page 403 Forbidden. + ProbeReverseProxyURL string +} + +// createTLSConfig creates a TLS configuration as per the server configuration. +func createTLSConfig(config *Config) (tlsConfig *tls.Config, err error) { + serverName := config.TLSServerName + if serverName == "" { + log.Info("TLS server name is not configured, using %s by default", defaultSNI) + serverName = defaultSNI + } + + if config.ServerMode { + tlsCert := config.TLSCertificate + if tlsCert == nil { + log.Info("Generating a stub certificate for %s", serverName) + tlsCert, err = createStubCertificate(serverName) + if err != nil { + return nil, fmt.Errorf("failed to generate a stub certificate: %w", err) + } + } else { + log.Info("Using the supplied TLS certificate") + } + + tlsConfig = &tls.Config{ + ServerName: serverName, + Certificates: []tls.Certificate{*tlsCert}, + MinVersion: tls.VersionTLS12, + } + } else { + tlsConfig = &tls.Config{ + InsecureSkipVerify: !config.VerifyCertificate, + ServerName: serverName, + } + } + + return tlsConfig, nil +} + +// NewServer creates a new instance of a *Server. +func NewServer(config *Config) (s *Server, err error) { + s = &Server{ + listenAddr: config.ListenAddr, + destinationAddr: config.DestinationAddr, + password: config.Password, + probeReverseProxyURL: config.ProbeReverseProxyURL, + dialer: proxy.Direct, + serverMode: config.ServerMode, + srcConns: map[net.Conn]struct{}{}, + srcConnsMu: &sync.Mutex{}, + dstConns: map[net.Conn]struct{}{}, + dstConnsMu: &sync.Mutex{}, + } + + s.tlsConfig, err = createTLSConfig(config) + if err != nil { + return nil, fmt.Errorf("failed to prepare TLS configuration: %w", err) + } + + if config.ProxyURL != "" { + var u *url.URL + u, err = url.Parse(config.ProxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + + s.dialer, err = proxy.FromURL(u, s.dialer) + if err != nil { + return nil, fmt.Errorf("failed to initialize proxy dialer: %w", err) + } + } + + return s, nil +} + +// Addr returns the address the pipe listens to if it is started or nil. +func (s *Server) Addr() (addr net.Addr) { + if s.listen == nil { + return nil + } + + return s.listen.Addr() +} + +// Start starts the pipe, exits immediately if it failed to start +// listening. Start returns once all servers are considered up. +func (s *Server) Start() (err error) { + log.Info("Starting the server %s", s) + + s.lock.Lock() + defer s.lock.Unlock() + + if s.started { + return errors.New("Server is already started") + } + + s.listen, err = s.createListener() + if err != nil { + return fmt.Errorf("failed to start pipe: %w", err) + } + + if s.probeReverseProxyURL != "" { + err = s.startProbeReverseProxy() + if err != nil { + return fmt.Errorf("failed to start probe reverse proxy: %w", err) + } + } + + s.wg.Add(1) + go s.serve() + + s.started = true + log.Info("Server has been started") + + return nil +} + +// createListener creates a TLS listener in server mode and UDP listener in +// client mode. +func (s *Server) createListener() (l net.Listener, err error) { + if s.serverMode { + l, err = tls.Listen("tcp", s.listenAddr, s.tlsConfig) + if err != nil { + return nil, err + } + } else { + l, err = udp.Listen("udp", s.listenAddr) + if err != nil { + return nil, err + } + } + + return l, nil +} + +// startProbeReverseProxy starts a reverse HTTP proxy that will be used for +// answering unauthorized and probe requests. Returns the listener of that +// proxy. Original request URI will be appended to proxyURL. +func (s *Server) startProbeReverseProxy() (err error) { + proxyURL := s.probeReverseProxyURL + + if _, err = url.Parse(proxyURL); err != nil { + return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err) + } + + targetURL, err := url.Parse(s.probeReverseProxyURL) + if err != nil { + return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err) + } + + handler := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetURL) + r.Out.Host = targetURL.Host + }, + } + + srv := &http.Server{ + ReadHeaderTimeout: upgradeTimeout, + Handler: handler, + } + + s.probeReverseProxyListen, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return fmt.Errorf("failed to start probe reverse proxy: %w", err) + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + + log.Info("Starting probe reverse proxy") + sErr := srv.Serve(s.probeReverseProxyListen) + log.Info("Probe reverse proxy has been stopped due to: %v", sErr) + }() + + return nil +} + +// Shutdown stops the pipe and waits for all active connections to close. +func (s *Server) Shutdown(ctx context.Context) (err error) { + log.Info("Stopping the server %s", s) + + s.stopServeLoop() + + // Closing the udpConn thread. + log.OnCloserError(s.listen, log.DEBUG) + + if s.probeReverseProxyListen != nil { + log.OnCloserError(s.probeReverseProxyListen, log.DEBUG) + } + + // Closing active TCP connections. + s.closeConnections(s.srcConnsMu, s.srcConns) + + // Closing active UDP connections. + s.closeConnections(s.dstConnsMu, s.dstConns) + + // Wait until all worker threads finish working + err = s.waitShutdown(ctx) + + log.Info("Server has been stopped") + + return err +} + +// closeConnections closes all active connections. +func (s *Server) closeConnections(mu *sync.Mutex, conns map[net.Conn]struct{}) { + mu.Lock() + defer mu.Unlock() + + for c := range conns { + _ = c.SetReadDeadline(time.Unix(1, 0)) + + log.OnCloserError(c, log.DEBUG) + } +} + +// stopServeLoop sets the started flag to false thus stopping the serving loop. +func (s *Server) stopServeLoop() { + s.lock.Lock() + defer s.lock.Unlock() + + s.started = false +} + +// type check +var _ fmt.Stringer = (*Server)(nil) + +// String implements the fmt.Stringer interface for *Server. +func (s *Server) String() (str string) { + switch s.serverMode { + case true: + return fmt.Sprintf("tls://%s <-> udp://%s", s.listenAddr, s.destinationAddr) + default: + return fmt.Sprintf("udp://%s <-> tls://%s", s.listenAddr, s.destinationAddr) + } +} + +// serve implements the pipe logic, i.e. accepts new connections and tunnels +// data to the destination. +func (s *Server) serve() { + defer s.wg.Done() + defer log.OnPanicAndExit("serve", 1) + + defer log.OnCloserError(s.listen, log.DEBUG) + + for s.isStarted() { + err := s.acceptConn() + if err != nil { + if !s.isStarted() { + return + } + + log.Error("exit serve loop due to: %v", err) + + return + } + } +} + +// acceptConn accepts new incoming and tracks active connections. +func (s *Server) acceptConn() (err error) { + conn, err := s.listen.Accept() + if err != nil { + // This type of errors should not lead to stopping the server. + if errors.Is(os.ErrDeadlineExceeded, err) { + return nil + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return nil + } + + return err + } + + log.Debug("Accepted new connection from %s", conn.RemoteAddr()) + + s.saveSrcConn(conn) + + s.wg.Add(1) + go s.serveConn(conn) + + return nil +} + +// saveSrcConn tracks the connection to allow unblocking reads on shutdown. +func (s *Server) saveSrcConn(conn net.Conn) { + s.srcConnsMu.Lock() + defer s.srcConnsMu.Unlock() + + // Track the connection to allow unblocking reads on shutdown. + s.srcConns[conn] = struct{}{} +} + +// closeSrcConn closes the source connection and cleans up after it. +func (s *Server) closeSrcConn(conn net.Conn) { + log.OnCloserError(conn, log.DEBUG) + + s.srcConnsMu.Lock() + defer s.srcConnsMu.Unlock() + + delete(s.srcConns, conn) +} + +// saveDstConn tracks the connection to allow unblocking reads on shutdown. +func (s *Server) saveDstConn(conn net.Conn) { + s.dstConnsMu.Lock() + defer s.dstConnsMu.Unlock() + + // Track the connection to allow unblocking reads on shutdown. + s.dstConns[conn] = struct{}{} +} + +// closeDstConn closes the destination connection and cleans up after it. +func (s *Server) closeDstConn(conn net.Conn) { + // No destination connection opened yet, do nothing. + if conn != nil { + return + } + + log.OnCloserError(conn, log.DEBUG) + + s.dstConnsMu.Lock() + defer s.dstConnsMu.Unlock() + + delete(s.dstConns, conn) +} + +// readWriteCloser is a helper object that's used for replacing +// io.ReadWriteCloser when the server peeked into the connection. +type readWriteCloser struct { + io.Reader + io.Writer + io.Closer +} + +// upgradeClientConn +func (s *Server) upgradeClientConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) { + log.Debug("Upgrading connection to %s", conn.RemoteAddr()) + + // Give up to 60 seconds on the upgrade and authentication. + _ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout)) + defer func() { + // Remove the deadline when it's not required any more. + _ = conn.SetReadDeadline(time.Time{}) + }() + + u, err := url.Parse(fmt.Sprintf("wss://%s/?password=%s", s.tlsConfig.ServerName, s.password)) + if err != nil { + return nil, err + } + + var br *bufio.Reader + br, _, err = ws.DefaultDialer.Upgrade(conn, u) + if err != nil { + return nil, fmt.Errorf("failed to upgrade: %w", err) + } + + if br != nil && br.Buffered() > 0 { + // If Upgrade returned a non-empty reader, then probably the server + // immediately sent some data. This is not the expected behavior so + // raise an error here. + return nil, fmt.Errorf("received initial data len=%d from the server", br.Buffered()) + } + + return newWsConn( + &readWriteCloser{ + Reader: conn, + Writer: conn, + Closer: conn, + }, + conn.RemoteAddr(), + ws.StateClientSide, + ), nil +} + +// respondToProbe writes a dummy response to the client if it's not authorized +// or if it's a probe. +func (s *Server) respondToProbe(rwc io.ReadWriteCloser, req *http.Request) { + if s.probeReverseProxyListen == nil { + log.Debug("No probe reverse proxy configured, respond with a dummy 403 page") + + response := fmt.Sprintf("%s 403 Forbidden\r\n", req.Proto) + + "Server: nginx\r\n" + + fmt.Sprintf("Date: %s\r\n", time.Now().Format(http.TimeFormat)) + + "Content-Type: text/html\r\n" + + "Connection: close\r\n" + + "\r\n" + + "\r\n" + + "