amneziawg-go/device/statemanager.go
2024-03-11 18:19:02 +03:00

247 lines
7 KiB
Go

/*
* Copyright (c) 2022. Proton AG
*
* This file is part of ProtonVPN.
*
* ProtonVPN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProtonVPN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
*/
package device
import (
"strings"
"sync"
"time"
)
var initialRestartDelay = 4 * time.Second
var maxRestartDelay = 32 * time.Second
var resetRestartDelay = 10 * time.Minute
var timeNow = time.Now
// WireGuardStateManager handles enabling/disabling WireGuard in response to network availability changes, serves
// connection state to the client and resets WireGuard connection in response to socket and handshake errors.
//
// Client should call SetNetworkAvailable every time network changes - WireGuard will remain inactive until
// SetNetworkAvailable(true) is called. When SetNetworkAvailable(true) is called twice in a row it'll be interpreted
// as network change and trigger reset of the connection (on TCP/TLS socket).
//
// GetState is blocking and therefore should run in dedicated thread in a loop. After Close is called GetState will
// return immediately with WireGuardDisabled.
type WireGuardStateManager struct {
HandshakeStateChan chan HandshakeState
SocketErrChan chan error
networkAvailableChan chan bool
closeChan chan bool
stateChan chan WireGuardState
isNetAvailable bool
lastRestart time.Time
transmission string
log *Logger
mu sync.Mutex
closed bool
startedTimestamp time.Time
nextRestartDelay time.Duration
}
type WireGuardState int
const (
WireGuardDisabled WireGuardState = iota
WireGuardConnecting
WireGuardConnected
WireGuardError
WireGuardWaitingForNetwork
)
type BaseDevice interface {
Up() error
Down() error
}
//goland:noinspection GoUnusedExportedFunction
func NewWireGuardStateManager(log *Logger, transmission string) *WireGuardStateManager {
return &WireGuardStateManager{
networkAvailableChan: make(chan bool, 100),
SocketErrChan: make(chan error, 100),
HandshakeStateChan: make(chan HandshakeState, 100),
closeChan: make(chan bool, 1),
stateChan: make(chan WireGuardState, 1),
transmission: transmission,
log: log,
nextRestartDelay: initialRestartDelay,
lastRestart: timeNow(),
}
}
func (man *WireGuardStateManager) Start(device BaseDevice) {
go man.handlerLoop(device)
}
func (man *WireGuardStateManager) GetState() WireGuardState {
state, ok := <-man.stateChan
if !ok {
return -1
}
return state
}
func (man *WireGuardStateManager) Close() {
man.log.Verbosef("StateManager: closing")
man.closed = true
go func() {
man.closeChan <- true
man.stateChan <- WireGuardDisabled
close(man.stateChan)
}()
}
func (man *WireGuardStateManager) SetNetworkAvailable(available bool) {
man.networkAvailableChan <- available
}
func (man *WireGuardStateManager) handlerLoop(device BaseDevice) {
man.log.Verbosef("StateManager: start loop")
// Ugly way of emulating optional bool type
var wasNetAvailablePtr *bool = nil
for {
select {
case netAvailable := <-man.networkAvailableChan:
man.onNetworkAvailabilityChange(device, wasNetAvailablePtr, netAvailable)
man.isNetAvailable = netAvailable
wasNetAvailablePtr = &man.isNetAvailable
case socketErr := <-man.SocketErrChan:
if man.isNetAvailable {
man.handleSocketErr(device, socketErr)
}
case handshakeState := <-man.HandshakeStateChan:
if man.isNetAvailable {
man.handleHandshakeState(device, handshakeState)
}
case <-man.closeChan:
man.log.Verbosef("StateManager: end loop")
return
}
}
}
func (man *WireGuardStateManager) onNetworkAvailabilityChange(device BaseDevice, wasAvailable *bool, available bool) {
if !available {
man.postState(WireGuardWaitingForNetwork)
}
if available && wasAvailable == nil {
man.log.Verbosef("StateManager: network on")
man.setActive(device, true)
man.startedTimestamp = timeNow()
} else if available && *wasAvailable && !man.startedTimestamp.IsZero() &&
timeNow().After(man.startedTimestamp.Add(5*time.Second)) {
// Ignore network changes at the very beginning of connection as those might be false positive
// (VPN tunnel opening)
man.log.Verbosef("StateManager: network change detected")
man.maybeRestart(device)
} else if available && !*wasAvailable {
man.log.Verbosef("StateManager: network back")
man.setActive(device, true)
} else if !available && wasAvailable != nil && *wasAvailable {
man.log.Verbosef("StateManager: network gone")
man.setActive(device, false)
}
}
func (man *WireGuardStateManager) setActive(device BaseDevice, activate bool) {
man.mu.Lock()
defer man.mu.Unlock()
var err error
if activate {
man.postState(WireGuardConnecting)
err = device.Up()
} else {
err = device.Down()
}
if err != nil {
man.log.Errorf("StateManager: setActive(%t) error %v", activate, err)
man.postState(WireGuardError)
}
}
func (man *WireGuardStateManager) handleSocketErr(device BaseDevice, err error) {
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset by peer") {
man.log.Errorf("StateManager: %s", errStr)
man.maybeRestart(device)
}
}
}
func (man *WireGuardStateManager) handleHandshakeState(device BaseDevice, state HandshakeState) {
switch state {
case HandshakeInit:
man.postState(WireGuardConnecting)
case HandshakeSuccess:
man.postState(WireGuardConnected)
case HandshakeFail:
man.postState(WireGuardError)
man.maybeRestart(device)
}
}
func (man *WireGuardStateManager) maybeRestart(device BaseDevice) {
if man.transmission == "udp" {
return
}
man.mu.Lock()
defer man.mu.Unlock()
if man.shouldRestart() {
man.log.Verbosef("StateManager: restarting")
man.postState(WireGuardConnecting)
device.Down()
if !man.closed {
device.Up()
}
}
}
// Don't restart too often, grow delay exponentially up to a limit and after some time reset to small initial value
func (man *WireGuardStateManager) shouldRestart() bool {
now := timeNow()
restart := now.After(man.lastRestart.Add(man.nextRestartDelay))
if restart {
if now.After(man.lastRestart.Add(resetRestartDelay)) {
man.nextRestartDelay = initialRestartDelay
} else {
man.nextRestartDelay *= 2
if man.nextRestartDelay > maxRestartDelay {
man.nextRestartDelay = maxRestartDelay
}
}
man.lastRestart = now
}
return restart
}
func (man *WireGuardStateManager) postState(state WireGuardState) {
go func() {
if !man.closed && (man.isNetAvailable || state == WireGuardWaitingForNetwork) {
man.stateChan <- state
}
}()
}