amneziawg-go/conn/bind_std_tcp.go
2024-03-11 19:26:44 +03:00

331 lines
7.5 KiB
Go

/*
* Copyright (c) 2022. Proton AG
*
* This file is part of ProtonVPN.
*
* ProtonVPN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonVPN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
*/
package conn
import (
"bytes"
"errors"
"io"
"net"
"net/netip"
"sync"
"time"
tls "github.com/refraction-networking/utls"
)
var lastErrorTimestamp time.Time
type StdNetBindTcp struct {
mu sync.Mutex
useTls bool
tcp *net.TCPConn
tls *tls.UConn
endpoint *StdNetEndpoint
currentPacket *bytes.Reader
closed bool
log *Logger
errorChan chan<- error
protectSocket func(fd int) int
tunsafe *TunSafeData
}
//goland:noinspection GoUnusedExportedFunction
func CreateStdNetBind(socketType string, log *Logger, errorChan chan<- error) Bind {
if socketType == "udp" {
return NewStdNetBind()
} else {
return &StdNetBindTcp{tunsafe: NewTunSafeData(), useTls: socketType == "tls", log: log, errorChan: errorChan}
}
}
func (s *StdNetBindTcp) BatchSize() int {
return 1
}
func (s *StdNetBindTcp) GetOffloadInfo() string {
return ""
}
func (bind *StdNetBindTcp) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err == nil {
bind.endpoint = &StdNetEndpoint{AddrPort: e}
}
if err != nil {
return nil, err
}
return asEndpoint(e), err
}
func dialTcp(addr string, protectSocket func(fd int) int) (*net.TCPConn, int, error) {
dialer := net.Dialer{Timeout: 5 * time.Second}
netConn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, 0, err
}
conn := netConn.(*net.TCPConn)
conn.SetLinger(0)
// Retrieve port.
laddr := conn.LocalAddr()
taddr, err := net.ResolveTCPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
_ = conn.Close()
return nil, 0, err
}
return conn, taddr.Port, nil
}
func (bind *StdNetBindTcp) upgradeToTls() error {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
ServerName: randomServerName(),
}
conn := tls.UClient(bind.tcp, tlsConf, tls.HelloChrome_Auto)
conn.SetDeadline(time.Now().Add(5 * time.Second))
bind.log.Verbosef("TLS: Starting handshake")
err := conn.Handshake()
bind.log.Verbosef("TLS: Handshake result: %v", err)
conn.SetDeadline(time.Time{})
// On some devices (e.g. Samsung S21 FE) we see first WireGuard handshake failing on TLS socket and adding small
// delay seems to fix that - issue is likely with timing on the server side, but couldn't find server-side fix.
time.Sleep(100 * time.Millisecond)
if err == nil {
bind.tls = conn
} else {
bind.onSocketError(err)
conn.Close()
}
return err
}
func (bind *StdNetBindTcp) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
bind.log.Verbosef("TCP/TLS: Open %d", uport)
bind.closed = false
return []ReceiveFunc{bind.makeReceiveFunc()}, uport, nil
}
func (bind *StdNetBindTcp) initTcp() error {
var err error
if bind.tcp != nil {
return ErrBindAlreadyOpen
}
var tcp *net.TCPConn
tcp, _, err = dialTcp(bind.endpoint.DstToString(), bind.protectSocket)
bind.log.Verbosef("TCP dial result: %v", err)
if err != nil {
bind.onSocketError(err)
return err
}
bind.tcp = tcp
return nil
}
func (bind *StdNetBindTcp) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
bind.log.Verbosef("TCP/TLS: Close")
bind.closed = true
err := bind.closeInternal()
return err
}
func (bind *StdNetBindTcp) closeInternal() error {
var err error
if bind.tls != nil {
err = bind.tls.Close()
bind.tls = nil
}
if bind.tcp != nil {
err = bind.tcp.Close()
bind.tcp = nil
}
bind.tunsafe.clear()
return err
}
func (bind *StdNetBindTcp) getConn() (net.Conn, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
if bind.closed {
return nil, net.ErrClosed
}
conn, err := bind.getConnInternal()
if err != nil {
bind.closed = true
}
return conn, err
}
func (bind *StdNetBindTcp) getConnInternal() (net.Conn, error) {
if bind.tcp == nil {
err := bind.initTcp()
if err != nil {
return nil, err
}
}
if !bind.useTls {
return bind.tcp, nil
}
if bind.tls == nil {
err := bind.upgradeToTls()
if err != nil {
bind.closeInternal()
return nil, err
}
}
return bind.tls, nil
}
func (bind *StdNetBindTcp) makeReceiveFunc() ReceiveFunc {
return func(packets [][]byte, sizes []int, eps []Endpoint) (int, error) {
var err error
eps[0] = bind.endpoint
if bind.currentPacket == nil || bind.currentPacket.Len() == 0 {
var conn net.Conn
conn, err = bind.getConn()
if err != nil {
bind.logError("recv getConn", err)
return 0, err
}
err = bind.readNextPacket(conn)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
bind.onSocketError(err)
bind.logError("recv", err)
}
return 0, err
}
}
n, err := bind.currentPacket.Read(packets[0])
if err != nil {
bind.logError("read packet", err)
return 0, err
}
sizes[0] = n
return 1, err
}
}
func (bind *StdNetBindTcp) readNextPacket(conn net.Conn) error {
tunSafeHeader := make([]byte, tunSafeHeaderSize)
_, err := io.ReadFull(conn, tunSafeHeader)
if err != nil {
return err
}
tunSafeType, payloadSize := parseTunSafeHeader(tunSafeHeader)
wgPacket, offset, err := bind.tunsafe.prepareWgPacket(tunSafeType, payloadSize)
if err != nil {
return err
}
_, err = io.ReadFull(conn, wgPacket[offset:])
if err != nil {
return err
}
bind.tunsafe.onRecvPacket(tunSafeType, wgPacket)
bind.currentPacket = bytes.NewReader(wgPacket)
return nil
}
func (bind *StdNetBindTcp) Send(buff [][]byte, endpoint Endpoint) error {
conn, err := bind.getConn()
if err != nil {
bind.logError("send conn", err)
return err
}
// As single tcp socket can send only to single destination. We assume endpoint passed to ParseEndpoint will be
// the same.
boundEndpoint := asEndpoint(bind.endpoint.AddrPort)
if endpoint != boundEndpoint {
return errors.New("StdNetBindTcp.Send endpoints mismatch")
}
for i := range buff {
tunSafePacket := bind.tunsafe.wgToTunSafe(buff[i])
_, err = conn.Write(tunSafePacket)
if err != nil {
bind.onSocketError(err)
bind.logError("send", err)
break
}
}
return err
}
func (bind *StdNetBindTcp) SetMark(_ uint32) error {
return nil
}
func (bind *StdNetBindTcp) onSocketError(err error) {
if err != nil && !bind.closed {
bind.errorChan <- err
}
}
func (bind *StdNetBindTcp) logError(t string, err error) {
if time.Now().After(lastErrorTimestamp.Add(5 * time.Second)) {
lastErrorTimestamp = time.Now()
bind.log.Errorf("TCP/TLS error %s: %v", t, err)
}
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = Endpoint(&StdNetEndpoint{AddrPort: ap})
m[ap] = e
}
return e
}