diff --git a/device/device.go b/device/device.go index f2c1797..2184092 100644 --- a/device/device.go +++ b/device/device.go @@ -17,8 +17,8 @@ import ( "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" + "github.com/amnezia-vpn/amneziawg-go/tls/pipe" "github.com/amnezia-vpn/amneziawg-go/tun" - "github.com/leninalive/udptlspipe/pipe" "github.com/tevino/abool/v2" ) diff --git a/go.mod b/go.mod index 1a409ca..d194233 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.21.6 toolchain go1.21.8 require ( - github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f + github.com/gobwas/ws v1.3.2 + github.com/refraction-networking/utls v1.6.2 github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.19.0 golang.org/x/net v0.21.0 @@ -15,15 +16,12 @@ require ( ) require ( - github.com/AdguardTeam/golibs v0.20.0 // indirect github.com/andybalholm/brotli v1.0.6 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gobwas/ws v1.3.2 // indirect github.com/google/btree v1.0.1 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/quic-go/quic-go v0.40.1 // indirect - github.com/refraction-networking/utls v1.6.2 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index 00f7acc..6e84baf 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,3 @@ -github.com/AdguardTeam/golibs v0.20.0 h1:A9FIdYq7wUKhFYy3z+YZ/Aw5oFUYgW+xgaVAJ0pnnPY= -github.com/AdguardTeam/golibs v0.20.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMssnipVXCxDL/w= -github.com/ameshkov/udptlspipe v1.3.1 h1:e+eC2Yb+04KPzH9b/Uktwn6W6lw5CgbFdHnGfAaofx8= -github.com/ameshkov/udptlspipe v1.3.1/go.mod h1:UnpDx2J//7WS/RRe5hb2UVZpwJzHga95ArLkPS9aRBk= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= @@ -26,8 +22,6 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= -github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f h1:VR2M22cXDtgp78N1mkCmxiXj1zYIP9ScUXS8gMHi6Vs= -github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f/go.mod h1:U3O6PfEGIxmmxAkOucn8Ty1akGF/1N1lDPeHPLCz3Cg= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= diff --git a/tls/pipe/server.go b/tls/pipe/server.go new file mode 100644 index 0000000..71f4278 --- /dev/null +++ b/tls/pipe/server.go @@ -0,0 +1,710 @@ +// Package pipe implements the pipe logic, i.e. listening for TLS or UDP +// connections and proxying data to the target destination. +package pipe + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/gobwas/ws" + "github.com/amnezia-vpn/amneziawg-go/tls/tunnel" + "github.com/amnezia-vpn/amneziawg-go/tls/udp" + tls "github.com/refraction-networking/utls" + "golang.org/x/net/proxy" +) + +// defaultSNI is the default server name that will be used in both the client +// TLS ClientHello and the server's certificate when no TLS configuration is +// configured. +const defaultSNI = "example.org" + +// upgradeTimeout is the read timeout for the first auth packet. +const upgradeTimeout = time.Second * 60 + +// Server represents an udptlspipe pipe. Depending on whether it is created in +// server- or client- mode, it listens to TLS or UDP connections and pipes the +// data to the destination. +type Server struct { + listenAddr string + destinationAddr string + dialer proxy.Dialer + serverMode bool + + probeReverseProxyURL string + probeReverseProxyListen net.Listener + + // tlsConfig to use for TLS connections. In server mode it also has the + // certificate that will be used. + tlsConfig *tls.Config + + // password is a string that the server will search for in the first bytes. + // If not found, the server will return a stub web page. + password string + + // listen is the TLS listener for incoming connections + listen net.Listener + + // srcConns is a set that is used to track active incoming TCP connections. + srcConns map[net.Conn]struct{} + srcConnsMu *sync.Mutex + + // dstConns is a set that is used to track active connections to the proxy + // destination. + dstConns map[net.Conn]struct{} + dstConnsMu *sync.Mutex + + // Shutdown handling + // -- + + // lock protects started, tcpListener and udpListener. + lock sync.RWMutex + started bool + // wg tracks active workers. Stop won't finish until there is at least + // won't finish until there's at least one active worker. + wg sync.WaitGroup +} + +// Config represents the server configuration. +type Config struct { + // ListenAddr is the address (ip:port) where the server will be listening + // to. Depending on the mode the server uses, it will either listen for TLS + // or UDP connections. + ListenAddr string + + // DestinationAddr is the address (host:port) to where the server will try + // to connect. Depending on the mode the server uses, it will either + // connect to a TLS endpoint (the pipe server) or not. + DestinationAddr string + + // Password enables authentication of the pipe clients. If set, it also + // enables active probing protection. + Password string + + // ServerMode controls the way the pipe operates. When it's true, the pipe + // server operates in server mode, i.e. it accepts incoming TLS connections + // and proxies the data to the destination address over UDP. When it works + // in client mode, it is the other way around: accepts UDP traffic and + // proxies it to the destination pipe server over TLS. + ServerMode bool + + // URL of a proxy server that can be used for proxying traffic to the + // destination. + ProxyURL string + + // VerifyCertificate enables server certificate verification in client mode. + // If enabled, the client will verify the server certificate using the + // system root certs store. + VerifyCertificate bool + + // TLSServerName configures the server name to send in TLS ClientHello when + // operating in client mode and the server name that will be used when + // generating a stub certificate. If not set, the default domain name will + // be used for these purposes. + TLSServerName string + + // TLSCertificate is an optional field that allows to configure the TLS + // certificate to use when running in server mode. This option makes sense + // only for server mode. If not configured, the server will generate a stub + // self-signed certificate automatically. + TLSCertificate *tls.Certificate + + // ProbeReverseProxyURL is the URL that will be used by the reverse HTTP + // proxy to respond to unauthorized or proxy requests. If not specified, + // it will respond with a stub page 403 Forbidden. + ProbeReverseProxyURL string +} + +// createTLSConfig creates a TLS configuration as per the server configuration. +func createTLSConfig(config *Config) (tlsConfig *tls.Config, err error) { + serverName := config.TLSServerName + if serverName == "" { + log.Info("TLS server name is not configured, using %s by default", defaultSNI) + serverName = defaultSNI + } + + if config.ServerMode { + tlsCert := config.TLSCertificate + if tlsCert == nil { + log.Info("Generating a stub certificate for %s", serverName) + tlsCert, err = createStubCertificate(serverName) + if err != nil { + return nil, fmt.Errorf("failed to generate a stub certificate: %w", err) + } + } else { + log.Info("Using the supplied TLS certificate") + } + + tlsConfig = &tls.Config{ + ServerName: serverName, + Certificates: []tls.Certificate{*tlsCert}, + MinVersion: tls.VersionTLS12, + } + } else { + tlsConfig = &tls.Config{ + InsecureSkipVerify: !config.VerifyCertificate, + ServerName: serverName, + } + } + + return tlsConfig, nil +} + +// NewServer creates a new instance of a *Server. +func NewServer(config *Config) (s *Server, err error) { + s = &Server{ + listenAddr: config.ListenAddr, + destinationAddr: config.DestinationAddr, + password: config.Password, + probeReverseProxyURL: config.ProbeReverseProxyURL, + dialer: proxy.Direct, + serverMode: config.ServerMode, + srcConns: map[net.Conn]struct{}{}, + srcConnsMu: &sync.Mutex{}, + dstConns: map[net.Conn]struct{}{}, + dstConnsMu: &sync.Mutex{}, + } + + s.tlsConfig, err = createTLSConfig(config) + if err != nil { + return nil, fmt.Errorf("failed to prepare TLS configuration: %w", err) + } + + if config.ProxyURL != "" { + var u *url.URL + u, err = url.Parse(config.ProxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + + s.dialer, err = proxy.FromURL(u, s.dialer) + if err != nil { + return nil, fmt.Errorf("failed to initialize proxy dialer: %w", err) + } + } + + return s, nil +} + +// Addr returns the address the pipe listens to if it is started or nil. +func (s *Server) Addr() (addr net.Addr) { + if s.listen == nil { + return nil + } + + return s.listen.Addr() +} + +// Start starts the pipe, exits immediately if it failed to start +// listening. Start returns once all servers are considered up. +func (s *Server) Start() (err error) { + log.Info("Starting the server %s", s) + + s.lock.Lock() + defer s.lock.Unlock() + + if s.started { + return errors.New("Server is already started") + } + + s.listen, err = s.createListener() + if err != nil { + return fmt.Errorf("failed to start pipe: %w", err) + } + + if s.probeReverseProxyURL != "" { + err = s.startProbeReverseProxy() + if err != nil { + return fmt.Errorf("failed to start probe reverse proxy: %w", err) + } + } + + s.wg.Add(1) + go s.serve() + + s.started = true + log.Info("Server has been started") + + return nil +} + +// createListener creates a TLS listener in server mode and UDP listener in +// client mode. +func (s *Server) createListener() (l net.Listener, err error) { + if s.serverMode { + l, err = tls.Listen("tcp", s.listenAddr, s.tlsConfig) + if err != nil { + return nil, err + } + } else { + l, err = udp.Listen("udp", s.listenAddr) + if err != nil { + return nil, err + } + } + + return l, nil +} + +// startProbeReverseProxy starts a reverse HTTP proxy that will be used for +// answering unauthorized and probe requests. Returns the listener of that +// proxy. Original request URI will be appended to proxyURL. +func (s *Server) startProbeReverseProxy() (err error) { + proxyURL := s.probeReverseProxyURL + + if _, err = url.Parse(proxyURL); err != nil { + return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err) + } + + targetURL, err := url.Parse(s.probeReverseProxyURL) + if err != nil { + return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err) + } + + handler := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetURL) + r.Out.Host = targetURL.Host + }, + } + + srv := &http.Server{ + ReadHeaderTimeout: upgradeTimeout, + Handler: handler, + } + + s.probeReverseProxyListen, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return fmt.Errorf("failed to start probe reverse proxy: %w", err) + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + + log.Info("Starting probe reverse proxy") + sErr := srv.Serve(s.probeReverseProxyListen) + log.Info("Probe reverse proxy has been stopped due to: %v", sErr) + }() + + return nil +} + +// Shutdown stops the pipe and waits for all active connections to close. +func (s *Server) Shutdown(ctx context.Context) (err error) { + log.Info("Stopping the server %s", s) + + s.stopServeLoop() + + // Closing the udpConn thread. + log.OnCloserError(s.listen, log.DEBUG) + + if s.probeReverseProxyListen != nil { + log.OnCloserError(s.probeReverseProxyListen, log.DEBUG) + } + + // Closing active TCP connections. + s.closeConnections(s.srcConnsMu, s.srcConns) + + // Closing active UDP connections. + s.closeConnections(s.dstConnsMu, s.dstConns) + + // Wait until all worker threads finish working + err = s.waitShutdown(ctx) + + log.Info("Server has been stopped") + + return err +} + +// closeConnections closes all active connections. +func (s *Server) closeConnections(mu *sync.Mutex, conns map[net.Conn]struct{}) { + mu.Lock() + defer mu.Unlock() + + for c := range conns { + _ = c.SetReadDeadline(time.Unix(1, 0)) + + log.OnCloserError(c, log.DEBUG) + } +} + +// stopServeLoop sets the started flag to false thus stopping the serving loop. +func (s *Server) stopServeLoop() { + s.lock.Lock() + defer s.lock.Unlock() + + s.started = false +} + +// type check +var _ fmt.Stringer = (*Server)(nil) + +// String implements the fmt.Stringer interface for *Server. +func (s *Server) String() (str string) { + switch s.serverMode { + case true: + return fmt.Sprintf("tls://%s <-> udp://%s", s.listenAddr, s.destinationAddr) + default: + return fmt.Sprintf("udp://%s <-> tls://%s", s.listenAddr, s.destinationAddr) + } +} + +// serve implements the pipe logic, i.e. accepts new connections and tunnels +// data to the destination. +func (s *Server) serve() { + defer s.wg.Done() + defer log.OnPanicAndExit("serve", 1) + + defer log.OnCloserError(s.listen, log.DEBUG) + + for s.isStarted() { + err := s.acceptConn() + if err != nil { + if !s.isStarted() { + return + } + + log.Error("exit serve loop due to: %v", err) + + return + } + } +} + +// acceptConn accepts new incoming and tracks active connections. +func (s *Server) acceptConn() (err error) { + conn, err := s.listen.Accept() + if err != nil { + // This type of errors should not lead to stopping the server. + if errors.Is(os.ErrDeadlineExceeded, err) { + return nil + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return nil + } + + return err + } + + log.Debug("Accepted new connection from %s", conn.RemoteAddr()) + + s.saveSrcConn(conn) + + s.wg.Add(1) + go s.serveConn(conn) + + return nil +} + +// saveSrcConn tracks the connection to allow unblocking reads on shutdown. +func (s *Server) saveSrcConn(conn net.Conn) { + s.srcConnsMu.Lock() + defer s.srcConnsMu.Unlock() + + // Track the connection to allow unblocking reads on shutdown. + s.srcConns[conn] = struct{}{} +} + +// closeSrcConn closes the source connection and cleans up after it. +func (s *Server) closeSrcConn(conn net.Conn) { + log.OnCloserError(conn, log.DEBUG) + + s.srcConnsMu.Lock() + defer s.srcConnsMu.Unlock() + + delete(s.srcConns, conn) +} + +// saveDstConn tracks the connection to allow unblocking reads on shutdown. +func (s *Server) saveDstConn(conn net.Conn) { + s.dstConnsMu.Lock() + defer s.dstConnsMu.Unlock() + + // Track the connection to allow unblocking reads on shutdown. + s.dstConns[conn] = struct{}{} +} + +// closeDstConn closes the destination connection and cleans up after it. +func (s *Server) closeDstConn(conn net.Conn) { + // No destination connection opened yet, do nothing. + if conn != nil { + return + } + + log.OnCloserError(conn, log.DEBUG) + + s.dstConnsMu.Lock() + defer s.dstConnsMu.Unlock() + + delete(s.dstConns, conn) +} + +// readWriteCloser is a helper object that's used for replacing +// io.ReadWriteCloser when the server peeked into the connection. +type readWriteCloser struct { + io.Reader + io.Writer + io.Closer +} + +// upgradeClientConn +func (s *Server) upgradeClientConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) { + log.Debug("Upgrading connection to %s", conn.RemoteAddr()) + + // Give up to 60 seconds on the upgrade and authentication. + _ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout)) + defer func() { + // Remove the deadline when it's not required any more. + _ = conn.SetReadDeadline(time.Time{}) + }() + + u, err := url.Parse(fmt.Sprintf("wss://%s/?password=%s", s.tlsConfig.ServerName, s.password)) + if err != nil { + return nil, err + } + + var br *bufio.Reader + br, _, err = ws.DefaultDialer.Upgrade(conn, u) + if err != nil { + return nil, fmt.Errorf("failed to upgrade: %w", err) + } + + if br != nil && br.Buffered() > 0 { + // If Upgrade returned a non-empty reader, then probably the server + // immediately sent some data. This is not the expected behavior so + // raise an error here. + return nil, fmt.Errorf("received initial data len=%d from the server", br.Buffered()) + } + + return newWsConn( + &readWriteCloser{ + Reader: conn, + Writer: conn, + Closer: conn, + }, + conn.RemoteAddr(), + ws.StateClientSide, + ), nil +} + +// respondToProbe writes a dummy response to the client if it's not authorized +// or if it's a probe. +func (s *Server) respondToProbe(rwc io.ReadWriteCloser, req *http.Request) { + if s.probeReverseProxyListen == nil { + log.Debug("No probe reverse proxy configured, respond with a dummy 403 page") + + response := fmt.Sprintf("%s 403 Forbidden\r\n", req.Proto) + + "Server: nginx\r\n" + + fmt.Sprintf("Date: %s\r\n", time.Now().Format(http.TimeFormat)) + + "Content-Type: text/html\r\n" + + "Connection: close\r\n" + + "\r\n" + + "\r\n" + + "403 Forbidden\r\n" + + "

403 Forbidden

\r\n" + + "
nginx
\r\n" + + "\r\n" + + "\r\n" + + _, _ = rwc.Write([]byte(response)) + + return + } + + log.Debug("Probe reverse proxy is configured, tunnel data to it") + + proxyConn, err := net.Dial("tcp", s.probeReverseProxyListen.Addr().String()) + if err != nil { + log.Error("Failed to connect to the probe reverse proxy: %v", err) + + return + } + + s.saveDstConn(proxyConn) + + tunnel.Tunnel("probeReverseProxy", rwc, proxyConn) +} + +// upgradeServerConn attempts to upgrade the server connection and returns a +// rwc that wraps the original connection and can be used for tunneling data. +func (s *Server) upgradeServerConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) { + log.Debug("Upgrading connection from %s", conn.RemoteAddr()) + + // Give up to 60 seconds on the upgrade and authentication. + _ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout)) + defer func() { + // Remove the deadline when it's not required any more. + _ = conn.SetReadDeadline(time.Time{}) + }() + + // bufio.Reader may read more than requested, so it's crucial to use + // TeeReader so that we could restore the bytes that has been read. + var buf bytes.Buffer + r := bufio.NewReader(io.TeeReader(conn, &buf)) + + req, err := http.ReadRequest(r) + if err != nil { + return nil, fmt.Errorf("cannot read HTTP request: %w", err) + } + + // Now that authentication check has been done restore the peeked up data + // so that it could be used further. + originalRwc := &readWriteCloser{ + Reader: io.MultiReader(bytes.NewReader(buf.Bytes()), conn), + Writer: conn, + Closer: conn, + } + + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + s.respondToProbe(originalRwc, req) + + return nil, fmt.Errorf("not a websocket") + } + + clientPassword := req.URL.Query().Get("password") + if s.password != "" && clientPassword != s.password { + s.respondToProbe(originalRwc, req) + + return nil, fmt.Errorf("wrong password: %s", clientPassword) + } + + _, err = ws.Upgrade(originalRwc) + if err != nil { + return nil, fmt.Errorf("failed to upgrade WebSocket: %w", err) + } + + return newWsConn(originalRwc, conn.RemoteAddr(), ws.StateServerSide), nil +} + +// serveConn processes incoming connection, authenticates it and proxies the +// data from it to the destination address. +func (s *Server) serveConn(conn net.Conn) { + defer func() { + s.wg.Done() + + s.closeSrcConn(conn) + }() + + var rwc io.ReadWriteCloser = conn + + if s.serverMode { + var err error + rwc, err = s.upgradeServerConn(conn) + if err != nil { + log.Error("failed to accept server conn: %v", err) + + return + } + } + + s.processConn(rwc) +} + +// processConn processes the prepared server connection that is passed as rwc. +func (s *Server) processConn(rwc io.ReadWriteCloser) { + var dstConn net.Conn + + defer s.closeDstConn(dstConn) + + dstConn, err := s.dialDst() + if err != nil { + log.Error("failed to connect to %s: %v", s.destinationAddr, err) + + return + } + + s.saveDstConn(dstConn) + + var dstRwc io.ReadWriteCloser = dstConn + if !s.serverMode { + dstRwc, err = s.upgradeClientConn(dstConn) + if err != nil { + log.Error("failed to upgrade: %v", err) + + return + } + } + + // Prepare ReadWriter objects for tunneling. + var srcRw, dstRw io.ReadWriter + srcRw = rwc + dstRw = dstRwc + + // When the client communicates with the server it uses encoded messages so + // connection between them needs to be wrapped. In server mode it is the + // source connection, in client mode it is the destination connection. + if s.serverMode { + srcRw = tunnel.NewMsgReadWriter(srcRw) + } else { + dstRw = tunnel.NewMsgReadWriter(dstRw) + } + + tunnel.Tunnel(s.String(), srcRw, dstRw) +} + +// dialDst creates a connection to the destination. Depending on the mode the +// server operates in, it is either a TLS connection or a UDP connection. +func (s *Server) dialDst() (conn net.Conn, err error) { + if s.serverMode { + return s.dialer.Dial("udp", s.destinationAddr) + } + + tcpConn, err := s.dialer.Dial("tcp", s.destinationAddr) + if err != nil { + return nil, fmt.Errorf("failed to open connection to %s: %w", s.destinationAddr, err) + } + + tlsConn := tls.UClient(tcpConn, s.tlsConfig, tls.HelloAndroid_11_OkHttp) + + err = tlsConn.Handshake() + if err != nil { + return nil, fmt.Errorf("cannot establish connection to %s: %w", s.destinationAddr, err) + } + + return tlsConn, nil +} + +// isStarted safely checks whether the pipe is started or not. +func (s *Server) isStarted() (started bool) { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.started +} + +// waitShutdown waits either until context deadline OR Server.wg. +func (s *Server) waitShutdown(ctx context.Context) (err error) { + // Using this channel to wait until all goroutines finish their work. + closed := make(chan struct{}) + go func() { + defer log.OnPanic("waitShutdown") + + // Wait until all active workers finished its work. + s.wg.Wait() + close(closed) + }() + + var ctxErr error + select { + case <-closed: + // Do nothing here. + case <-ctx.Done(): + ctxErr = ctx.Err() + } + + return ctxErr +} diff --git a/tls/pipe/tlsconfig.go b/tls/pipe/tlsconfig.go new file mode 100644 index 0000000..a321821 --- /dev/null +++ b/tls/pipe/tlsconfig.go @@ -0,0 +1,73 @@ +package pipe + +import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" + + tls "github.com/refraction-networking/utls" +) + +// createStubCertificate creates a stub TLS certificate for the pipe server. +// This stub cert is generated when the user does not specify any certificate. +func createStubCertificate(tlsServerName string) (cert *tls.Certificate, err error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(5 * 365 * time.Hour * 24) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{tlsServerName}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + template.DNSNames = append(template.DNSNames, tlsServerName) + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) + if err != nil { + return nil, err + } + + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + tlsCert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + return nil, err + } + + return &tlsCert, nil +} + +func publicKey(priv any) (pub any) { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} diff --git a/tls/pipe/websocket.go b/tls/pipe/websocket.go new file mode 100644 index 0000000..7d7ecd7 --- /dev/null +++ b/tls/pipe/websocket.go @@ -0,0 +1,85 @@ +package pipe + +import ( + "io" + "net" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +// wsConn represents a WebSocket connection that's been already initialized. +type wsConn struct { + rwc io.ReadWriteCloser + remoteAddr net.Addr + r *wsutil.Reader + w *wsutil.Writer +} + +// newWsConn creates a wrapper over the existing network connection that is +// able to send/read messages using WebSocket protocol. +func newWsConn(rwc io.ReadWriteCloser, remoteAddr net.Addr, state ws.State) (c *wsConn) { + r := wsutil.NewReader(rwc, state) + w := wsutil.NewWriter(rwc, state, ws.OpBinary) + + return &wsConn{ + rwc: rwc, + remoteAddr: remoteAddr, + r: r, + w: w, + } +} + +// type check +var _ io.ReadWriteCloser = (*wsConn)(nil) + +// Read implements the io.ReadWriteCloser interface for *wsConn. +func (w *wsConn) Read(b []byte) (n int, err error) { + n, err = w.r.Read(b) + if err == wsutil.ErrNoFrameAdvance { + log.Debug("Reading the next WebSocket frame from %v", w.remoteAddr) + + hdr, fErr := w.r.NextFrame() + if fErr != nil { + return 0, io.EOF + } + + log.Debug( + "Received WebSocket frame with opcode=%d len=%d fin=%v from %v", + hdr.OpCode, + hdr.Length, + hdr.Fin, + w.remoteAddr, + ) + + // Reading again after the frame has been read. + n, err = w.r.Read(b) + + // EOF in the case of wsutil.Reader does not mean that the connection is + // closed, it only means that the current frame is finished. + if err == io.EOF { + err = nil + } + } + + return n, err +} + +// Write implements the io.ReadWriteCloser interface for *wsConn. +func (w *wsConn) Write(b []byte) (n int, err error) { + log.Debug("Writing data len=%d to the WebSocket %v", len(b), w.remoteAddr) + + n, err = w.w.Write(b) + if err != nil { + return 0, err + } + + err = w.w.Flush() + + return n, err +} + +// Close implements the io.ReadWriteCloser interface for *wsConn. +func (w *wsConn) Close() (err error) { + return w.rwc.Close() +} diff --git a/tls/tunnel/msgreadwriter.go b/tls/tunnel/msgreadwriter.go new file mode 100644 index 0000000..a17f058 --- /dev/null +++ b/tls/tunnel/msgreadwriter.go @@ -0,0 +1,149 @@ +package tunnel + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "io" +) + +// MaxMessageLength is the maximum length that is safe to use. +// TODO(leninalive): Make it configurable. +const MaxMessageLength = 1320 + +// MinMessageLength is the minimum message size. If the message is smaller, it +// will be padded with random bytes. +const MinMessageLength = 100 + +// MaxPaddingLength is the maximum size of a random padding that's added to +// every message. +const MaxPaddingLength = 256 + +// MsgReadWriter is a wrapper over io.ReadWriter that encodes messages written +// to and read from the base writer. +type MsgReadWriter struct { + base io.ReadWriter +} + +// NewMsgReadWriter creates a new instance of *MsgReadWriter. +func NewMsgReadWriter(base io.ReadWriter) (rw *MsgReadWriter) { + return &MsgReadWriter{base: base} +} + +// type check +var _ io.ReadWriter = (*MsgReadWriter)(nil) + +// Read implements the io.ReadWriter interface for *MsgReadWriter. +func (rw *MsgReadWriter) Read(b []byte) (n int, err error) { + // Read the main message (always goes first). + msg, err := readPrefixed(rw.base) + if err != nil { + return 0, err + } + + // Skip padding. + _, err = readPrefixed(rw.base) + if err != nil { + return 0, err + } + + if len(b) < len(msg) { + return 0, fmt.Errorf("message length %d is greater than the buffer size %d", len(msg), len(b)) + } + + copy(b[:len(msg)], msg) + + return len(msg), nil +} + +// Write implements the io.ReadWriter interface for *MsgReadWriter. +func (rw *MsgReadWriter) Write(b []byte) (n int, err error) { + // Create random padding to make it harder to understand what's inside + // the tunnel. + minLength := MinMessageLength - len(b) + if minLength <= 0 { + minLength = 1 + } + maxLength := MaxPaddingLength + if maxLength <= minLength { + maxLength = minLength + 1 + } + padding := createRandomPadding(minLength, maxLength) + + // Pack the message before sending it. + msg := pack(b, padding) + + _, err = rw.base.Write(msg) + + if err != nil { + return 0, err + } + + return len(b), nil +} + +// pack packs the message to be sent over the tunnel. +// Message looks like this: +// +// <2 bytes>: body length +// body +// <2 bytes>: padding length +// padding +func pack(b, padding []byte) (msg []byte) { + msg = make([]byte, len(b)+len(padding)+4) + + binary.BigEndian.PutUint16(msg[:2], uint16(len(b))) + copy(msg[2:], b) + binary.BigEndian.PutUint16(msg[len(b)+2:len(b)+4], uint16(len(padding))) + copy(msg[len(b)+4:], padding) + + return msg +} + +// readPrefixed reads a 2-byte prefixed byte array from the reader. +func readPrefixed(r io.Reader) (b []byte, err error) { + var length uint16 + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return nil, err + } + + if length > MaxMessageLength { + // Warn the user that this may not work correctly. + log.Error( + "Warning: received message of length %d larger than %d, considering reducing the MTU", + length, + MaxMessageLength, + ) + } + + b = make([]byte, length) + _, err = io.ReadFull(r, b) + + return b, err +} + +// createRandomPadding creates a random padding array. +func createRandomPadding(minLength int, maxLength int) (b []byte) { + // Generate a random length for the slice between minLength and maxLength. + lengthBuf := make([]byte, 1) + _, err := rand.Read(lengthBuf) + if err != nil { + log.Fatalf("Failed to use crypto/rand: %v", err) + } + length := int(lengthBuf[0]) + + // Ensure the length is within our desired range. + length = (length % (maxLength - minLength)) + minLength + + // Create a slice of the random length. + b = make([]byte, length) + + // Fill the slice with random bytes. + _, err = rand.Read(b) + if err != nil { + log.Fatalf("Failed to use crypto/rand: %v", err) + } + + return b +} diff --git a/tls/tunnel/tunnel.go b/tls/tunnel/tunnel.go new file mode 100644 index 0000000..697ee9f --- /dev/null +++ b/tls/tunnel/tunnel.go @@ -0,0 +1,52 @@ +// Package tunnel implements the tunneling logic for copying data between two +// network connections both sides. +package tunnel + +import ( + "fmt" + "io" + "sync" +) + +// Tunnel passes data between two connections. +func Tunnel(pipeName string, left io.ReadWriter, right io.ReadWriter) { + wg := &sync.WaitGroup{} + wg.Add(2) + + go pipe(fmt.Sprintf("%s left->right", pipeName), left, right, wg) + go pipe(fmt.Sprintf("%s left<-right", pipeName), right, left, wg) + + wg.Wait() +} + +// pipe copies data from reader r to writer w. +func pipe(pipeName string, r io.Reader, w io.Writer, wg *sync.WaitGroup) { + defer wg.Done() + + buf := make([]byte, 65536) + var n int + var err error + + for { + n, err = r.Read(buf) + + if err != nil { + log.Debug("failed to read: %v", err) + + return + } + + if n == 0 { + continue + } + + log.Debug("%s: copying %d bytes", pipeName, n) + + _, err = w.Write(buf[:n]) + if err != nil { + log.Debug("failed to write: %v", err) + + return + } + } +} diff --git a/tls/udp/listener.go b/tls/udp/listener.go new file mode 100644 index 0000000..88d0897 --- /dev/null +++ b/tls/udp/listener.go @@ -0,0 +1,281 @@ +// Package udp implements helper structures for working with UDP. +package udp + +import ( + "errors" + "net" + "sync" + "time" +) + +// Listener is a struct that implements net.Listener interface for working +// with UDP. This is achieved by maintaining an internal "nat-like" table +// with destinations. +type Listener struct { + conn *net.UDPConn + + // natTable is a table which maps peer addresses to udpConn structs. + // Whenever a new packet is received, Listener looks up if there's + // already a udpConn for the peer address and either creates a new one + // or adds the packet to the existing one. + natTable map[string]*udpConn + natTableMu sync.Mutex + + chanAccept chan *udpConn + chanClosed chan struct{} + + closed bool + closedMu sync.Mutex +} + +// Listen creates a new *Listener and is supposed to be a function similar +// to net.Listen, but for UDP only. +func Listen(network, addr string) (l *Listener, err error) { + listenAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + + l = &Listener{ + natTable: map[string]*udpConn{}, + chanAccept: make(chan *udpConn, 256), + chanClosed: make(chan struct{}, 1), + } + + l.conn, err = net.ListenUDP(network, listenAddr) + if err != nil { + return nil, err + } + + go l.readLoop() + + return l, nil +} + +// type check. +var _ net.Listener = (*Listener)(nil) + +// Accept implements the net.Listener interface for *Listener. +func (l *Listener) Accept() (conn net.Conn, err error) { + if l.isClosed() { + return nil, net.ErrClosed + } + + select { + case conn = <-l.chanAccept: + return conn, nil + case <-l.chanClosed: + return nil, net.ErrClosed + } +} + +// Close implements the net.Listener interface for *Listener. +func (l *Listener) Close() (err error) { + if l.isClosed() { + return nil + } + + l.closedMu.Lock() + l.closed = true + l.closedMu.Unlock() + + close(l.chanClosed) + + l.natTableMu.Lock() + for _, c := range l.natTable { + log.OnCloserError(c, log.DEBUG) + } + l.natTableMu.Unlock() + + return l.conn.Close() +} + +// Addr implements the net.Listener interface for *Listener. +func (l *Listener) Addr() (addr net.Addr) { + return l.conn.LocalAddr() +} + +// isClosed returns true if the listener is already closed. +func (l *Listener) isClosed() (ok bool) { + l.closedMu.Lock() + defer l.closedMu.Unlock() + + return l.closed +} + +// readLoop implements the listener logic, it reads incoming data and passes it +// to the corresponding udpConn. When a new udpConn is created, it is written +// to the chanAccept channel. +func (l *Listener) readLoop() { + buf := make([]byte, 65536) + + for !l.isClosed() { + n, addr, err := l.conn.ReadFromUDP(buf) + + if err != nil || n == 0 { + if errors.Is(err, net.ErrClosed) { + return + } + + // TODO(leninalive): Handle errors better here. + + continue + } + + msg := make([]byte, n) + copy(msg, buf[:n]) + l.acceptMsg(addr, msg) + } +} + +// acceptMsg passes the message to the corresponding udpConn. +func (l *Listener) acceptMsg(addr *net.UDPAddr, msg []byte) { + l.natTableMu.Lock() + defer l.natTableMu.Unlock() + + key := addr.String() + conn, _ := l.natTable[key] + if conn == nil || conn.isClosed() { + conn = newUDPConn(addr, l.conn) + l.natTable[key] = conn + + l.chanAccept <- conn + } + + conn.addMsg(msg) +} + +// udpConn represents a connection with a single peer. +type udpConn struct { + peerAddr *net.UDPAddr + conn *net.UDPConn + + remaining []byte + + closed bool + closedMu sync.Mutex + + chanMsg chan []byte + chanClosed chan struct{} +} + +// newUDPConn creates a new *udpConn for the specified peer. +func newUDPConn(peerAddr *net.UDPAddr, baseConn *net.UDPConn) (conn *udpConn) { + return &udpConn{ + peerAddr: peerAddr, + conn: baseConn, + chanMsg: make(chan []byte, 256), + chanClosed: make(chan struct{}, 1), + } +} + +// addMsg adds a new byte array that can be then read from this connection. +func (c *udpConn) addMsg(b []byte) { + c.chanMsg <- b +} + +// isClosed returns true if the connection is closed. +func (c *udpConn) isClosed() (ok bool) { + c.closedMu.Lock() + defer c.closedMu.Unlock() + + return c.closed +} + +// type check +var _ net.Conn = (*udpConn)(nil) + +// Read implements the net.Conn interface for *udpConn. +func (c *udpConn) Read(b []byte) (n int, err error) { + n = c.readRemaining(b) + if n > 0 { + return n, nil + } + + select { + case buf := <-c.chanMsg: + c.remaining = buf + n = c.readRemaining(b) + + return n, nil + case <-c.chanClosed: + return 0, net.ErrClosed + } +} + +// readRemaining reads remaining bytes that were not yet read. +func (c *udpConn) readRemaining(b []byte) (n int) { + if c.remaining == nil || len(c.remaining) == 0 { + return 0 + } + + if len(c.remaining) >= len(b) { + n = len(b) + + copy(b, c.remaining[:n]) + c.remaining = c.remaining[n:] + + return n + } + + n = len(c.remaining) + + copy(b[:n], c.remaining) + c.remaining = nil + + return n +} + +// Write implements the net.Conn interface for *udpConn. +func (c *udpConn) Write(b []byte) (n int, err error) { + return c.conn.WriteToUDP(b, c.peerAddr) +} + +// Close implements the net.Conn interface for *udpConn. +func (c *udpConn) Close() (err error) { + c.closedMu.Lock() + defer c.closedMu.Unlock() + + if c.closed { + return nil + } + + c.closed = true + close(c.chanClosed) + + // Do not close the underlying UDP connection as it's shared with other + // udpConn objects. + + return nil +} + +// LocalAddr implements the net.Conn interface for *udpConn. +func (c *udpConn) LocalAddr() (addr net.Addr) { + return c.conn.LocalAddr() +} + +// RemoteAddr implements the net.Conn interface for *udpConn. +func (c *udpConn) RemoteAddr() (addr net.Addr) { + return c.peerAddr +} + +// SetDeadline implements the net.Conn interface for *udpConn. +func (c *udpConn) SetDeadline(_ time.Time) (err error) { + // TODO(leninalive): Implement it. + + return nil +} + +// SetReadDeadline implements the net.Conn interface for *udpConn. +func (c *udpConn) SetReadDeadline(_ time.Time) (err error) { + // TODO(leninalive): Implement it. + + return nil +} + +// SetWriteDeadline implements the net.Conn interface for *udpConn. +func (c *udpConn) SetWriteDeadline(_ time.Time) (err error) { + // TODO(leninalive): Implement it. + + return nil +}