mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-08 01:56:56 +02:00
tls refactoring
Signed-off-by: Iurii Egorov <ye@amnezia.org>
This commit is contained in:
parent
2d725d0ca7
commit
93e92cfadc
9 changed files with 1353 additions and 11 deletions
|
@ -17,8 +17,8 @@ import (
|
|||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tls/pipe"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"github.com/leninalive/udptlspipe/pipe"
|
||||
"github.com/tevino/abool/v2"
|
||||
)
|
||||
|
||||
|
|
6
go.mod
6
go.mod
|
@ -5,7 +5,8 @@ go 1.21.6
|
|||
toolchain go1.21.8
|
||||
|
||||
require (
|
||||
github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f
|
||||
github.com/gobwas/ws v1.3.2
|
||||
github.com/refraction-networking/utls v1.6.2
|
||||
github.com/tevino/abool/v2 v2.1.0
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/net v0.21.0
|
||||
|
@ -15,15 +16,12 @@ require (
|
|||
)
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/golibs v0.20.0 // indirect
|
||||
github.com/andybalholm/brotli v1.0.6 // indirect
|
||||
github.com/cloudflare/circl v1.3.7 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/gobwas/ws v1.3.2 // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
github.com/quic-go/quic-go v0.40.1 // indirect
|
||||
github.com/refraction-networking/utls v1.6.2 // indirect
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||
)
|
||||
|
|
6
go.sum
6
go.sum
|
@ -1,7 +1,3 @@
|
|||
github.com/AdguardTeam/golibs v0.20.0 h1:A9FIdYq7wUKhFYy3z+YZ/Aw5oFUYgW+xgaVAJ0pnnPY=
|
||||
github.com/AdguardTeam/golibs v0.20.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMssnipVXCxDL/w=
|
||||
github.com/ameshkov/udptlspipe v1.3.1 h1:e+eC2Yb+04KPzH9b/Uktwn6W6lw5CgbFdHnGfAaofx8=
|
||||
github.com/ameshkov/udptlspipe v1.3.1/go.mod h1:UnpDx2J//7WS/RRe5hb2UVZpwJzHga95ArLkPS9aRBk=
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU=
|
||||
|
@ -26,8 +22,6 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE
|
|||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f h1:VR2M22cXDtgp78N1mkCmxiXj1zYIP9ScUXS8gMHi6Vs=
|
||||
github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f/go.mod h1:U3O6PfEGIxmmxAkOucn8Ty1akGF/1N1lDPeHPLCz3Cg=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
|
|
710
tls/pipe/server.go
Normal file
710
tls/pipe/server.go
Normal file
|
@ -0,0 +1,710 @@
|
|||
// Package pipe implements the pipe logic, i.e. listening for TLS or UDP
|
||||
// connections and proxying data to the target destination.
|
||||
package pipe
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tls/tunnel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tls/udp"
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// defaultSNI is the default server name that will be used in both the client
|
||||
// TLS ClientHello and the server's certificate when no TLS configuration is
|
||||
// configured.
|
||||
const defaultSNI = "example.org"
|
||||
|
||||
// upgradeTimeout is the read timeout for the first auth packet.
|
||||
const upgradeTimeout = time.Second * 60
|
||||
|
||||
// Server represents an udptlspipe pipe. Depending on whether it is created in
|
||||
// server- or client- mode, it listens to TLS or UDP connections and pipes the
|
||||
// data to the destination.
|
||||
type Server struct {
|
||||
listenAddr string
|
||||
destinationAddr string
|
||||
dialer proxy.Dialer
|
||||
serverMode bool
|
||||
|
||||
probeReverseProxyURL string
|
||||
probeReverseProxyListen net.Listener
|
||||
|
||||
// tlsConfig to use for TLS connections. In server mode it also has the
|
||||
// certificate that will be used.
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// password is a string that the server will search for in the first bytes.
|
||||
// If not found, the server will return a stub web page.
|
||||
password string
|
||||
|
||||
// listen is the TLS listener for incoming connections
|
||||
listen net.Listener
|
||||
|
||||
// srcConns is a set that is used to track active incoming TCP connections.
|
||||
srcConns map[net.Conn]struct{}
|
||||
srcConnsMu *sync.Mutex
|
||||
|
||||
// dstConns is a set that is used to track active connections to the proxy
|
||||
// destination.
|
||||
dstConns map[net.Conn]struct{}
|
||||
dstConnsMu *sync.Mutex
|
||||
|
||||
// Shutdown handling
|
||||
// --
|
||||
|
||||
// lock protects started, tcpListener and udpListener.
|
||||
lock sync.RWMutex
|
||||
started bool
|
||||
// wg tracks active workers. Stop won't finish until there is at least
|
||||
// won't finish until there's at least one active worker.
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Config represents the server configuration.
|
||||
type Config struct {
|
||||
// ListenAddr is the address (ip:port) where the server will be listening
|
||||
// to. Depending on the mode the server uses, it will either listen for TLS
|
||||
// or UDP connections.
|
||||
ListenAddr string
|
||||
|
||||
// DestinationAddr is the address (host:port) to where the server will try
|
||||
// to connect. Depending on the mode the server uses, it will either
|
||||
// connect to a TLS endpoint (the pipe server) or not.
|
||||
DestinationAddr string
|
||||
|
||||
// Password enables authentication of the pipe clients. If set, it also
|
||||
// enables active probing protection.
|
||||
Password string
|
||||
|
||||
// ServerMode controls the way the pipe operates. When it's true, the pipe
|
||||
// server operates in server mode, i.e. it accepts incoming TLS connections
|
||||
// and proxies the data to the destination address over UDP. When it works
|
||||
// in client mode, it is the other way around: accepts UDP traffic and
|
||||
// proxies it to the destination pipe server over TLS.
|
||||
ServerMode bool
|
||||
|
||||
// URL of a proxy server that can be used for proxying traffic to the
|
||||
// destination.
|
||||
ProxyURL string
|
||||
|
||||
// VerifyCertificate enables server certificate verification in client mode.
|
||||
// If enabled, the client will verify the server certificate using the
|
||||
// system root certs store.
|
||||
VerifyCertificate bool
|
||||
|
||||
// TLSServerName configures the server name to send in TLS ClientHello when
|
||||
// operating in client mode and the server name that will be used when
|
||||
// generating a stub certificate. If not set, the default domain name will
|
||||
// be used for these purposes.
|
||||
TLSServerName string
|
||||
|
||||
// TLSCertificate is an optional field that allows to configure the TLS
|
||||
// certificate to use when running in server mode. This option makes sense
|
||||
// only for server mode. If not configured, the server will generate a stub
|
||||
// self-signed certificate automatically.
|
||||
TLSCertificate *tls.Certificate
|
||||
|
||||
// ProbeReverseProxyURL is the URL that will be used by the reverse HTTP
|
||||
// proxy to respond to unauthorized or proxy requests. If not specified,
|
||||
// it will respond with a stub page 403 Forbidden.
|
||||
ProbeReverseProxyURL string
|
||||
}
|
||||
|
||||
// createTLSConfig creates a TLS configuration as per the server configuration.
|
||||
func createTLSConfig(config *Config) (tlsConfig *tls.Config, err error) {
|
||||
serverName := config.TLSServerName
|
||||
if serverName == "" {
|
||||
log.Info("TLS server name is not configured, using %s by default", defaultSNI)
|
||||
serverName = defaultSNI
|
||||
}
|
||||
|
||||
if config.ServerMode {
|
||||
tlsCert := config.TLSCertificate
|
||||
if tlsCert == nil {
|
||||
log.Info("Generating a stub certificate for %s", serverName)
|
||||
tlsCert, err = createStubCertificate(serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate a stub certificate: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Info("Using the supplied TLS certificate")
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
ServerName: serverName,
|
||||
Certificates: []tls.Certificate{*tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
} else {
|
||||
tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: !config.VerifyCertificate,
|
||||
ServerName: serverName,
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// NewServer creates a new instance of a *Server.
|
||||
func NewServer(config *Config) (s *Server, err error) {
|
||||
s = &Server{
|
||||
listenAddr: config.ListenAddr,
|
||||
destinationAddr: config.DestinationAddr,
|
||||
password: config.Password,
|
||||
probeReverseProxyURL: config.ProbeReverseProxyURL,
|
||||
dialer: proxy.Direct,
|
||||
serverMode: config.ServerMode,
|
||||
srcConns: map[net.Conn]struct{}{},
|
||||
srcConnsMu: &sync.Mutex{},
|
||||
dstConns: map[net.Conn]struct{}{},
|
||||
dstConnsMu: &sync.Mutex{},
|
||||
}
|
||||
|
||||
s.tlsConfig, err = createTLSConfig(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare TLS configuration: %w", err)
|
||||
}
|
||||
|
||||
if config.ProxyURL != "" {
|
||||
var u *url.URL
|
||||
u, err = url.Parse(config.ProxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
s.dialer, err = proxy.FromURL(u, s.dialer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize proxy dialer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Addr returns the address the pipe listens to if it is started or nil.
|
||||
func (s *Server) Addr() (addr net.Addr) {
|
||||
if s.listen == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.listen.Addr()
|
||||
}
|
||||
|
||||
// Start starts the pipe, exits immediately if it failed to start
|
||||
// listening. Start returns once all servers are considered up.
|
||||
func (s *Server) Start() (err error) {
|
||||
log.Info("Starting the server %s", s)
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.started {
|
||||
return errors.New("Server is already started")
|
||||
}
|
||||
|
||||
s.listen, err = s.createListener()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start pipe: %w", err)
|
||||
}
|
||||
|
||||
if s.probeReverseProxyURL != "" {
|
||||
err = s.startProbeReverseProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start probe reverse proxy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.serve()
|
||||
|
||||
s.started = true
|
||||
log.Info("Server has been started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createListener creates a TLS listener in server mode and UDP listener in
|
||||
// client mode.
|
||||
func (s *Server) createListener() (l net.Listener, err error) {
|
||||
if s.serverMode {
|
||||
l, err = tls.Listen("tcp", s.listenAddr, s.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
l, err = udp.Listen("udp", s.listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// startProbeReverseProxy starts a reverse HTTP proxy that will be used for
|
||||
// answering unauthorized and probe requests. Returns the listener of that
|
||||
// proxy. Original request URI will be appended to proxyURL.
|
||||
func (s *Server) startProbeReverseProxy() (err error) {
|
||||
proxyURL := s.probeReverseProxyURL
|
||||
|
||||
if _, err = url.Parse(proxyURL); err != nil {
|
||||
return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err)
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(s.probeReverseProxyURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err)
|
||||
}
|
||||
|
||||
handler := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
r.SetURL(targetURL)
|
||||
r.Out.Host = targetURL.Host
|
||||
},
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
ReadHeaderTimeout: upgradeTimeout,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
s.probeReverseProxyListen, err = net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start probe reverse proxy: %w", err)
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Info("Starting probe reverse proxy")
|
||||
sErr := srv.Serve(s.probeReverseProxyListen)
|
||||
log.Info("Probe reverse proxy has been stopped due to: %v", sErr)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown stops the pipe and waits for all active connections to close.
|
||||
func (s *Server) Shutdown(ctx context.Context) (err error) {
|
||||
log.Info("Stopping the server %s", s)
|
||||
|
||||
s.stopServeLoop()
|
||||
|
||||
// Closing the udpConn thread.
|
||||
log.OnCloserError(s.listen, log.DEBUG)
|
||||
|
||||
if s.probeReverseProxyListen != nil {
|
||||
log.OnCloserError(s.probeReverseProxyListen, log.DEBUG)
|
||||
}
|
||||
|
||||
// Closing active TCP connections.
|
||||
s.closeConnections(s.srcConnsMu, s.srcConns)
|
||||
|
||||
// Closing active UDP connections.
|
||||
s.closeConnections(s.dstConnsMu, s.dstConns)
|
||||
|
||||
// Wait until all worker threads finish working
|
||||
err = s.waitShutdown(ctx)
|
||||
|
||||
log.Info("Server has been stopped")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// closeConnections closes all active connections.
|
||||
func (s *Server) closeConnections(mu *sync.Mutex, conns map[net.Conn]struct{}) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for c := range conns {
|
||||
_ = c.SetReadDeadline(time.Unix(1, 0))
|
||||
|
||||
log.OnCloserError(c, log.DEBUG)
|
||||
}
|
||||
}
|
||||
|
||||
// stopServeLoop sets the started flag to false thus stopping the serving loop.
|
||||
func (s *Server) stopServeLoop() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.started = false
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fmt.Stringer = (*Server)(nil)
|
||||
|
||||
// String implements the fmt.Stringer interface for *Server.
|
||||
func (s *Server) String() (str string) {
|
||||
switch s.serverMode {
|
||||
case true:
|
||||
return fmt.Sprintf("tls://%s <-> udp://%s", s.listenAddr, s.destinationAddr)
|
||||
default:
|
||||
return fmt.Sprintf("udp://%s <-> tls://%s", s.listenAddr, s.destinationAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// serve implements the pipe logic, i.e. accepts new connections and tunnels
|
||||
// data to the destination.
|
||||
func (s *Server) serve() {
|
||||
defer s.wg.Done()
|
||||
defer log.OnPanicAndExit("serve", 1)
|
||||
|
||||
defer log.OnCloserError(s.listen, log.DEBUG)
|
||||
|
||||
for s.isStarted() {
|
||||
err := s.acceptConn()
|
||||
if err != nil {
|
||||
if !s.isStarted() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("exit serve loop due to: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// acceptConn accepts new incoming and tracks active connections.
|
||||
func (s *Server) acceptConn() (err error) {
|
||||
conn, err := s.listen.Accept()
|
||||
if err != nil {
|
||||
// This type of errors should not lead to stopping the server.
|
||||
if errors.Is(os.ErrDeadlineExceeded, err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Accepted new connection from %s", conn.RemoteAddr())
|
||||
|
||||
s.saveSrcConn(conn)
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.serveConn(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveSrcConn tracks the connection to allow unblocking reads on shutdown.
|
||||
func (s *Server) saveSrcConn(conn net.Conn) {
|
||||
s.srcConnsMu.Lock()
|
||||
defer s.srcConnsMu.Unlock()
|
||||
|
||||
// Track the connection to allow unblocking reads on shutdown.
|
||||
s.srcConns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
// closeSrcConn closes the source connection and cleans up after it.
|
||||
func (s *Server) closeSrcConn(conn net.Conn) {
|
||||
log.OnCloserError(conn, log.DEBUG)
|
||||
|
||||
s.srcConnsMu.Lock()
|
||||
defer s.srcConnsMu.Unlock()
|
||||
|
||||
delete(s.srcConns, conn)
|
||||
}
|
||||
|
||||
// saveDstConn tracks the connection to allow unblocking reads on shutdown.
|
||||
func (s *Server) saveDstConn(conn net.Conn) {
|
||||
s.dstConnsMu.Lock()
|
||||
defer s.dstConnsMu.Unlock()
|
||||
|
||||
// Track the connection to allow unblocking reads on shutdown.
|
||||
s.dstConns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
// closeDstConn closes the destination connection and cleans up after it.
|
||||
func (s *Server) closeDstConn(conn net.Conn) {
|
||||
// No destination connection opened yet, do nothing.
|
||||
if conn != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.OnCloserError(conn, log.DEBUG)
|
||||
|
||||
s.dstConnsMu.Lock()
|
||||
defer s.dstConnsMu.Unlock()
|
||||
|
||||
delete(s.dstConns, conn)
|
||||
}
|
||||
|
||||
// readWriteCloser is a helper object that's used for replacing
|
||||
// io.ReadWriteCloser when the server peeked into the connection.
|
||||
type readWriteCloser struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// upgradeClientConn
|
||||
func (s *Server) upgradeClientConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) {
|
||||
log.Debug("Upgrading connection to %s", conn.RemoteAddr())
|
||||
|
||||
// Give up to 60 seconds on the upgrade and authentication.
|
||||
_ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout))
|
||||
defer func() {
|
||||
// Remove the deadline when it's not required any more.
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("wss://%s/?password=%s", s.tlsConfig.ServerName, s.password))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var br *bufio.Reader
|
||||
br, _, err = ws.DefaultDialer.Upgrade(conn, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upgrade: %w", err)
|
||||
}
|
||||
|
||||
if br != nil && br.Buffered() > 0 {
|
||||
// If Upgrade returned a non-empty reader, then probably the server
|
||||
// immediately sent some data. This is not the expected behavior so
|
||||
// raise an error here.
|
||||
return nil, fmt.Errorf("received initial data len=%d from the server", br.Buffered())
|
||||
}
|
||||
|
||||
return newWsConn(
|
||||
&readWriteCloser{
|
||||
Reader: conn,
|
||||
Writer: conn,
|
||||
Closer: conn,
|
||||
},
|
||||
conn.RemoteAddr(),
|
||||
ws.StateClientSide,
|
||||
), nil
|
||||
}
|
||||
|
||||
// respondToProbe writes a dummy response to the client if it's not authorized
|
||||
// or if it's a probe.
|
||||
func (s *Server) respondToProbe(rwc io.ReadWriteCloser, req *http.Request) {
|
||||
if s.probeReverseProxyListen == nil {
|
||||
log.Debug("No probe reverse proxy configured, respond with a dummy 403 page")
|
||||
|
||||
response := fmt.Sprintf("%s 403 Forbidden\r\n", req.Proto) +
|
||||
"Server: nginx\r\n" +
|
||||
fmt.Sprintf("Date: %s\r\n", time.Now().Format(http.TimeFormat)) +
|
||||
"Content-Type: text/html\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n" +
|
||||
"<html>\r\n" +
|
||||
"<head><title>403 Forbidden</title></head>\r\n" +
|
||||
"<center><h1>403 Forbidden</h1></center>\r\n" +
|
||||
"<hr><center>nginx</center>\r\n" +
|
||||
"</body>\r\n" +
|
||||
"</html>\r\n"
|
||||
|
||||
_, _ = rwc.Write([]byte(response))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Probe reverse proxy is configured, tunnel data to it")
|
||||
|
||||
proxyConn, err := net.Dial("tcp", s.probeReverseProxyListen.Addr().String())
|
||||
if err != nil {
|
||||
log.Error("Failed to connect to the probe reverse proxy: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.saveDstConn(proxyConn)
|
||||
|
||||
tunnel.Tunnel("probeReverseProxy", rwc, proxyConn)
|
||||
}
|
||||
|
||||
// upgradeServerConn attempts to upgrade the server connection and returns a
|
||||
// rwc that wraps the original connection and can be used for tunneling data.
|
||||
func (s *Server) upgradeServerConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) {
|
||||
log.Debug("Upgrading connection from %s", conn.RemoteAddr())
|
||||
|
||||
// Give up to 60 seconds on the upgrade and authentication.
|
||||
_ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout))
|
||||
defer func() {
|
||||
// Remove the deadline when it's not required any more.
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
// bufio.Reader may read more than requested, so it's crucial to use
|
||||
// TeeReader so that we could restore the bytes that has been read.
|
||||
var buf bytes.Buffer
|
||||
r := bufio.NewReader(io.TeeReader(conn, &buf))
|
||||
|
||||
req, err := http.ReadRequest(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// Now that authentication check has been done restore the peeked up data
|
||||
// so that it could be used further.
|
||||
originalRwc := &readWriteCloser{
|
||||
Reader: io.MultiReader(bytes.NewReader(buf.Bytes()), conn),
|
||||
Writer: conn,
|
||||
Closer: conn,
|
||||
}
|
||||
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
s.respondToProbe(originalRwc, req)
|
||||
|
||||
return nil, fmt.Errorf("not a websocket")
|
||||
}
|
||||
|
||||
clientPassword := req.URL.Query().Get("password")
|
||||
if s.password != "" && clientPassword != s.password {
|
||||
s.respondToProbe(originalRwc, req)
|
||||
|
||||
return nil, fmt.Errorf("wrong password: %s", clientPassword)
|
||||
}
|
||||
|
||||
_, err = ws.Upgrade(originalRwc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upgrade WebSocket: %w", err)
|
||||
}
|
||||
|
||||
return newWsConn(originalRwc, conn.RemoteAddr(), ws.StateServerSide), nil
|
||||
}
|
||||
|
||||
// serveConn processes incoming connection, authenticates it and proxies the
|
||||
// data from it to the destination address.
|
||||
func (s *Server) serveConn(conn net.Conn) {
|
||||
defer func() {
|
||||
s.wg.Done()
|
||||
|
||||
s.closeSrcConn(conn)
|
||||
}()
|
||||
|
||||
var rwc io.ReadWriteCloser = conn
|
||||
|
||||
if s.serverMode {
|
||||
var err error
|
||||
rwc, err = s.upgradeServerConn(conn)
|
||||
if err != nil {
|
||||
log.Error("failed to accept server conn: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.processConn(rwc)
|
||||
}
|
||||
|
||||
// processConn processes the prepared server connection that is passed as rwc.
|
||||
func (s *Server) processConn(rwc io.ReadWriteCloser) {
|
||||
var dstConn net.Conn
|
||||
|
||||
defer s.closeDstConn(dstConn)
|
||||
|
||||
dstConn, err := s.dialDst()
|
||||
if err != nil {
|
||||
log.Error("failed to connect to %s: %v", s.destinationAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.saveDstConn(dstConn)
|
||||
|
||||
var dstRwc io.ReadWriteCloser = dstConn
|
||||
if !s.serverMode {
|
||||
dstRwc, err = s.upgradeClientConn(dstConn)
|
||||
if err != nil {
|
||||
log.Error("failed to upgrade: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare ReadWriter objects for tunneling.
|
||||
var srcRw, dstRw io.ReadWriter
|
||||
srcRw = rwc
|
||||
dstRw = dstRwc
|
||||
|
||||
// When the client communicates with the server it uses encoded messages so
|
||||
// connection between them needs to be wrapped. In server mode it is the
|
||||
// source connection, in client mode it is the destination connection.
|
||||
if s.serverMode {
|
||||
srcRw = tunnel.NewMsgReadWriter(srcRw)
|
||||
} else {
|
||||
dstRw = tunnel.NewMsgReadWriter(dstRw)
|
||||
}
|
||||
|
||||
tunnel.Tunnel(s.String(), srcRw, dstRw)
|
||||
}
|
||||
|
||||
// dialDst creates a connection to the destination. Depending on the mode the
|
||||
// server operates in, it is either a TLS connection or a UDP connection.
|
||||
func (s *Server) dialDst() (conn net.Conn, err error) {
|
||||
if s.serverMode {
|
||||
return s.dialer.Dial("udp", s.destinationAddr)
|
||||
}
|
||||
|
||||
tcpConn, err := s.dialer.Dial("tcp", s.destinationAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open connection to %s: %w", s.destinationAddr, err)
|
||||
}
|
||||
|
||||
tlsConn := tls.UClient(tcpConn, s.tlsConfig, tls.HelloAndroid_11_OkHttp)
|
||||
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot establish connection to %s: %w", s.destinationAddr, err)
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// isStarted safely checks whether the pipe is started or not.
|
||||
func (s *Server) isStarted() (started bool) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.started
|
||||
}
|
||||
|
||||
// waitShutdown waits either until context deadline OR Server.wg.
|
||||
func (s *Server) waitShutdown(ctx context.Context) (err error) {
|
||||
// Using this channel to wait until all goroutines finish their work.
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
defer log.OnPanic("waitShutdown")
|
||||
|
||||
// Wait until all active workers finished its work.
|
||||
s.wg.Wait()
|
||||
close(closed)
|
||||
}()
|
||||
|
||||
var ctxErr error
|
||||
select {
|
||||
case <-closed:
|
||||
// Do nothing here.
|
||||
case <-ctx.Done():
|
||||
ctxErr = ctx.Err()
|
||||
}
|
||||
|
||||
return ctxErr
|
||||
}
|
73
tls/pipe/tlsconfig.go
Normal file
73
tls/pipe/tlsconfig.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
// createStubCertificate creates a stub TLS certificate for the pipe server.
|
||||
// This stub cert is generated when the user does not specify any certificate.
|
||||
func createStubCertificate(tlsServerName string) (cert *tls.Certificate, err error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{tlsServerName},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
template.DNSNames = append(template.DNSNames, tlsServerName)
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tlsCert, nil
|
||||
}
|
||||
|
||||
func publicKey(priv any) (pub any) {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
85
tls/pipe/websocket.go
Normal file
85
tls/pipe/websocket.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
)
|
||||
|
||||
// wsConn represents a WebSocket connection that's been already initialized.
|
||||
type wsConn struct {
|
||||
rwc io.ReadWriteCloser
|
||||
remoteAddr net.Addr
|
||||
r *wsutil.Reader
|
||||
w *wsutil.Writer
|
||||
}
|
||||
|
||||
// newWsConn creates a wrapper over the existing network connection that is
|
||||
// able to send/read messages using WebSocket protocol.
|
||||
func newWsConn(rwc io.ReadWriteCloser, remoteAddr net.Addr, state ws.State) (c *wsConn) {
|
||||
r := wsutil.NewReader(rwc, state)
|
||||
w := wsutil.NewWriter(rwc, state, ws.OpBinary)
|
||||
|
||||
return &wsConn{
|
||||
rwc: rwc,
|
||||
remoteAddr: remoteAddr,
|
||||
r: r,
|
||||
w: w,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ io.ReadWriteCloser = (*wsConn)(nil)
|
||||
|
||||
// Read implements the io.ReadWriteCloser interface for *wsConn.
|
||||
func (w *wsConn) Read(b []byte) (n int, err error) {
|
||||
n, err = w.r.Read(b)
|
||||
if err == wsutil.ErrNoFrameAdvance {
|
||||
log.Debug("Reading the next WebSocket frame from %v", w.remoteAddr)
|
||||
|
||||
hdr, fErr := w.r.NextFrame()
|
||||
if fErr != nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
"Received WebSocket frame with opcode=%d len=%d fin=%v from %v",
|
||||
hdr.OpCode,
|
||||
hdr.Length,
|
||||
hdr.Fin,
|
||||
w.remoteAddr,
|
||||
)
|
||||
|
||||
// Reading again after the frame has been read.
|
||||
n, err = w.r.Read(b)
|
||||
|
||||
// EOF in the case of wsutil.Reader does not mean that the connection is
|
||||
// closed, it only means that the current frame is finished.
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements the io.ReadWriteCloser interface for *wsConn.
|
||||
func (w *wsConn) Write(b []byte) (n int, err error) {
|
||||
log.Debug("Writing data len=%d to the WebSocket %v", len(b), w.remoteAddr)
|
||||
|
||||
n, err = w.w.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = w.w.Flush()
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close implements the io.ReadWriteCloser interface for *wsConn.
|
||||
func (w *wsConn) Close() (err error) {
|
||||
return w.rwc.Close()
|
||||
}
|
149
tls/tunnel/msgreadwriter.go
Normal file
149
tls/tunnel/msgreadwriter.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxMessageLength is the maximum length that is safe to use.
|
||||
// TODO(leninalive): Make it configurable.
|
||||
const MaxMessageLength = 1320
|
||||
|
||||
// MinMessageLength is the minimum message size. If the message is smaller, it
|
||||
// will be padded with random bytes.
|
||||
const MinMessageLength = 100
|
||||
|
||||
// MaxPaddingLength is the maximum size of a random padding that's added to
|
||||
// every message.
|
||||
const MaxPaddingLength = 256
|
||||
|
||||
// MsgReadWriter is a wrapper over io.ReadWriter that encodes messages written
|
||||
// to and read from the base writer.
|
||||
type MsgReadWriter struct {
|
||||
base io.ReadWriter
|
||||
}
|
||||
|
||||
// NewMsgReadWriter creates a new instance of *MsgReadWriter.
|
||||
func NewMsgReadWriter(base io.ReadWriter) (rw *MsgReadWriter) {
|
||||
return &MsgReadWriter{base: base}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ io.ReadWriter = (*MsgReadWriter)(nil)
|
||||
|
||||
// Read implements the io.ReadWriter interface for *MsgReadWriter.
|
||||
func (rw *MsgReadWriter) Read(b []byte) (n int, err error) {
|
||||
// Read the main message (always goes first).
|
||||
msg, err := readPrefixed(rw.base)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Skip padding.
|
||||
_, err = readPrefixed(rw.base)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(b) < len(msg) {
|
||||
return 0, fmt.Errorf("message length %d is greater than the buffer size %d", len(msg), len(b))
|
||||
}
|
||||
|
||||
copy(b[:len(msg)], msg)
|
||||
|
||||
return len(msg), nil
|
||||
}
|
||||
|
||||
// Write implements the io.ReadWriter interface for *MsgReadWriter.
|
||||
func (rw *MsgReadWriter) Write(b []byte) (n int, err error) {
|
||||
// Create random padding to make it harder to understand what's inside
|
||||
// the tunnel.
|
||||
minLength := MinMessageLength - len(b)
|
||||
if minLength <= 0 {
|
||||
minLength = 1
|
||||
}
|
||||
maxLength := MaxPaddingLength
|
||||
if maxLength <= minLength {
|
||||
maxLength = minLength + 1
|
||||
}
|
||||
padding := createRandomPadding(minLength, maxLength)
|
||||
|
||||
// Pack the message before sending it.
|
||||
msg := pack(b, padding)
|
||||
|
||||
_, err = rw.base.Write(msg)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// pack packs the message to be sent over the tunnel.
|
||||
// Message looks like this:
|
||||
//
|
||||
// <2 bytes>: body length
|
||||
// body
|
||||
// <2 bytes>: padding length
|
||||
// padding
|
||||
func pack(b, padding []byte) (msg []byte) {
|
||||
msg = make([]byte, len(b)+len(padding)+4)
|
||||
|
||||
binary.BigEndian.PutUint16(msg[:2], uint16(len(b)))
|
||||
copy(msg[2:], b)
|
||||
binary.BigEndian.PutUint16(msg[len(b)+2:len(b)+4], uint16(len(padding)))
|
||||
copy(msg[len(b)+4:], padding)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// readPrefixed reads a 2-byte prefixed byte array from the reader.
|
||||
func readPrefixed(r io.Reader) (b []byte, err error) {
|
||||
var length uint16
|
||||
err = binary.Read(r, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if length > MaxMessageLength {
|
||||
// Warn the user that this may not work correctly.
|
||||
log.Error(
|
||||
"Warning: received message of length %d larger than %d, considering reducing the MTU",
|
||||
length,
|
||||
MaxMessageLength,
|
||||
)
|
||||
}
|
||||
|
||||
b = make([]byte, length)
|
||||
_, err = io.ReadFull(r, b)
|
||||
|
||||
return b, err
|
||||
}
|
||||
|
||||
// createRandomPadding creates a random padding array.
|
||||
func createRandomPadding(minLength int, maxLength int) (b []byte) {
|
||||
// Generate a random length for the slice between minLength and maxLength.
|
||||
lengthBuf := make([]byte, 1)
|
||||
_, err := rand.Read(lengthBuf)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to use crypto/rand: %v", err)
|
||||
}
|
||||
length := int(lengthBuf[0])
|
||||
|
||||
// Ensure the length is within our desired range.
|
||||
length = (length % (maxLength - minLength)) + minLength
|
||||
|
||||
// Create a slice of the random length.
|
||||
b = make([]byte, length)
|
||||
|
||||
// Fill the slice with random bytes.
|
||||
_, err = rand.Read(b)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to use crypto/rand: %v", err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
52
tls/tunnel/tunnel.go
Normal file
52
tls/tunnel/tunnel.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
// Package tunnel implements the tunneling logic for copying data between two
|
||||
// network connections both sides.
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Tunnel passes data between two connections.
|
||||
func Tunnel(pipeName string, left io.ReadWriter, right io.ReadWriter) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
go pipe(fmt.Sprintf("%s left->right", pipeName), left, right, wg)
|
||||
go pipe(fmt.Sprintf("%s left<-right", pipeName), right, left, wg)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// pipe copies data from reader r to writer w.
|
||||
func pipe(pipeName string, r io.Reader, w io.Writer, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
buf := make([]byte, 65536)
|
||||
var n int
|
||||
var err error
|
||||
|
||||
for {
|
||||
n, err = r.Read(buf)
|
||||
|
||||
if err != nil {
|
||||
log.Debug("failed to read: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("%s: copying %d bytes", pipeName, n)
|
||||
|
||||
_, err = w.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Debug("failed to write: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
281
tls/udp/listener.go
Normal file
281
tls/udp/listener.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
// Package udp implements helper structures for working with UDP.
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Listener is a struct that implements net.Listener interface for working
|
||||
// with UDP. This is achieved by maintaining an internal "nat-like" table
|
||||
// with destinations.
|
||||
type Listener struct {
|
||||
conn *net.UDPConn
|
||||
|
||||
// natTable is a table which maps peer addresses to udpConn structs.
|
||||
// Whenever a new packet is received, Listener looks up if there's
|
||||
// already a udpConn for the peer address and either creates a new one
|
||||
// or adds the packet to the existing one.
|
||||
natTable map[string]*udpConn
|
||||
natTableMu sync.Mutex
|
||||
|
||||
chanAccept chan *udpConn
|
||||
chanClosed chan struct{}
|
||||
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
}
|
||||
|
||||
// Listen creates a new *Listener and is supposed to be a function similar
|
||||
// to net.Listen, but for UDP only.
|
||||
func Listen(network, addr string) (l *Listener, err error) {
|
||||
listenAddr, err := net.ResolveUDPAddr(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l = &Listener{
|
||||
natTable: map[string]*udpConn{},
|
||||
chanAccept: make(chan *udpConn, 256),
|
||||
chanClosed: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
l.conn, err = net.ListenUDP(network, listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go l.readLoop()
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// type check.
|
||||
var _ net.Listener = (*Listener)(nil)
|
||||
|
||||
// Accept implements the net.Listener interface for *Listener.
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
if l.isClosed() {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case conn = <-l.chanAccept:
|
||||
return conn, nil
|
||||
case <-l.chanClosed:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements the net.Listener interface for *Listener.
|
||||
func (l *Listener) Close() (err error) {
|
||||
if l.isClosed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.closedMu.Lock()
|
||||
l.closed = true
|
||||
l.closedMu.Unlock()
|
||||
|
||||
close(l.chanClosed)
|
||||
|
||||
l.natTableMu.Lock()
|
||||
for _, c := range l.natTable {
|
||||
log.OnCloserError(c, log.DEBUG)
|
||||
}
|
||||
l.natTableMu.Unlock()
|
||||
|
||||
return l.conn.Close()
|
||||
}
|
||||
|
||||
// Addr implements the net.Listener interface for *Listener.
|
||||
func (l *Listener) Addr() (addr net.Addr) {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// isClosed returns true if the listener is already closed.
|
||||
func (l *Listener) isClosed() (ok bool) {
|
||||
l.closedMu.Lock()
|
||||
defer l.closedMu.Unlock()
|
||||
|
||||
return l.closed
|
||||
}
|
||||
|
||||
// readLoop implements the listener logic, it reads incoming data and passes it
|
||||
// to the corresponding udpConn. When a new udpConn is created, it is written
|
||||
// to the chanAccept channel.
|
||||
func (l *Listener) readLoop() {
|
||||
buf := make([]byte, 65536)
|
||||
|
||||
for !l.isClosed() {
|
||||
n, addr, err := l.conn.ReadFromUDP(buf)
|
||||
|
||||
if err != nil || n == 0 {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(leninalive): Handle errors better here.
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
msg := make([]byte, n)
|
||||
copy(msg, buf[:n])
|
||||
l.acceptMsg(addr, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// acceptMsg passes the message to the corresponding udpConn.
|
||||
func (l *Listener) acceptMsg(addr *net.UDPAddr, msg []byte) {
|
||||
l.natTableMu.Lock()
|
||||
defer l.natTableMu.Unlock()
|
||||
|
||||
key := addr.String()
|
||||
conn, _ := l.natTable[key]
|
||||
if conn == nil || conn.isClosed() {
|
||||
conn = newUDPConn(addr, l.conn)
|
||||
l.natTable[key] = conn
|
||||
|
||||
l.chanAccept <- conn
|
||||
}
|
||||
|
||||
conn.addMsg(msg)
|
||||
}
|
||||
|
||||
// udpConn represents a connection with a single peer.
|
||||
type udpConn struct {
|
||||
peerAddr *net.UDPAddr
|
||||
conn *net.UDPConn
|
||||
|
||||
remaining []byte
|
||||
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
|
||||
chanMsg chan []byte
|
||||
chanClosed chan struct{}
|
||||
}
|
||||
|
||||
// newUDPConn creates a new *udpConn for the specified peer.
|
||||
func newUDPConn(peerAddr *net.UDPAddr, baseConn *net.UDPConn) (conn *udpConn) {
|
||||
return &udpConn{
|
||||
peerAddr: peerAddr,
|
||||
conn: baseConn,
|
||||
chanMsg: make(chan []byte, 256),
|
||||
chanClosed: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// addMsg adds a new byte array that can be then read from this connection.
|
||||
func (c *udpConn) addMsg(b []byte) {
|
||||
c.chanMsg <- b
|
||||
}
|
||||
|
||||
// isClosed returns true if the connection is closed.
|
||||
func (c *udpConn) isClosed() (ok bool) {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ net.Conn = (*udpConn)(nil)
|
||||
|
||||
// Read implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) Read(b []byte) (n int, err error) {
|
||||
n = c.readRemaining(b)
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case buf := <-c.chanMsg:
|
||||
c.remaining = buf
|
||||
n = c.readRemaining(b)
|
||||
|
||||
return n, nil
|
||||
case <-c.chanClosed:
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
// readRemaining reads remaining bytes that were not yet read.
|
||||
func (c *udpConn) readRemaining(b []byte) (n int) {
|
||||
if c.remaining == nil || len(c.remaining) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
if len(c.remaining) >= len(b) {
|
||||
n = len(b)
|
||||
|
||||
copy(b, c.remaining[:n])
|
||||
c.remaining = c.remaining[n:]
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
n = len(c.remaining)
|
||||
|
||||
copy(b[:n], c.remaining)
|
||||
c.remaining = nil
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Write implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) Write(b []byte) (n int, err error) {
|
||||
return c.conn.WriteToUDP(b, c.peerAddr)
|
||||
}
|
||||
|
||||
// Close implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) Close() (err error) {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
close(c.chanClosed)
|
||||
|
||||
// Do not close the underlying UDP connection as it's shared with other
|
||||
// udpConn objects.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) LocalAddr() (addr net.Addr) {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) RemoteAddr() (addr net.Addr) {
|
||||
return c.peerAddr
|
||||
}
|
||||
|
||||
// SetDeadline implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) SetDeadline(_ time.Time) (err error) {
|
||||
// TODO(leninalive): Implement it.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) SetReadDeadline(_ time.Time) (err error) {
|
||||
// TODO(leninalive): Implement it.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the net.Conn interface for *udpConn.
|
||||
func (c *udpConn) SetWriteDeadline(_ time.Time) (err error) {
|
||||
// TODO(leninalive): Implement it.
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Add table
Reference in a new issue