tls refactoring

Signed-off-by: Iurii Egorov <ye@amnezia.org>
This commit is contained in:
Iurii Egorov 2024-05-01 18:32:40 +03:00
parent 2d725d0ca7
commit 93e92cfadc
No known key found for this signature in database
GPG key ID: B08A11B4E8F59276
9 changed files with 1353 additions and 11 deletions

View file

@ -17,8 +17,8 @@ import (
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tls/pipe"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/leninalive/udptlspipe/pipe"
"github.com/tevino/abool/v2"
)

6
go.mod
View file

@ -5,7 +5,8 @@ go 1.21.6
toolchain go1.21.8
require (
github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f
github.com/gobwas/ws v1.3.2
github.com/refraction-networking/utls v1.6.2
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.19.0
golang.org/x/net v0.21.0
@ -15,15 +16,12 @@ require (
)
require (
github.com/AdguardTeam/golibs v0.20.0 // indirect
github.com/andybalholm/brotli v1.0.6 // indirect
github.com/cloudflare/circl v1.3.7 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gobwas/ws v1.3.2 // indirect
github.com/google/btree v1.0.1 // indirect
github.com/klauspost/compress v1.17.4 // indirect
github.com/quic-go/quic-go v0.40.1 // indirect
github.com/refraction-networking/utls v1.6.2 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)

6
go.sum
View file

@ -1,7 +1,3 @@
github.com/AdguardTeam/golibs v0.20.0 h1:A9FIdYq7wUKhFYy3z+YZ/Aw5oFUYgW+xgaVAJ0pnnPY=
github.com/AdguardTeam/golibs v0.20.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMssnipVXCxDL/w=
github.com/ameshkov/udptlspipe v1.3.1 h1:e+eC2Yb+04KPzH9b/Uktwn6W6lw5CgbFdHnGfAaofx8=
github.com/ameshkov/udptlspipe v1.3.1/go.mod h1:UnpDx2J//7WS/RRe5hb2UVZpwJzHga95ArLkPS9aRBk=
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU=
@ -26,8 +22,6 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f h1:VR2M22cXDtgp78N1mkCmxiXj1zYIP9ScUXS8gMHi6Vs=
github.com/leninalive/udptlspipe v0.0.0-20240313123600-80348db0072f/go.mod h1:U3O6PfEGIxmmxAkOucn8Ty1akGF/1N1lDPeHPLCz3Cg=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=

710
tls/pipe/server.go Normal file
View file

@ -0,0 +1,710 @@
// Package pipe implements the pipe logic, i.e. listening for TLS or UDP
// connections and proxying data to the target destination.
package pipe
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/gobwas/ws"
"github.com/amnezia-vpn/amneziawg-go/tls/tunnel"
"github.com/amnezia-vpn/amneziawg-go/tls/udp"
tls "github.com/refraction-networking/utls"
"golang.org/x/net/proxy"
)
// defaultSNI is the default server name that will be used in both the client
// TLS ClientHello and the server's certificate when no TLS configuration is
// configured.
const defaultSNI = "example.org"
// upgradeTimeout is the read timeout for the first auth packet.
const upgradeTimeout = time.Second * 60
// Server represents an udptlspipe pipe. Depending on whether it is created in
// server- or client- mode, it listens to TLS or UDP connections and pipes the
// data to the destination.
type Server struct {
listenAddr string
destinationAddr string
dialer proxy.Dialer
serverMode bool
probeReverseProxyURL string
probeReverseProxyListen net.Listener
// tlsConfig to use for TLS connections. In server mode it also has the
// certificate that will be used.
tlsConfig *tls.Config
// password is a string that the server will search for in the first bytes.
// If not found, the server will return a stub web page.
password string
// listen is the TLS listener for incoming connections
listen net.Listener
// srcConns is a set that is used to track active incoming TCP connections.
srcConns map[net.Conn]struct{}
srcConnsMu *sync.Mutex
// dstConns is a set that is used to track active connections to the proxy
// destination.
dstConns map[net.Conn]struct{}
dstConnsMu *sync.Mutex
// Shutdown handling
// --
// lock protects started, tcpListener and udpListener.
lock sync.RWMutex
started bool
// wg tracks active workers. Stop won't finish until there is at least
// won't finish until there's at least one active worker.
wg sync.WaitGroup
}
// Config represents the server configuration.
type Config struct {
// ListenAddr is the address (ip:port) where the server will be listening
// to. Depending on the mode the server uses, it will either listen for TLS
// or UDP connections.
ListenAddr string
// DestinationAddr is the address (host:port) to where the server will try
// to connect. Depending on the mode the server uses, it will either
// connect to a TLS endpoint (the pipe server) or not.
DestinationAddr string
// Password enables authentication of the pipe clients. If set, it also
// enables active probing protection.
Password string
// ServerMode controls the way the pipe operates. When it's true, the pipe
// server operates in server mode, i.e. it accepts incoming TLS connections
// and proxies the data to the destination address over UDP. When it works
// in client mode, it is the other way around: accepts UDP traffic and
// proxies it to the destination pipe server over TLS.
ServerMode bool
// URL of a proxy server that can be used for proxying traffic to the
// destination.
ProxyURL string
// VerifyCertificate enables server certificate verification in client mode.
// If enabled, the client will verify the server certificate using the
// system root certs store.
VerifyCertificate bool
// TLSServerName configures the server name to send in TLS ClientHello when
// operating in client mode and the server name that will be used when
// generating a stub certificate. If not set, the default domain name will
// be used for these purposes.
TLSServerName string
// TLSCertificate is an optional field that allows to configure the TLS
// certificate to use when running in server mode. This option makes sense
// only for server mode. If not configured, the server will generate a stub
// self-signed certificate automatically.
TLSCertificate *tls.Certificate
// ProbeReverseProxyURL is the URL that will be used by the reverse HTTP
// proxy to respond to unauthorized or proxy requests. If not specified,
// it will respond with a stub page 403 Forbidden.
ProbeReverseProxyURL string
}
// createTLSConfig creates a TLS configuration as per the server configuration.
func createTLSConfig(config *Config) (tlsConfig *tls.Config, err error) {
serverName := config.TLSServerName
if serverName == "" {
log.Info("TLS server name is not configured, using %s by default", defaultSNI)
serverName = defaultSNI
}
if config.ServerMode {
tlsCert := config.TLSCertificate
if tlsCert == nil {
log.Info("Generating a stub certificate for %s", serverName)
tlsCert, err = createStubCertificate(serverName)
if err != nil {
return nil, fmt.Errorf("failed to generate a stub certificate: %w", err)
}
} else {
log.Info("Using the supplied TLS certificate")
}
tlsConfig = &tls.Config{
ServerName: serverName,
Certificates: []tls.Certificate{*tlsCert},
MinVersion: tls.VersionTLS12,
}
} else {
tlsConfig = &tls.Config{
InsecureSkipVerify: !config.VerifyCertificate,
ServerName: serverName,
}
}
return tlsConfig, nil
}
// NewServer creates a new instance of a *Server.
func NewServer(config *Config) (s *Server, err error) {
s = &Server{
listenAddr: config.ListenAddr,
destinationAddr: config.DestinationAddr,
password: config.Password,
probeReverseProxyURL: config.ProbeReverseProxyURL,
dialer: proxy.Direct,
serverMode: config.ServerMode,
srcConns: map[net.Conn]struct{}{},
srcConnsMu: &sync.Mutex{},
dstConns: map[net.Conn]struct{}{},
dstConnsMu: &sync.Mutex{},
}
s.tlsConfig, err = createTLSConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to prepare TLS configuration: %w", err)
}
if config.ProxyURL != "" {
var u *url.URL
u, err = url.Parse(config.ProxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
s.dialer, err = proxy.FromURL(u, s.dialer)
if err != nil {
return nil, fmt.Errorf("failed to initialize proxy dialer: %w", err)
}
}
return s, nil
}
// Addr returns the address the pipe listens to if it is started or nil.
func (s *Server) Addr() (addr net.Addr) {
if s.listen == nil {
return nil
}
return s.listen.Addr()
}
// Start starts the pipe, exits immediately if it failed to start
// listening. Start returns once all servers are considered up.
func (s *Server) Start() (err error) {
log.Info("Starting the server %s", s)
s.lock.Lock()
defer s.lock.Unlock()
if s.started {
return errors.New("Server is already started")
}
s.listen, err = s.createListener()
if err != nil {
return fmt.Errorf("failed to start pipe: %w", err)
}
if s.probeReverseProxyURL != "" {
err = s.startProbeReverseProxy()
if err != nil {
return fmt.Errorf("failed to start probe reverse proxy: %w", err)
}
}
s.wg.Add(1)
go s.serve()
s.started = true
log.Info("Server has been started")
return nil
}
// createListener creates a TLS listener in server mode and UDP listener in
// client mode.
func (s *Server) createListener() (l net.Listener, err error) {
if s.serverMode {
l, err = tls.Listen("tcp", s.listenAddr, s.tlsConfig)
if err != nil {
return nil, err
}
} else {
l, err = udp.Listen("udp", s.listenAddr)
if err != nil {
return nil, err
}
}
return l, nil
}
// startProbeReverseProxy starts a reverse HTTP proxy that will be used for
// answering unauthorized and probe requests. Returns the listener of that
// proxy. Original request URI will be appended to proxyURL.
func (s *Server) startProbeReverseProxy() (err error) {
proxyURL := s.probeReverseProxyURL
if _, err = url.Parse(proxyURL); err != nil {
return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err)
}
targetURL, err := url.Parse(s.probeReverseProxyURL)
if err != nil {
return fmt.Errorf("reverse proxy URL must be a valid URL: %w", err)
}
handler := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(targetURL)
r.Out.Host = targetURL.Host
},
}
srv := &http.Server{
ReadHeaderTimeout: upgradeTimeout,
Handler: handler,
}
s.probeReverseProxyListen, err = net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return fmt.Errorf("failed to start probe reverse proxy: %w", err)
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
log.Info("Starting probe reverse proxy")
sErr := srv.Serve(s.probeReverseProxyListen)
log.Info("Probe reverse proxy has been stopped due to: %v", sErr)
}()
return nil
}
// Shutdown stops the pipe and waits for all active connections to close.
func (s *Server) Shutdown(ctx context.Context) (err error) {
log.Info("Stopping the server %s", s)
s.stopServeLoop()
// Closing the udpConn thread.
log.OnCloserError(s.listen, log.DEBUG)
if s.probeReverseProxyListen != nil {
log.OnCloserError(s.probeReverseProxyListen, log.DEBUG)
}
// Closing active TCP connections.
s.closeConnections(s.srcConnsMu, s.srcConns)
// Closing active UDP connections.
s.closeConnections(s.dstConnsMu, s.dstConns)
// Wait until all worker threads finish working
err = s.waitShutdown(ctx)
log.Info("Server has been stopped")
return err
}
// closeConnections closes all active connections.
func (s *Server) closeConnections(mu *sync.Mutex, conns map[net.Conn]struct{}) {
mu.Lock()
defer mu.Unlock()
for c := range conns {
_ = c.SetReadDeadline(time.Unix(1, 0))
log.OnCloserError(c, log.DEBUG)
}
}
// stopServeLoop sets the started flag to false thus stopping the serving loop.
func (s *Server) stopServeLoop() {
s.lock.Lock()
defer s.lock.Unlock()
s.started = false
}
// type check
var _ fmt.Stringer = (*Server)(nil)
// String implements the fmt.Stringer interface for *Server.
func (s *Server) String() (str string) {
switch s.serverMode {
case true:
return fmt.Sprintf("tls://%s <-> udp://%s", s.listenAddr, s.destinationAddr)
default:
return fmt.Sprintf("udp://%s <-> tls://%s", s.listenAddr, s.destinationAddr)
}
}
// serve implements the pipe logic, i.e. accepts new connections and tunnels
// data to the destination.
func (s *Server) serve() {
defer s.wg.Done()
defer log.OnPanicAndExit("serve", 1)
defer log.OnCloserError(s.listen, log.DEBUG)
for s.isStarted() {
err := s.acceptConn()
if err != nil {
if !s.isStarted() {
return
}
log.Error("exit serve loop due to: %v", err)
return
}
}
}
// acceptConn accepts new incoming and tracks active connections.
func (s *Server) acceptConn() (err error) {
conn, err := s.listen.Accept()
if err != nil {
// This type of errors should not lead to stopping the server.
if errors.Is(os.ErrDeadlineExceeded, err) {
return nil
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return nil
}
return err
}
log.Debug("Accepted new connection from %s", conn.RemoteAddr())
s.saveSrcConn(conn)
s.wg.Add(1)
go s.serveConn(conn)
return nil
}
// saveSrcConn tracks the connection to allow unblocking reads on shutdown.
func (s *Server) saveSrcConn(conn net.Conn) {
s.srcConnsMu.Lock()
defer s.srcConnsMu.Unlock()
// Track the connection to allow unblocking reads on shutdown.
s.srcConns[conn] = struct{}{}
}
// closeSrcConn closes the source connection and cleans up after it.
func (s *Server) closeSrcConn(conn net.Conn) {
log.OnCloserError(conn, log.DEBUG)
s.srcConnsMu.Lock()
defer s.srcConnsMu.Unlock()
delete(s.srcConns, conn)
}
// saveDstConn tracks the connection to allow unblocking reads on shutdown.
func (s *Server) saveDstConn(conn net.Conn) {
s.dstConnsMu.Lock()
defer s.dstConnsMu.Unlock()
// Track the connection to allow unblocking reads on shutdown.
s.dstConns[conn] = struct{}{}
}
// closeDstConn closes the destination connection and cleans up after it.
func (s *Server) closeDstConn(conn net.Conn) {
// No destination connection opened yet, do nothing.
if conn != nil {
return
}
log.OnCloserError(conn, log.DEBUG)
s.dstConnsMu.Lock()
defer s.dstConnsMu.Unlock()
delete(s.dstConns, conn)
}
// readWriteCloser is a helper object that's used for replacing
// io.ReadWriteCloser when the server peeked into the connection.
type readWriteCloser struct {
io.Reader
io.Writer
io.Closer
}
// upgradeClientConn
func (s *Server) upgradeClientConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) {
log.Debug("Upgrading connection to %s", conn.RemoteAddr())
// Give up to 60 seconds on the upgrade and authentication.
_ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout))
defer func() {
// Remove the deadline when it's not required any more.
_ = conn.SetReadDeadline(time.Time{})
}()
u, err := url.Parse(fmt.Sprintf("wss://%s/?password=%s", s.tlsConfig.ServerName, s.password))
if err != nil {
return nil, err
}
var br *bufio.Reader
br, _, err = ws.DefaultDialer.Upgrade(conn, u)
if err != nil {
return nil, fmt.Errorf("failed to upgrade: %w", err)
}
if br != nil && br.Buffered() > 0 {
// If Upgrade returned a non-empty reader, then probably the server
// immediately sent some data. This is not the expected behavior so
// raise an error here.
return nil, fmt.Errorf("received initial data len=%d from the server", br.Buffered())
}
return newWsConn(
&readWriteCloser{
Reader: conn,
Writer: conn,
Closer: conn,
},
conn.RemoteAddr(),
ws.StateClientSide,
), nil
}
// respondToProbe writes a dummy response to the client if it's not authorized
// or if it's a probe.
func (s *Server) respondToProbe(rwc io.ReadWriteCloser, req *http.Request) {
if s.probeReverseProxyListen == nil {
log.Debug("No probe reverse proxy configured, respond with a dummy 403 page")
response := fmt.Sprintf("%s 403 Forbidden\r\n", req.Proto) +
"Server: nginx\r\n" +
fmt.Sprintf("Date: %s\r\n", time.Now().Format(http.TimeFormat)) +
"Content-Type: text/html\r\n" +
"Connection: close\r\n" +
"\r\n" +
"<html>\r\n" +
"<head><title>403 Forbidden</title></head>\r\n" +
"<center><h1>403 Forbidden</h1></center>\r\n" +
"<hr><center>nginx</center>\r\n" +
"</body>\r\n" +
"</html>\r\n"
_, _ = rwc.Write([]byte(response))
return
}
log.Debug("Probe reverse proxy is configured, tunnel data to it")
proxyConn, err := net.Dial("tcp", s.probeReverseProxyListen.Addr().String())
if err != nil {
log.Error("Failed to connect to the probe reverse proxy: %v", err)
return
}
s.saveDstConn(proxyConn)
tunnel.Tunnel("probeReverseProxy", rwc, proxyConn)
}
// upgradeServerConn attempts to upgrade the server connection and returns a
// rwc that wraps the original connection and can be used for tunneling data.
func (s *Server) upgradeServerConn(conn net.Conn) (rwc io.ReadWriteCloser, err error) {
log.Debug("Upgrading connection from %s", conn.RemoteAddr())
// Give up to 60 seconds on the upgrade and authentication.
_ = conn.SetReadDeadline(time.Now().Add(upgradeTimeout))
defer func() {
// Remove the deadline when it's not required any more.
_ = conn.SetReadDeadline(time.Time{})
}()
// bufio.Reader may read more than requested, so it's crucial to use
// TeeReader so that we could restore the bytes that has been read.
var buf bytes.Buffer
r := bufio.NewReader(io.TeeReader(conn, &buf))
req, err := http.ReadRequest(r)
if err != nil {
return nil, fmt.Errorf("cannot read HTTP request: %w", err)
}
// Now that authentication check has been done restore the peeked up data
// so that it could be used further.
originalRwc := &readWriteCloser{
Reader: io.MultiReader(bytes.NewReader(buf.Bytes()), conn),
Writer: conn,
Closer: conn,
}
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
s.respondToProbe(originalRwc, req)
return nil, fmt.Errorf("not a websocket")
}
clientPassword := req.URL.Query().Get("password")
if s.password != "" && clientPassword != s.password {
s.respondToProbe(originalRwc, req)
return nil, fmt.Errorf("wrong password: %s", clientPassword)
}
_, err = ws.Upgrade(originalRwc)
if err != nil {
return nil, fmt.Errorf("failed to upgrade WebSocket: %w", err)
}
return newWsConn(originalRwc, conn.RemoteAddr(), ws.StateServerSide), nil
}
// serveConn processes incoming connection, authenticates it and proxies the
// data from it to the destination address.
func (s *Server) serveConn(conn net.Conn) {
defer func() {
s.wg.Done()
s.closeSrcConn(conn)
}()
var rwc io.ReadWriteCloser = conn
if s.serverMode {
var err error
rwc, err = s.upgradeServerConn(conn)
if err != nil {
log.Error("failed to accept server conn: %v", err)
return
}
}
s.processConn(rwc)
}
// processConn processes the prepared server connection that is passed as rwc.
func (s *Server) processConn(rwc io.ReadWriteCloser) {
var dstConn net.Conn
defer s.closeDstConn(dstConn)
dstConn, err := s.dialDst()
if err != nil {
log.Error("failed to connect to %s: %v", s.destinationAddr, err)
return
}
s.saveDstConn(dstConn)
var dstRwc io.ReadWriteCloser = dstConn
if !s.serverMode {
dstRwc, err = s.upgradeClientConn(dstConn)
if err != nil {
log.Error("failed to upgrade: %v", err)
return
}
}
// Prepare ReadWriter objects for tunneling.
var srcRw, dstRw io.ReadWriter
srcRw = rwc
dstRw = dstRwc
// When the client communicates with the server it uses encoded messages so
// connection between them needs to be wrapped. In server mode it is the
// source connection, in client mode it is the destination connection.
if s.serverMode {
srcRw = tunnel.NewMsgReadWriter(srcRw)
} else {
dstRw = tunnel.NewMsgReadWriter(dstRw)
}
tunnel.Tunnel(s.String(), srcRw, dstRw)
}
// dialDst creates a connection to the destination. Depending on the mode the
// server operates in, it is either a TLS connection or a UDP connection.
func (s *Server) dialDst() (conn net.Conn, err error) {
if s.serverMode {
return s.dialer.Dial("udp", s.destinationAddr)
}
tcpConn, err := s.dialer.Dial("tcp", s.destinationAddr)
if err != nil {
return nil, fmt.Errorf("failed to open connection to %s: %w", s.destinationAddr, err)
}
tlsConn := tls.UClient(tcpConn, s.tlsConfig, tls.HelloAndroid_11_OkHttp)
err = tlsConn.Handshake()
if err != nil {
return nil, fmt.Errorf("cannot establish connection to %s: %w", s.destinationAddr, err)
}
return tlsConn, nil
}
// isStarted safely checks whether the pipe is started or not.
func (s *Server) isStarted() (started bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.started
}
// waitShutdown waits either until context deadline OR Server.wg.
func (s *Server) waitShutdown(ctx context.Context) (err error) {
// Using this channel to wait until all goroutines finish their work.
closed := make(chan struct{})
go func() {
defer log.OnPanic("waitShutdown")
// Wait until all active workers finished its work.
s.wg.Wait()
close(closed)
}()
var ctxErr error
select {
case <-closed:
// Do nothing here.
case <-ctx.Done():
ctxErr = ctx.Err()
}
return ctxErr
}

73
tls/pipe/tlsconfig.go Normal file
View file

@ -0,0 +1,73 @@
package pipe
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time"
tls "github.com/refraction-networking/utls"
)
// createStubCertificate creates a stub TLS certificate for the pipe server.
// This stub cert is generated when the user does not specify any certificate.
func createStubCertificate(tlsServerName string) (cert *tls.Certificate, err error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, err
}
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{tlsServerName},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
if err != nil {
return nil, err
}
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
tlsCert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
return nil, err
}
return &tlsCert, nil
}
func publicKey(priv any) (pub any) {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}

85
tls/pipe/websocket.go Normal file
View file

@ -0,0 +1,85 @@
package pipe
import (
"io"
"net"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
// wsConn represents a WebSocket connection that's been already initialized.
type wsConn struct {
rwc io.ReadWriteCloser
remoteAddr net.Addr
r *wsutil.Reader
w *wsutil.Writer
}
// newWsConn creates a wrapper over the existing network connection that is
// able to send/read messages using WebSocket protocol.
func newWsConn(rwc io.ReadWriteCloser, remoteAddr net.Addr, state ws.State) (c *wsConn) {
r := wsutil.NewReader(rwc, state)
w := wsutil.NewWriter(rwc, state, ws.OpBinary)
return &wsConn{
rwc: rwc,
remoteAddr: remoteAddr,
r: r,
w: w,
}
}
// type check
var _ io.ReadWriteCloser = (*wsConn)(nil)
// Read implements the io.ReadWriteCloser interface for *wsConn.
func (w *wsConn) Read(b []byte) (n int, err error) {
n, err = w.r.Read(b)
if err == wsutil.ErrNoFrameAdvance {
log.Debug("Reading the next WebSocket frame from %v", w.remoteAddr)
hdr, fErr := w.r.NextFrame()
if fErr != nil {
return 0, io.EOF
}
log.Debug(
"Received WebSocket frame with opcode=%d len=%d fin=%v from %v",
hdr.OpCode,
hdr.Length,
hdr.Fin,
w.remoteAddr,
)
// Reading again after the frame has been read.
n, err = w.r.Read(b)
// EOF in the case of wsutil.Reader does not mean that the connection is
// closed, it only means that the current frame is finished.
if err == io.EOF {
err = nil
}
}
return n, err
}
// Write implements the io.ReadWriteCloser interface for *wsConn.
func (w *wsConn) Write(b []byte) (n int, err error) {
log.Debug("Writing data len=%d to the WebSocket %v", len(b), w.remoteAddr)
n, err = w.w.Write(b)
if err != nil {
return 0, err
}
err = w.w.Flush()
return n, err
}
// Close implements the io.ReadWriteCloser interface for *wsConn.
func (w *wsConn) Close() (err error) {
return w.rwc.Close()
}

149
tls/tunnel/msgreadwriter.go Normal file
View file

@ -0,0 +1,149 @@
package tunnel
import (
"crypto/rand"
"encoding/binary"
"fmt"
"io"
)
// MaxMessageLength is the maximum length that is safe to use.
// TODO(leninalive): Make it configurable.
const MaxMessageLength = 1320
// MinMessageLength is the minimum message size. If the message is smaller, it
// will be padded with random bytes.
const MinMessageLength = 100
// MaxPaddingLength is the maximum size of a random padding that's added to
// every message.
const MaxPaddingLength = 256
// MsgReadWriter is a wrapper over io.ReadWriter that encodes messages written
// to and read from the base writer.
type MsgReadWriter struct {
base io.ReadWriter
}
// NewMsgReadWriter creates a new instance of *MsgReadWriter.
func NewMsgReadWriter(base io.ReadWriter) (rw *MsgReadWriter) {
return &MsgReadWriter{base: base}
}
// type check
var _ io.ReadWriter = (*MsgReadWriter)(nil)
// Read implements the io.ReadWriter interface for *MsgReadWriter.
func (rw *MsgReadWriter) Read(b []byte) (n int, err error) {
// Read the main message (always goes first).
msg, err := readPrefixed(rw.base)
if err != nil {
return 0, err
}
// Skip padding.
_, err = readPrefixed(rw.base)
if err != nil {
return 0, err
}
if len(b) < len(msg) {
return 0, fmt.Errorf("message length %d is greater than the buffer size %d", len(msg), len(b))
}
copy(b[:len(msg)], msg)
return len(msg), nil
}
// Write implements the io.ReadWriter interface for *MsgReadWriter.
func (rw *MsgReadWriter) Write(b []byte) (n int, err error) {
// Create random padding to make it harder to understand what's inside
// the tunnel.
minLength := MinMessageLength - len(b)
if minLength <= 0 {
minLength = 1
}
maxLength := MaxPaddingLength
if maxLength <= minLength {
maxLength = minLength + 1
}
padding := createRandomPadding(minLength, maxLength)
// Pack the message before sending it.
msg := pack(b, padding)
_, err = rw.base.Write(msg)
if err != nil {
return 0, err
}
return len(b), nil
}
// pack packs the message to be sent over the tunnel.
// Message looks like this:
//
// <2 bytes>: body length
// body
// <2 bytes>: padding length
// padding
func pack(b, padding []byte) (msg []byte) {
msg = make([]byte, len(b)+len(padding)+4)
binary.BigEndian.PutUint16(msg[:2], uint16(len(b)))
copy(msg[2:], b)
binary.BigEndian.PutUint16(msg[len(b)+2:len(b)+4], uint16(len(padding)))
copy(msg[len(b)+4:], padding)
return msg
}
// readPrefixed reads a 2-byte prefixed byte array from the reader.
func readPrefixed(r io.Reader) (b []byte, err error) {
var length uint16
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if length > MaxMessageLength {
// Warn the user that this may not work correctly.
log.Error(
"Warning: received message of length %d larger than %d, considering reducing the MTU",
length,
MaxMessageLength,
)
}
b = make([]byte, length)
_, err = io.ReadFull(r, b)
return b, err
}
// createRandomPadding creates a random padding array.
func createRandomPadding(minLength int, maxLength int) (b []byte) {
// Generate a random length for the slice between minLength and maxLength.
lengthBuf := make([]byte, 1)
_, err := rand.Read(lengthBuf)
if err != nil {
log.Fatalf("Failed to use crypto/rand: %v", err)
}
length := int(lengthBuf[0])
// Ensure the length is within our desired range.
length = (length % (maxLength - minLength)) + minLength
// Create a slice of the random length.
b = make([]byte, length)
// Fill the slice with random bytes.
_, err = rand.Read(b)
if err != nil {
log.Fatalf("Failed to use crypto/rand: %v", err)
}
return b
}

52
tls/tunnel/tunnel.go Normal file
View file

@ -0,0 +1,52 @@
// Package tunnel implements the tunneling logic for copying data between two
// network connections both sides.
package tunnel
import (
"fmt"
"io"
"sync"
)
// Tunnel passes data between two connections.
func Tunnel(pipeName string, left io.ReadWriter, right io.ReadWriter) {
wg := &sync.WaitGroup{}
wg.Add(2)
go pipe(fmt.Sprintf("%s left->right", pipeName), left, right, wg)
go pipe(fmt.Sprintf("%s left<-right", pipeName), right, left, wg)
wg.Wait()
}
// pipe copies data from reader r to writer w.
func pipe(pipeName string, r io.Reader, w io.Writer, wg *sync.WaitGroup) {
defer wg.Done()
buf := make([]byte, 65536)
var n int
var err error
for {
n, err = r.Read(buf)
if err != nil {
log.Debug("failed to read: %v", err)
return
}
if n == 0 {
continue
}
log.Debug("%s: copying %d bytes", pipeName, n)
_, err = w.Write(buf[:n])
if err != nil {
log.Debug("failed to write: %v", err)
return
}
}
}

281
tls/udp/listener.go Normal file
View file

@ -0,0 +1,281 @@
// Package udp implements helper structures for working with UDP.
package udp
import (
"errors"
"net"
"sync"
"time"
)
// Listener is a struct that implements net.Listener interface for working
// with UDP. This is achieved by maintaining an internal "nat-like" table
// with destinations.
type Listener struct {
conn *net.UDPConn
// natTable is a table which maps peer addresses to udpConn structs.
// Whenever a new packet is received, Listener looks up if there's
// already a udpConn for the peer address and either creates a new one
// or adds the packet to the existing one.
natTable map[string]*udpConn
natTableMu sync.Mutex
chanAccept chan *udpConn
chanClosed chan struct{}
closed bool
closedMu sync.Mutex
}
// Listen creates a new *Listener and is supposed to be a function similar
// to net.Listen, but for UDP only.
func Listen(network, addr string) (l *Listener, err error) {
listenAddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return nil, err
}
l = &Listener{
natTable: map[string]*udpConn{},
chanAccept: make(chan *udpConn, 256),
chanClosed: make(chan struct{}, 1),
}
l.conn, err = net.ListenUDP(network, listenAddr)
if err != nil {
return nil, err
}
go l.readLoop()
return l, nil
}
// type check.
var _ net.Listener = (*Listener)(nil)
// Accept implements the net.Listener interface for *Listener.
func (l *Listener) Accept() (conn net.Conn, err error) {
if l.isClosed() {
return nil, net.ErrClosed
}
select {
case conn = <-l.chanAccept:
return conn, nil
case <-l.chanClosed:
return nil, net.ErrClosed
}
}
// Close implements the net.Listener interface for *Listener.
func (l *Listener) Close() (err error) {
if l.isClosed() {
return nil
}
l.closedMu.Lock()
l.closed = true
l.closedMu.Unlock()
close(l.chanClosed)
l.natTableMu.Lock()
for _, c := range l.natTable {
log.OnCloserError(c, log.DEBUG)
}
l.natTableMu.Unlock()
return l.conn.Close()
}
// Addr implements the net.Listener interface for *Listener.
func (l *Listener) Addr() (addr net.Addr) {
return l.conn.LocalAddr()
}
// isClosed returns true if the listener is already closed.
func (l *Listener) isClosed() (ok bool) {
l.closedMu.Lock()
defer l.closedMu.Unlock()
return l.closed
}
// readLoop implements the listener logic, it reads incoming data and passes it
// to the corresponding udpConn. When a new udpConn is created, it is written
// to the chanAccept channel.
func (l *Listener) readLoop() {
buf := make([]byte, 65536)
for !l.isClosed() {
n, addr, err := l.conn.ReadFromUDP(buf)
if err != nil || n == 0 {
if errors.Is(err, net.ErrClosed) {
return
}
// TODO(leninalive): Handle errors better here.
continue
}
msg := make([]byte, n)
copy(msg, buf[:n])
l.acceptMsg(addr, msg)
}
}
// acceptMsg passes the message to the corresponding udpConn.
func (l *Listener) acceptMsg(addr *net.UDPAddr, msg []byte) {
l.natTableMu.Lock()
defer l.natTableMu.Unlock()
key := addr.String()
conn, _ := l.natTable[key]
if conn == nil || conn.isClosed() {
conn = newUDPConn(addr, l.conn)
l.natTable[key] = conn
l.chanAccept <- conn
}
conn.addMsg(msg)
}
// udpConn represents a connection with a single peer.
type udpConn struct {
peerAddr *net.UDPAddr
conn *net.UDPConn
remaining []byte
closed bool
closedMu sync.Mutex
chanMsg chan []byte
chanClosed chan struct{}
}
// newUDPConn creates a new *udpConn for the specified peer.
func newUDPConn(peerAddr *net.UDPAddr, baseConn *net.UDPConn) (conn *udpConn) {
return &udpConn{
peerAddr: peerAddr,
conn: baseConn,
chanMsg: make(chan []byte, 256),
chanClosed: make(chan struct{}, 1),
}
}
// addMsg adds a new byte array that can be then read from this connection.
func (c *udpConn) addMsg(b []byte) {
c.chanMsg <- b
}
// isClosed returns true if the connection is closed.
func (c *udpConn) isClosed() (ok bool) {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
// type check
var _ net.Conn = (*udpConn)(nil)
// Read implements the net.Conn interface for *udpConn.
func (c *udpConn) Read(b []byte) (n int, err error) {
n = c.readRemaining(b)
if n > 0 {
return n, nil
}
select {
case buf := <-c.chanMsg:
c.remaining = buf
n = c.readRemaining(b)
return n, nil
case <-c.chanClosed:
return 0, net.ErrClosed
}
}
// readRemaining reads remaining bytes that were not yet read.
func (c *udpConn) readRemaining(b []byte) (n int) {
if c.remaining == nil || len(c.remaining) == 0 {
return 0
}
if len(c.remaining) >= len(b) {
n = len(b)
copy(b, c.remaining[:n])
c.remaining = c.remaining[n:]
return n
}
n = len(c.remaining)
copy(b[:n], c.remaining)
c.remaining = nil
return n
}
// Write implements the net.Conn interface for *udpConn.
func (c *udpConn) Write(b []byte) (n int, err error) {
return c.conn.WriteToUDP(b, c.peerAddr)
}
// Close implements the net.Conn interface for *udpConn.
func (c *udpConn) Close() (err error) {
c.closedMu.Lock()
defer c.closedMu.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.chanClosed)
// Do not close the underlying UDP connection as it's shared with other
// udpConn objects.
return nil
}
// LocalAddr implements the net.Conn interface for *udpConn.
func (c *udpConn) LocalAddr() (addr net.Addr) {
return c.conn.LocalAddr()
}
// RemoteAddr implements the net.Conn interface for *udpConn.
func (c *udpConn) RemoteAddr() (addr net.Addr) {
return c.peerAddr
}
// SetDeadline implements the net.Conn interface for *udpConn.
func (c *udpConn) SetDeadline(_ time.Time) (err error) {
// TODO(leninalive): Implement it.
return nil
}
// SetReadDeadline implements the net.Conn interface for *udpConn.
func (c *udpConn) SetReadDeadline(_ time.Time) (err error) {
// TODO(leninalive): Implement it.
return nil
}
// SetWriteDeadline implements the net.Conn interface for *udpConn.
func (c *udpConn) SetWriteDeadline(_ time.Time) (err error) {
// TODO(leninalive): Implement it.
return nil
}