This commit is contained in:
Iurii Egorov 2024-05-16 09:54:10 +03:00 committed by GitHub
commit 3ece72108b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1015 additions and 12 deletions

331
conn/bind_std_tcp.go Normal file
View file

@ -0,0 +1,331 @@
/*
* Copyright (c) 2022. Proton AG
*
* This file is part of ProtonVPN.
*
* ProtonVPN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonVPN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
*/
package conn
import (
"bytes"
"errors"
"io"
"net"
"net/netip"
"sync"
"time"
tls "github.com/refraction-networking/utls"
)
var lastErrorTimestamp time.Time
type StdNetBindTcp struct {
mu sync.Mutex
useTls bool
tcp *net.TCPConn
tls *tls.UConn
endpoint *StdNetEndpoint
currentPacket *bytes.Reader
closed bool
log *Logger
errorChan chan<- error
protectSocket func(fd int) int
tunsafe *TunSafeData
}
//goland:noinspection GoUnusedExportedFunction
func CreateStdNetBind(socketType string, log *Logger, errorChan chan<- error) Bind {
if socketType == "udp" {
return NewStdNetBind()
} else {
return &StdNetBindTcp{tunsafe: NewTunSafeData(), useTls: socketType == "tls", log: log, errorChan: errorChan}
}
}
func (s *StdNetBindTcp) BatchSize() int {
return 1
}
func (s *StdNetBindTcp) GetOffloadInfo() string {
return ""
}
func (bind *StdNetBindTcp) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err == nil {
bind.endpoint = &StdNetEndpoint{AddrPort: e}
}
if err != nil {
return nil, err
}
return asEndpoint(e), err
}
func dialTcp(addr string, protectSocket func(fd int) int) (*net.TCPConn, int, error) {
dialer := net.Dialer{Timeout: 5 * time.Second}
netConn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, 0, err
}
conn := netConn.(*net.TCPConn)
conn.SetLinger(0)
// Retrieve port.
laddr := conn.LocalAddr()
taddr, err := net.ResolveTCPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
_ = conn.Close()
return nil, 0, err
}
return conn, taddr.Port, nil
}
func (bind *StdNetBindTcp) upgradeToTls() error {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
ServerName: randomServerName(),
}
conn := tls.UClient(bind.tcp, tlsConf, tls.HelloChrome_Auto)
conn.SetDeadline(time.Now().Add(5 * time.Second))
bind.log.Verbosef("TLS: Starting handshake")
err := conn.Handshake()
bind.log.Verbosef("TLS: Handshake result: %v", err)
conn.SetDeadline(time.Time{})
// On some devices (e.g. Samsung S21 FE) we see first WireGuard handshake failing on TLS socket and adding small
// delay seems to fix that - issue is likely with timing on the server side, but couldn't find server-side fix.
time.Sleep(100 * time.Millisecond)
if err == nil {
bind.tls = conn
} else {
bind.onSocketError(err)
conn.Close()
}
return err
}
func (bind *StdNetBindTcp) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
bind.log.Verbosef("TCP/TLS: Open %d", uport)
bind.closed = false
return []ReceiveFunc{bind.makeReceiveFunc()}, uport, nil
}
func (bind *StdNetBindTcp) initTcp() error {
var err error
if bind.tcp != nil {
return ErrBindAlreadyOpen
}
var tcp *net.TCPConn
tcp, _, err = dialTcp(bind.endpoint.DstToString(), bind.protectSocket)
bind.log.Verbosef("TCP dial result: %v", err)
if err != nil {
bind.onSocketError(err)
return err
}
bind.tcp = tcp
return nil
}
func (bind *StdNetBindTcp) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
bind.log.Verbosef("TCP/TLS: Close")
bind.closed = true
err := bind.closeInternal()
return err
}
func (bind *StdNetBindTcp) closeInternal() error {
var err error
if bind.tls != nil {
err = bind.tls.Close()
bind.tls = nil
}
if bind.tcp != nil {
err = bind.tcp.Close()
bind.tcp = nil
}
bind.tunsafe.clear()
return err
}
func (bind *StdNetBindTcp) getConn() (net.Conn, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
if bind.closed {
return nil, net.ErrClosed
}
conn, err := bind.getConnInternal()
if err != nil {
bind.closed = true
}
return conn, err
}
func (bind *StdNetBindTcp) getConnInternal() (net.Conn, error) {
if bind.tcp == nil {
err := bind.initTcp()
if err != nil {
return nil, err
}
}
if !bind.useTls {
return bind.tcp, nil
}
if bind.tls == nil {
err := bind.upgradeToTls()
if err != nil {
bind.closeInternal()
return nil, err
}
}
return bind.tls, nil
}
func (bind *StdNetBindTcp) makeReceiveFunc() ReceiveFunc {
return func(packets [][]byte, sizes []int, eps []Endpoint) (int, error) {
var err error
eps[0] = bind.endpoint
if bind.currentPacket == nil || bind.currentPacket.Len() == 0 {
var conn net.Conn
conn, err = bind.getConn()
if err != nil {
bind.logError("recv getConn", err)
return 0, err
}
err = bind.readNextPacket(conn)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
bind.onSocketError(err)
bind.logError("recv", err)
}
return 0, err
}
}
n, err := bind.currentPacket.Read(packets[0])
if err != nil {
bind.logError("read packet", err)
return 0, err
}
sizes[0] = n
return 1, err
}
}
func (bind *StdNetBindTcp) readNextPacket(conn net.Conn) error {
tunSafeHeader := make([]byte, tunSafeHeaderSize)
_, err := io.ReadFull(conn, tunSafeHeader)
if err != nil {
return err
}
tunSafeType, payloadSize := parseTunSafeHeader(tunSafeHeader)
wgPacket, offset, err := bind.tunsafe.prepareWgPacket(tunSafeType, payloadSize)
if err != nil {
return err
}
_, err = io.ReadFull(conn, wgPacket[offset:])
if err != nil {
return err
}
bind.tunsafe.onRecvPacket(tunSafeType, wgPacket)
bind.currentPacket = bytes.NewReader(wgPacket)
return nil
}
func (bind *StdNetBindTcp) Send(buff [][]byte, endpoint Endpoint) error {
conn, err := bind.getConn()
if err != nil {
bind.logError("send conn", err)
return err
}
// As single tcp socket can send only to single destination. We assume endpoint passed to ParseEndpoint will be
// the same.
boundEndpoint := asEndpoint(bind.endpoint.AddrPort)
if endpoint != boundEndpoint {
return errors.New("StdNetBindTcp.Send endpoints mismatch")
}
for i := range buff {
tunSafePacket := bind.tunsafe.wgToTunSafe(buff[i])
_, err = conn.Write(tunSafePacket)
if err != nil {
bind.onSocketError(err)
bind.logError("send", err)
break
}
}
return err
}
func (bind *StdNetBindTcp) SetMark(_ uint32) error {
return nil
}
func (bind *StdNetBindTcp) onSocketError(err error) {
if err != nil && !bind.closed {
bind.errorChan <- err
}
}
func (bind *StdNetBindTcp) logError(t string, err error) {
if time.Now().After(lastErrorTimestamp.Add(5 * time.Second)) {
lastErrorTimestamp = time.Now()
bind.log.Errorf("TCP/TLS error %s: %v", t, err)
}
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = Endpoint(&StdNetEndpoint{AddrPort: ap})
m[ap] = e
}
return e
}

View file

@ -32,3 +32,11 @@ func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
}
return
}
func (bind *StdNetBindTcp) PeekLookAtSocketFd4() (fd int, err error) {
return -1, err
}
func (bind *StdNetBindTcp) PeekLookAtSocketFd6() (fd int, err error) {
return -1, err
}

View file

@ -59,6 +59,11 @@ type Bind interface {
GetOffloadInfo() string
}
type Logger struct {
Verbosef func(format string, args ...any)
Errorf func(format string, args ...any)
}
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {

View file

@ -7,4 +7,6 @@
package conn
func NewDefaultBind() Bind { return NewStdNetBind() }
func NewDefaultBind(logger *Logger, errorChan chan<- error) Bind {
return CreateStdNetBind("tls", logger, errorChan)
}

175
conn/tcp_tls_utils.go Normal file
View file

@ -0,0 +1,175 @@
/*
* Copyright (c) 2022. Proton AG
*
* This file is part of ProtonVPN.
*
* ProtonVPN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonVPN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
*/
package conn
import (
"bytes"
cryptoRand "crypto/rand"
"encoding/binary"
"errors"
"math/big"
"math/rand"
"time"
)
var wgDataPrefix = []byte{4, 0, 0, 0}
var wgDataHeaderSize = 16
var wgDataPrefixSize = 8 // Wireguard data header without counter
var tunSafeHeaderSize = 2
var tunSafeNormalType = uint8(0b00)
var tunSafeDataType = uint8(0b10)
type TunSafeData struct {
wgSendPrefix []byte
wgSendCount uint64
wgRecvPrefix []byte
wgRecvCount uint64
}
var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"}
func NewTunSafeData() *TunSafeData {
return &TunSafeData{
wgRecvPrefix: make([]byte, 8),
wgSendPrefix: make([]byte, 8),
}
}
// Returns (type, size)
func parseTunSafeHeader(header []byte) (byte, int) {
tunSafeType := header[0] >> 6
size := (int(header[0])&0b00111111)<<8 | int(header[1])
return tunSafeType, size
}
func (tunSafe *TunSafeData) clear() {
tunSafe.wgSendCount = 0
tunSafe.wgRecvCount = 0
}
func (tunSafe *TunSafeData) writeWgHeader(wgPacket []byte) {
buffer := new(bytes.Buffer)
buffer.Grow(len(tunSafe.wgRecvPrefix) + binary.Size(tunSafe.wgRecvCount))
buffer.Write(tunSafe.wgRecvPrefix)
_ = binary.Write(buffer, binary.LittleEndian, tunSafe.wgRecvCount)
copy(wgPacket, buffer.Bytes())
}
func (tunSafe *TunSafeData) prepareWgPacket(tunSafeType byte, payloadSize int) ([]byte, int, error) {
var wgPacket []byte
offset := 0
switch tunSafeType {
case tunSafeNormalType:
wgPacket = make([]byte, payloadSize)
case tunSafeDataType:
offset = wgDataHeaderSize
wgPacket = make([]byte, payloadSize+offset)
tunSafe.writeWgHeader(wgPacket)
default:
return nil, 0, errors.New("StdNetBindTcp: unknown TunSafe type")
}
return wgPacket, offset, nil
}
func (tunSafe *TunSafeData) onRecvPacket(tunSafeType byte, wgPacket []byte) {
if tunSafeType == tunSafeNormalType {
isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
if isWgDataPacket {
copy(tunSafe.wgRecvPrefix, wgPacket[:wgDataPrefixSize])
countBuffer := bytes.NewBuffer(wgPacket[wgDataPrefixSize:wgDataHeaderSize])
_ = binary.Read(countBuffer, binary.LittleEndian, &tunSafe.wgRecvCount)
}
}
tunSafe.wgRecvCount++
}
func (tunSafe *TunSafeData) wgToTunSafe(wgPacket []byte) []byte {
wgLen := len(wgPacket)
if wgLen < wgDataHeaderSize {
return wgToTunSafeNormal(wgPacket)
}
wgPrefix := wgPacket[:wgDataPrefixSize]
var wgCount uint64
_ = binary.Read(bytes.NewReader(wgPacket[wgDataPrefixSize:wgDataHeaderSize]), binary.LittleEndian, &wgCount)
prefixMatch := bytes.Equal(wgPrefix, tunSafe.wgSendPrefix)
if prefixMatch && wgCount == tunSafe.wgSendCount+1 {
tunSafe.wgSendCount += 1
return wgToTunSafeData(wgPacket)
} else {
isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
if isWgDataPacket {
tunSafe.wgSendPrefix = wgPrefix
tunSafe.wgSendCount = wgCount
}
return wgToTunSafeNormal(wgPacket)
}
}
func wgToTunSafeNormal(wgPacket []byte) []byte {
payloadSize := len(wgPacket)
result := make([]byte, payloadSize+tunSafeHeaderSize)
// Tunsafe normal header
result[0] = uint8(payloadSize >> 8)
result[1] = uint8(payloadSize & 0xff)
// Full packet
copy(result[tunSafeHeaderSize:], wgPacket)
return result
}
func wgToTunSafeData(wgPacket []byte) []byte {
payloadSize := len(wgPacket) - wgDataHeaderSize
result := make([]byte, payloadSize+tunSafeHeaderSize)
// TunSafe data header
result[0] = uint8(0b10<<6 | payloadSize>>8)
result[1] = uint8(payloadSize & 0xff)
// Packet without header
copy(result[tunSafeHeaderSize:], wgPacket[wgDataHeaderSize:])
return result
}
func randomServerName() string {
charNum := int('z') - int('a') + 1
size := 3 + randInt(10)
name := make([]byte, size)
for i := range name {
name[i] = byte(int('a') + randInt(charNum))
}
return string(name) + "." + randItem(topLevelDomains)
}
func randItem(list []string) string {
return list[randInt(len(list))]
}
func randInt(n int) int {
size, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(n)))
if err == nil {
return int(size.Int64())
}
rand.Seed(time.Now().UnixNano())
return rand.Intn(n)
}

View file

@ -6,6 +6,7 @@
package device
import (
"net"
"runtime"
"sync"
"sync/atomic"
@ -95,6 +96,9 @@ type Device struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
handshakeStateChan chan<- HandshakeState
allowedSrcAddresses []net.IP
}
type aSecCfgType struct {
@ -110,6 +114,14 @@ type aSecCfgType struct {
transportPacketMagicHeader uint32
}
type HandshakeState int
const (
HandshakeInit HandshakeState = iota
HandshakeSuccess = iota
HandshakeFail = iota
)
// deviceState represents the state of a Device.
// There are three states: down, up, closed.
// Transitions:
@ -189,6 +201,7 @@ func (device *Device) changeState(want deviceState) (err error) {
// upLocked attempts to bring the device up and reports whether it succeeded.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) upLocked() error {
device.handshakeStateChan <- HandshakeInit
if err := device.BindUpdate(); err != nil {
device.log.Errorf("Unable to update bind: %v", err)
return err
@ -301,9 +314,18 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil
}
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, handshakeStateChan chan<- HandshakeState /*, allowedSrcAddresses string*/) *Device {
device := new(Device)
device.state.state.Store(uint32(deviceStateDown))
device.handshakeStateChan = handshakeStateChan
/*var allowedSources = strings.Split(allowedSrcAddresses, ",")
device.allowedSrcAddresses = make([]net.IP, len(allowedSources))
for i, source := range allowedSources {
ip := net.ParseIP(source)
if ip != nil {
device.allowedSrcAddresses[i] = ip
}
}*/
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
@ -419,6 +441,7 @@ func (device *Device) Close() {
device.log.Verbosef("Device closed")
close(device.closed)
close(device.handshakeStateChan)
}
func (device *Device) Wait() chan struct{} {

View file

@ -237,7 +237,8 @@ func genTestPair(
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)),
make(chan HandshakeState))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()

View file

@ -6,6 +6,7 @@
package device
import (
"github.com/amnezia-vpn/amneziawg-go/conn"
"log"
"os"
)
@ -16,8 +17,7 @@ import (
// They do not require a trailing newline in the format.
// If nil, that level of logging will be silent.
type Logger struct {
Verbosef func(format string, args ...any)
Errorf func(format string, args ...any)
conn.Logger
}
// Log levels for use with NewLogger.
@ -34,7 +34,7 @@ func DiscardLogf(format string, args ...any) {}
// It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf}
logger := &Logger{conn.Logger{Verbosef: DiscardLogf, Errorf: DiscardLogf}}
logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
}

View file

@ -39,7 +39,7 @@ func randDevice(t *testing.T) *Device {
}
tun := tuntest.NewChannelTUN()
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, make(chan HandshakeState))
device.SetPrivateKey(sk)
return device
}

View file

@ -406,6 +406,7 @@ func (device *Device) RoutineHandshake(id int) {
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
device.handshakeStateChan <- HandshakeSuccess
device.log.Verbosef("%v - Received handshake initiation", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
@ -434,6 +435,7 @@ func (device *Device) RoutineHandshake(id int) {
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
device.handshakeStateChan <- HandshakeSuccess
device.log.Verbosef("%v - Received handshake response", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))

View file

@ -175,6 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(sendBuffer)
if err != nil {
peer.device.handshakeStateChan <- HandshakeFail
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
peer.timersHandshakeInitiated()
@ -318,12 +319,14 @@ func (device *Device) RoutineReadFromTUN() {
// lookup peer
var peer *Peer
var src []byte
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
src = elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case 6:
@ -331,6 +334,7 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
src = elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
@ -340,6 +344,13 @@ func (device *Device) RoutineReadFromTUN() {
if peer == nil {
continue
}
// Drop packets with unexpected src IP.
if device.allowedSrcAddresses != nil && device.isUnexpectedSrcIP(src) {
//device.log.Verbosef("Dropping packet with unexpected src IP: %v (allowed = %v)", src, device.allowedSrcAddresses)
continue
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsContainer()
@ -383,6 +394,15 @@ func (device *Device) RoutineReadFromTUN() {
}
}
func (device *Device) isUnexpectedSrcIP(src []byte) bool {
for _, allowed := range device.allowedSrcAddresses {
if allowed.Equal(src) {
return false
}
}
return true
}
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {

247
device/statemanager.go Normal file
View file

@ -0,0 +1,247 @@
/*
* Copyright (c) 2022. Proton AG
*
* This file is part of ProtonVPN.
*
* ProtonVPN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonVPN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
*/
package device
import (
"strings"
"sync"
"time"
)
var initialRestartDelay = 4 * time.Second
var maxRestartDelay = 32 * time.Second
var resetRestartDelay = 10 * time.Minute
var timeNow = time.Now
// WireGuardStateManager handles enabling/disabling WireGuard in response to network availability changes, serves
// connection state to the client and resets WireGuard connection in response to socket and handshake errors.
//
// Client should call SetNetworkAvailable every time network changes - WireGuard will remain inactive until
// SetNetworkAvailable(true) is called. When SetNetworkAvailable(true) is called twice in a row it'll be interpreted
// as network change and trigger reset of the connection (on TCP/TLS socket).
//
// GetState is blocking and therefore should run in dedicated thread in a loop. After Close is called GetState will
// return immediately with WireGuardDisabled.
type WireGuardStateManager struct {
HandshakeStateChan chan HandshakeState
SocketErrChan chan error
networkAvailableChan chan bool
closeChan chan bool
stateChan chan WireGuardState
isNetAvailable bool
lastRestart time.Time
transmission string
log *Logger
mu sync.Mutex
closed bool
startedTimestamp time.Time
nextRestartDelay time.Duration
}
type WireGuardState int
const (
WireGuardDisabled WireGuardState = iota
WireGuardConnecting
WireGuardConnected
WireGuardError
WireGuardWaitingForNetwork
)
type BaseDevice interface {
Up() error
Down() error
}
//goland:noinspection GoUnusedExportedFunction
func NewWireGuardStateManager(log *Logger, transmission string) *WireGuardStateManager {
return &WireGuardStateManager{
networkAvailableChan: make(chan bool, 100),
SocketErrChan: make(chan error, 100),
HandshakeStateChan: make(chan HandshakeState, 100),
closeChan: make(chan bool, 1),
stateChan: make(chan WireGuardState, 1),
transmission: transmission,
log: log,
nextRestartDelay: initialRestartDelay,
lastRestart: timeNow(),
}
}
func (man *WireGuardStateManager) Start(device BaseDevice) {
go man.handlerLoop(device)
}
func (man *WireGuardStateManager) GetState() WireGuardState {
state, ok := <-man.stateChan
if !ok {
return -1
}
return state
}
func (man *WireGuardStateManager) Close() {
man.log.Verbosef("StateManager: closing")
man.closed = true
go func() {
man.closeChan <- true
man.stateChan <- WireGuardDisabled
close(man.stateChan)
}()
}
func (man *WireGuardStateManager) SetNetworkAvailable(available bool) {
man.networkAvailableChan <- available
}
func (man *WireGuardStateManager) handlerLoop(device BaseDevice) {
man.log.Verbosef("StateManager: start loop")
// Ugly way of emulating optional bool type
var wasNetAvailablePtr *bool = nil
for {
select {
case netAvailable := <-man.networkAvailableChan:
man.onNetworkAvailabilityChange(device, wasNetAvailablePtr, netAvailable)
man.isNetAvailable = netAvailable
wasNetAvailablePtr = &man.isNetAvailable
case socketErr := <-man.SocketErrChan:
if man.isNetAvailable {
man.handleSocketErr(device, socketErr)
}
case handshakeState := <-man.HandshakeStateChan:
if man.isNetAvailable {
man.handleHandshakeState(device, handshakeState)
}
case <-man.closeChan:
man.log.Verbosef("StateManager: end loop")
return
}
}
}
func (man *WireGuardStateManager) onNetworkAvailabilityChange(device BaseDevice, wasAvailable *bool, available bool) {
if !available {
man.postState(WireGuardWaitingForNetwork)
}
if available && wasAvailable == nil {
man.log.Verbosef("StateManager: network on")
man.setActive(device, true)
man.startedTimestamp = timeNow()
} else if available && *wasAvailable && !man.startedTimestamp.IsZero() &&
timeNow().After(man.startedTimestamp.Add(5*time.Second)) {
// Ignore network changes at the very beginning of connection as those might be false positive
// (VPN tunnel opening)
man.log.Verbosef("StateManager: network change detected")
man.maybeRestart(device)
} else if available && !*wasAvailable {
man.log.Verbosef("StateManager: network back")
man.setActive(device, true)
} else if !available && wasAvailable != nil && *wasAvailable {
man.log.Verbosef("StateManager: network gone")
man.setActive(device, false)
}
}
func (man *WireGuardStateManager) setActive(device BaseDevice, activate bool) {
man.mu.Lock()
defer man.mu.Unlock()
var err error
if activate {
man.postState(WireGuardConnecting)
err = device.Up()
} else {
err = device.Down()
}
if err != nil {
man.log.Errorf("StateManager: setActive(%t) error %v", activate, err)
man.postState(WireGuardError)
}
}
func (man *WireGuardStateManager) handleSocketErr(device BaseDevice, err error) {
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset by peer") {
man.log.Errorf("StateManager: %s", errStr)
man.maybeRestart(device)
}
}
}
func (man *WireGuardStateManager) handleHandshakeState(device BaseDevice, state HandshakeState) {
switch state {
case HandshakeInit:
man.postState(WireGuardConnecting)
case HandshakeSuccess:
man.postState(WireGuardConnected)
case HandshakeFail:
man.postState(WireGuardError)
man.maybeRestart(device)
}
}
func (man *WireGuardStateManager) maybeRestart(device BaseDevice) {
if man.transmission == "udp" {
return
}
man.mu.Lock()
defer man.mu.Unlock()
if man.shouldRestart() {
man.log.Verbosef("StateManager: restarting")
man.postState(WireGuardConnecting)
device.Down()
if !man.closed {
device.Up()
}
}
}
// Don't restart too often, grow delay exponentially up to a limit and after some time reset to small initial value
func (man *WireGuardStateManager) shouldRestart() bool {
now := timeNow()
restart := now.After(man.lastRestart.Add(man.nextRestartDelay))
if restart {
if now.After(man.lastRestart.Add(resetRestartDelay)) {
man.nextRestartDelay = initialRestartDelay
} else {
man.nextRestartDelay *= 2
if man.nextRestartDelay > maxRestartDelay {
man.nextRestartDelay = maxRestartDelay
}
}
man.lastRestart = now
}
return restart
}
func (man *WireGuardStateManager) postState(state WireGuardState) {
go func() {
if !man.closed && (man.isNetAvailable || state == WireGuardWaitingForNetwork) {
man.stateChan <- state
}
}()
}

152
device/statemanager_test.go Normal file
View file

@ -0,0 +1,152 @@
/*
* Copyright (c) 2022. Proton AG
*
* This file is part of ProtonVPN.
*
* ProtonVPN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonVPN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
*/
package device
import (
"errors"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
var timeMs int64 = 0
var mockDevice MockDevice
var manager *WireGuardStateManager
var lastState WireGuardState
type MockDevice struct {
isUp bool
upCount int
}
func (dev *MockDevice) Up() error {
dev.isUp = true
dev.upCount++
return nil
}
func (dev *MockDevice) Down() error {
dev.isUp = false
return nil
}
func setup() {
timeMs = 0
timeNow = func() time.Time { return time.UnixMilli(timeMs) }
mockDevice.isUp = false
manager = NewWireGuardStateManager(NewLogger(LogLevelVerbose, ""), "tcp")
manager.Start(&mockDevice)
lastState = WireGuardDisabled
go func() {
for lastState != -1 {
lastState = manager.GetState()
}
}()
}
func setdown() {
manager.Close()
}
func TestWireGuardStateManager_shouldRestart(t *testing.T) {
assert := assert.New(t)
setup()
defer setdown()
assert.Equal(initialRestartDelay, manager.nextRestartDelay)
assert.Equal(false, manager.shouldRestart())
timeMs += initialRestartDelay.Milliseconds()
assert.Equal(false, manager.shouldRestart())
timeMs += 1
assert.Equal(true, manager.shouldRestart())
assert.Equal(2*initialRestartDelay, manager.nextRestartDelay)
assert.Equal(false, manager.shouldRestart())
timeMs += 2 * initialRestartDelay.Milliseconds()
assert.Equal(false, manager.shouldRestart())
timeMs += 1
assert.Equal(true, manager.shouldRestart())
timeMs += resetRestartDelay.Milliseconds() + 1
assert.Equal(true, manager.shouldRestart())
assert.Equal(initialRestartDelay, manager.nextRestartDelay)
}
func TestWireGuardStateManager_networkStartsAndStopsDevice(t *testing.T) {
assert := assert.New(t)
setup()
defer setdown()
assert.Equal(false, mockDevice.isUp)
manager.SetNetworkAvailable(true)
time.Sleep(time.Millisecond) // Poor substitute for advanceUntilIdle, make sure goroutines finish before checking
assert.Equal(true, mockDevice.isUp)
assert.Equal(WireGuardConnecting, lastState)
manager.SetNetworkAvailable(false)
time.Sleep(time.Millisecond)
assert.Equal(WireGuardWaitingForNetwork, lastState)
assert.Equal(false, mockDevice.isUp)
}
func TestWireGuardStateManager_happyConnectionPath(t *testing.T) {
assert := assert.New(t)
setup()
defer setdown()
manager.SetNetworkAvailable(true)
time.Sleep(time.Millisecond)
manager.HandshakeStateChan <- HandshakeSuccess
time.Sleep(time.Millisecond)
assert.Equal(WireGuardConnected, lastState)
assert.Equal(true, mockDevice.isUp)
}
func TestWireGuardStateManager_handshakeFailCausesRestart(t *testing.T) {
assert := assert.New(t)
setup()
defer setdown()
manager.SetNetworkAvailable(true)
time.Sleep(time.Millisecond)
manager.HandshakeStateChan <- HandshakeFail
time.Sleep(time.Millisecond)
assert.Equal(WireGuardError, lastState)
timeMs += initialRestartDelay.Milliseconds() + 1
manager.HandshakeStateChan <- HandshakeFail
time.Sleep(time.Millisecond)
assert.Equal(WireGuardConnecting, lastState)
assert.Equal(2, mockDevice.upCount)
}
func TestWireGuardStateManager_brokenPipeCausesRestart(t *testing.T) {
assert := assert.New(t)
setup()
defer setdown()
manager.SetNetworkAvailable(true)
timeMs += initialRestartDelay.Milliseconds() + 1
time.Sleep(time.Millisecond)
manager.SocketErrChan <- errors.New("broken pipe")
time.Sleep(time.Millisecond)
assert.Equal(WireGuardConnecting, lastState)
assert.Equal(2, mockDevice.upCount)
}

View file

@ -78,6 +78,7 @@ func (peer *Peer) timersActive() bool {
func expiredRetransmitHandshake(peer *Peer) {
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.handshakeStateChan <- HandshakeFail
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
@ -97,6 +98,7 @@ func expiredRetransmitHandshake(peer *Peer) {
}
} else {
peer.timers.handshakeAttempts.Add(1)
peer.device.handshakeStateChan <- HandshakeFail
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */

10
go.mod
View file

@ -3,6 +3,8 @@ module github.com/amnezia-vpn/amneziawg-go
go 1.20
require (
github.com/refraction-networking/utls v1.1.5
github.com/stretchr/testify v1.8.0
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.19.0
golang.org/x/net v0.21.0
@ -15,3 +17,11 @@ require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.15.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

31
go.sum
View file

@ -1,16 +1,41 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY=
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.1.5 h1:JtrojoNhbUQkBqEg05sP3gDgDj6hIEAAVKbI9lx4n6w=
github.com/refraction-networking/utls v1.1.5/go.mod h1:jRQxtYi7nkq1p28HF2lwOH5zQm9aC8rpK0O9lIIzGh8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=

View file

@ -222,11 +222,11 @@ func main() {
return
}
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
errs := make(chan error)
device := device.NewDevice(tdev, conn.NewDefaultBind(&logger.Logger, errs), logger, make(chan device.HandshakeState))
logger.Verbosef("Device started")
errs := make(chan error)
term := make(chan os.Signal, 1)
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)