mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-16 22:16:55 +02:00
Merge 186c62412b
into c00bda9200
This commit is contained in:
commit
3ece72108b
17 changed files with 1015 additions and 12 deletions
331
conn/bind_std_tcp.go
Normal file
331
conn/bind_std_tcp.go
Normal file
|
@ -0,0 +1,331 @@
|
|||
/*
|
||||
* Copyright (c) 2022. Proton AG
|
||||
*
|
||||
* This file is part of ProtonVPN.
|
||||
*
|
||||
* ProtonVPN is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonVPN is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
var lastErrorTimestamp time.Time
|
||||
|
||||
type StdNetBindTcp struct {
|
||||
mu sync.Mutex
|
||||
|
||||
useTls bool
|
||||
tcp *net.TCPConn
|
||||
tls *tls.UConn
|
||||
endpoint *StdNetEndpoint
|
||||
currentPacket *bytes.Reader
|
||||
closed bool
|
||||
log *Logger
|
||||
errorChan chan<- error
|
||||
protectSocket func(fd int) int
|
||||
|
||||
tunsafe *TunSafeData
|
||||
}
|
||||
|
||||
//goland:noinspection GoUnusedExportedFunction
|
||||
func CreateStdNetBind(socketType string, log *Logger, errorChan chan<- error) Bind {
|
||||
if socketType == "udp" {
|
||||
return NewStdNetBind()
|
||||
} else {
|
||||
return &StdNetBindTcp{tunsafe: NewTunSafeData(), useTls: socketType == "tls", log: log, errorChan: errorChan}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBindTcp) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBindTcp) GetOffloadInfo() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) ParseEndpoint(s string) (Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err == nil {
|
||||
bind.endpoint = &StdNetEndpoint{AddrPort: e}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return asEndpoint(e), err
|
||||
}
|
||||
|
||||
func dialTcp(addr string, protectSocket func(fd int) int) (*net.TCPConn, int, error) {
|
||||
dialer := net.Dialer{Timeout: 5 * time.Second}
|
||||
netConn, err := dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
conn := netConn.(*net.TCPConn)
|
||||
conn.SetLinger(0)
|
||||
|
||||
// Retrieve port.
|
||||
laddr := conn.LocalAddr()
|
||||
taddr, err := net.ResolveTCPAddr(
|
||||
laddr.Network(),
|
||||
laddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn, taddr.Port, nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) upgradeToTls() error {
|
||||
tlsConf := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: randomServerName(),
|
||||
}
|
||||
|
||||
conn := tls.UClient(bind.tcp, tlsConf, tls.HelloChrome_Auto)
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
bind.log.Verbosef("TLS: Starting handshake")
|
||||
err := conn.Handshake()
|
||||
bind.log.Verbosef("TLS: Handshake result: %v", err)
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
// On some devices (e.g. Samsung S21 FE) we see first WireGuard handshake failing on TLS socket and adding small
|
||||
// delay seems to fix that - issue is likely with timing on the server side, but couldn't find server-side fix.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if err == nil {
|
||||
bind.tls = conn
|
||||
} else {
|
||||
bind.onSocketError(err)
|
||||
conn.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
bind.log.Verbosef("TCP/TLS: Open %d", uport)
|
||||
bind.closed = false
|
||||
return []ReceiveFunc{bind.makeReceiveFunc()}, uport, nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) initTcp() error {
|
||||
var err error
|
||||
|
||||
if bind.tcp != nil {
|
||||
return ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
var tcp *net.TCPConn
|
||||
|
||||
tcp, _, err = dialTcp(bind.endpoint.DstToString(), bind.protectSocket)
|
||||
bind.log.Verbosef("TCP dial result: %v", err)
|
||||
if err != nil {
|
||||
bind.onSocketError(err)
|
||||
return err
|
||||
}
|
||||
bind.tcp = tcp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) Close() error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
bind.log.Verbosef("TCP/TLS: Close")
|
||||
bind.closed = true
|
||||
err := bind.closeInternal()
|
||||
return err
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) closeInternal() error {
|
||||
var err error
|
||||
if bind.tls != nil {
|
||||
err = bind.tls.Close()
|
||||
bind.tls = nil
|
||||
}
|
||||
if bind.tcp != nil {
|
||||
err = bind.tcp.Close()
|
||||
bind.tcp = nil
|
||||
}
|
||||
bind.tunsafe.clear()
|
||||
return err
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) getConn() (net.Conn, error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
if bind.closed {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
conn, err := bind.getConnInternal()
|
||||
if err != nil {
|
||||
bind.closed = true
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) getConnInternal() (net.Conn, error) {
|
||||
if bind.tcp == nil {
|
||||
err := bind.initTcp()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if !bind.useTls {
|
||||
return bind.tcp, nil
|
||||
}
|
||||
if bind.tls == nil {
|
||||
err := bind.upgradeToTls()
|
||||
if err != nil {
|
||||
bind.closeInternal()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bind.tls, nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) makeReceiveFunc() ReceiveFunc {
|
||||
return func(packets [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
var err error
|
||||
eps[0] = bind.endpoint
|
||||
if bind.currentPacket == nil || bind.currentPacket.Len() == 0 {
|
||||
var conn net.Conn
|
||||
conn, err = bind.getConn()
|
||||
if err != nil {
|
||||
bind.logError("recv getConn", err)
|
||||
return 0, err
|
||||
}
|
||||
err = bind.readNextPacket(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
bind.onSocketError(err)
|
||||
bind.logError("recv", err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
n, err := bind.currentPacket.Read(packets[0])
|
||||
if err != nil {
|
||||
bind.logError("read packet", err)
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) readNextPacket(conn net.Conn) error {
|
||||
tunSafeHeader := make([]byte, tunSafeHeaderSize)
|
||||
_, err := io.ReadFull(conn, tunSafeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tunSafeType, payloadSize := parseTunSafeHeader(tunSafeHeader)
|
||||
wgPacket, offset, err := bind.tunsafe.prepareWgPacket(tunSafeType, payloadSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(conn, wgPacket[offset:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bind.tunsafe.onRecvPacket(tunSafeType, wgPacket)
|
||||
bind.currentPacket = bytes.NewReader(wgPacket)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) Send(buff [][]byte, endpoint Endpoint) error {
|
||||
conn, err := bind.getConn()
|
||||
if err != nil {
|
||||
bind.logError("send conn", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// As single tcp socket can send only to single destination. We assume endpoint passed to ParseEndpoint will be
|
||||
// the same.
|
||||
boundEndpoint := asEndpoint(bind.endpoint.AddrPort)
|
||||
if endpoint != boundEndpoint {
|
||||
return errors.New("StdNetBindTcp.Send endpoints mismatch")
|
||||
}
|
||||
for i := range buff {
|
||||
tunSafePacket := bind.tunsafe.wgToTunSafe(buff[i])
|
||||
_, err = conn.Write(tunSafePacket)
|
||||
if err != nil {
|
||||
bind.onSocketError(err)
|
||||
bind.logError("send", err)
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) SetMark(_ uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) onSocketError(err error) {
|
||||
if err != nil && !bind.closed {
|
||||
bind.errorChan <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) logError(t string, err error) {
|
||||
if time.Now().After(lastErrorTimestamp.Add(5 * time.Second)) {
|
||||
lastErrorTimestamp = time.Now()
|
||||
bind.log.Errorf("TCP/TLS error %s: %v", t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
||||
// but Endpoints are immutable, so we can re-use them.
|
||||
var endpointPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[netip.AddrPort]Endpoint)
|
||||
},
|
||||
}
|
||||
|
||||
// asEndpoint returns an Endpoint containing ap.
|
||||
func asEndpoint(ap netip.AddrPort) Endpoint {
|
||||
m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
|
||||
defer endpointPool.Put(m)
|
||||
e, ok := m[ap]
|
||||
if !ok {
|
||||
e = Endpoint(&StdNetEndpoint{AddrPort: ap})
|
||||
m[ap] = e
|
||||
}
|
||||
return e
|
||||
}
|
|
@ -32,3 +32,11 @@ func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
func (bind *StdNetBindTcp) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
return -1, err
|
||||
}
|
||||
|
|
|
@ -59,6 +59,11 @@ type Bind interface {
|
|||
GetOffloadInfo() string
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
Verbosef func(format string, args ...any)
|
||||
Errorf func(format string, args ...any)
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
// tied to a single network interface. Used by wireguard-windows.
|
||||
type BindSocketToInterface interface {
|
||||
|
|
|
@ -7,4 +7,6 @@
|
|||
|
||||
package conn
|
||||
|
||||
func NewDefaultBind() Bind { return NewStdNetBind() }
|
||||
func NewDefaultBind(logger *Logger, errorChan chan<- error) Bind {
|
||||
return CreateStdNetBind("tls", logger, errorChan)
|
||||
}
|
||||
|
|
175
conn/tcp_tls_utils.go
Normal file
175
conn/tcp_tls_utils.go
Normal file
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright (c) 2022. Proton AG
|
||||
*
|
||||
* This file is part of ProtonVPN.
|
||||
*
|
||||
* ProtonVPN is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonVPN is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
cryptoRand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
var wgDataPrefix = []byte{4, 0, 0, 0}
|
||||
var wgDataHeaderSize = 16
|
||||
var wgDataPrefixSize = 8 // Wireguard data header without counter
|
||||
|
||||
var tunSafeHeaderSize = 2
|
||||
var tunSafeNormalType = uint8(0b00)
|
||||
var tunSafeDataType = uint8(0b10)
|
||||
|
||||
type TunSafeData struct {
|
||||
wgSendPrefix []byte
|
||||
wgSendCount uint64
|
||||
wgRecvPrefix []byte
|
||||
wgRecvCount uint64
|
||||
}
|
||||
|
||||
var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"}
|
||||
|
||||
func NewTunSafeData() *TunSafeData {
|
||||
return &TunSafeData{
|
||||
wgRecvPrefix: make([]byte, 8),
|
||||
wgSendPrefix: make([]byte, 8),
|
||||
}
|
||||
}
|
||||
|
||||
// Returns (type, size)
|
||||
func parseTunSafeHeader(header []byte) (byte, int) {
|
||||
tunSafeType := header[0] >> 6
|
||||
size := (int(header[0])&0b00111111)<<8 | int(header[1])
|
||||
return tunSafeType, size
|
||||
}
|
||||
|
||||
func (tunSafe *TunSafeData) clear() {
|
||||
tunSafe.wgSendCount = 0
|
||||
tunSafe.wgRecvCount = 0
|
||||
}
|
||||
|
||||
func (tunSafe *TunSafeData) writeWgHeader(wgPacket []byte) {
|
||||
buffer := new(bytes.Buffer)
|
||||
buffer.Grow(len(tunSafe.wgRecvPrefix) + binary.Size(tunSafe.wgRecvCount))
|
||||
buffer.Write(tunSafe.wgRecvPrefix)
|
||||
_ = binary.Write(buffer, binary.LittleEndian, tunSafe.wgRecvCount)
|
||||
copy(wgPacket, buffer.Bytes())
|
||||
}
|
||||
|
||||
func (tunSafe *TunSafeData) prepareWgPacket(tunSafeType byte, payloadSize int) ([]byte, int, error) {
|
||||
var wgPacket []byte
|
||||
offset := 0
|
||||
switch tunSafeType {
|
||||
case tunSafeNormalType:
|
||||
wgPacket = make([]byte, payloadSize)
|
||||
case tunSafeDataType:
|
||||
offset = wgDataHeaderSize
|
||||
wgPacket = make([]byte, payloadSize+offset)
|
||||
tunSafe.writeWgHeader(wgPacket)
|
||||
default:
|
||||
return nil, 0, errors.New("StdNetBindTcp: unknown TunSafe type")
|
||||
}
|
||||
return wgPacket, offset, nil
|
||||
}
|
||||
|
||||
func (tunSafe *TunSafeData) onRecvPacket(tunSafeType byte, wgPacket []byte) {
|
||||
if tunSafeType == tunSafeNormalType {
|
||||
isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
|
||||
if isWgDataPacket {
|
||||
copy(tunSafe.wgRecvPrefix, wgPacket[:wgDataPrefixSize])
|
||||
countBuffer := bytes.NewBuffer(wgPacket[wgDataPrefixSize:wgDataHeaderSize])
|
||||
_ = binary.Read(countBuffer, binary.LittleEndian, &tunSafe.wgRecvCount)
|
||||
}
|
||||
}
|
||||
tunSafe.wgRecvCount++
|
||||
}
|
||||
|
||||
func (tunSafe *TunSafeData) wgToTunSafe(wgPacket []byte) []byte {
|
||||
wgLen := len(wgPacket)
|
||||
if wgLen < wgDataHeaderSize {
|
||||
return wgToTunSafeNormal(wgPacket)
|
||||
}
|
||||
wgPrefix := wgPacket[:wgDataPrefixSize]
|
||||
var wgCount uint64
|
||||
_ = binary.Read(bytes.NewReader(wgPacket[wgDataPrefixSize:wgDataHeaderSize]), binary.LittleEndian, &wgCount)
|
||||
prefixMatch := bytes.Equal(wgPrefix, tunSafe.wgSendPrefix)
|
||||
if prefixMatch && wgCount == tunSafe.wgSendCount+1 {
|
||||
tunSafe.wgSendCount += 1
|
||||
return wgToTunSafeData(wgPacket)
|
||||
} else {
|
||||
isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
|
||||
if isWgDataPacket {
|
||||
tunSafe.wgSendPrefix = wgPrefix
|
||||
tunSafe.wgSendCount = wgCount
|
||||
}
|
||||
return wgToTunSafeNormal(wgPacket)
|
||||
}
|
||||
}
|
||||
|
||||
func wgToTunSafeNormal(wgPacket []byte) []byte {
|
||||
payloadSize := len(wgPacket)
|
||||
result := make([]byte, payloadSize+tunSafeHeaderSize)
|
||||
|
||||
// Tunsafe normal header
|
||||
result[0] = uint8(payloadSize >> 8)
|
||||
result[1] = uint8(payloadSize & 0xff)
|
||||
|
||||
// Full packet
|
||||
copy(result[tunSafeHeaderSize:], wgPacket)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func wgToTunSafeData(wgPacket []byte) []byte {
|
||||
payloadSize := len(wgPacket) - wgDataHeaderSize
|
||||
result := make([]byte, payloadSize+tunSafeHeaderSize)
|
||||
|
||||
// TunSafe data header
|
||||
result[0] = uint8(0b10<<6 | payloadSize>>8)
|
||||
result[1] = uint8(payloadSize & 0xff)
|
||||
|
||||
// Packet without header
|
||||
copy(result[tunSafeHeaderSize:], wgPacket[wgDataHeaderSize:])
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func randomServerName() string {
|
||||
charNum := int('z') - int('a') + 1
|
||||
size := 3 + randInt(10)
|
||||
name := make([]byte, size)
|
||||
for i := range name {
|
||||
name[i] = byte(int('a') + randInt(charNum))
|
||||
}
|
||||
return string(name) + "." + randItem(topLevelDomains)
|
||||
}
|
||||
|
||||
func randItem(list []string) string {
|
||||
return list[randInt(len(list))]
|
||||
}
|
||||
|
||||
func randInt(n int) int {
|
||||
size, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(n)))
|
||||
if err == nil {
|
||||
return int(size.Int64())
|
||||
}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
return rand.Intn(n)
|
||||
}
|
|
@ -6,6 +6,7 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -95,6 +96,9 @@ type Device struct {
|
|||
isASecOn abool.AtomicBool
|
||||
aSecMux sync.RWMutex
|
||||
aSecCfg aSecCfgType
|
||||
|
||||
handshakeStateChan chan<- HandshakeState
|
||||
allowedSrcAddresses []net.IP
|
||||
}
|
||||
|
||||
type aSecCfgType struct {
|
||||
|
@ -110,6 +114,14 @@ type aSecCfgType struct {
|
|||
transportPacketMagicHeader uint32
|
||||
}
|
||||
|
||||
type HandshakeState int
|
||||
|
||||
const (
|
||||
HandshakeInit HandshakeState = iota
|
||||
HandshakeSuccess = iota
|
||||
HandshakeFail = iota
|
||||
)
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
// There are three states: down, up, closed.
|
||||
// Transitions:
|
||||
|
@ -189,6 +201,7 @@ func (device *Device) changeState(want deviceState) (err error) {
|
|||
// upLocked attempts to bring the device up and reports whether it succeeded.
|
||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||
func (device *Device) upLocked() error {
|
||||
device.handshakeStateChan <- HandshakeInit
|
||||
if err := device.BindUpdate(); err != nil {
|
||||
device.log.Errorf("Unable to update bind: %v", err)
|
||||
return err
|
||||
|
@ -301,9 +314,18 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, handshakeStateChan chan<- HandshakeState /*, allowedSrcAddresses string*/) *Device {
|
||||
device := new(Device)
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
device.handshakeStateChan = handshakeStateChan
|
||||
/*var allowedSources = strings.Split(allowedSrcAddresses, ",")
|
||||
device.allowedSrcAddresses = make([]net.IP, len(allowedSources))
|
||||
for i, source := range allowedSources {
|
||||
ip := net.ParseIP(source)
|
||||
if ip != nil {
|
||||
device.allowedSrcAddresses[i] = ip
|
||||
}
|
||||
}*/
|
||||
device.closed = make(chan struct{})
|
||||
device.log = logger
|
||||
device.net.bind = bind
|
||||
|
@ -419,6 +441,7 @@ func (device *Device) Close() {
|
|||
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
close(device.handshakeStateChan)
|
||||
}
|
||||
|
||||
func (device *Device) Wait() chan struct{} {
|
||||
|
|
|
@ -237,7 +237,8 @@ func genTestPair(
|
|||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||
level = LogLevelError
|
||||
}
|
||||
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)),
|
||||
make(chan HandshakeState))
|
||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||
p.dev.Close()
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
@ -16,8 +17,7 @@ import (
|
|||
// They do not require a trailing newline in the format.
|
||||
// If nil, that level of logging will be silent.
|
||||
type Logger struct {
|
||||
Verbosef func(format string, args ...any)
|
||||
Errorf func(format string, args ...any)
|
||||
conn.Logger
|
||||
}
|
||||
|
||||
// Log levels for use with NewLogger.
|
||||
|
@ -34,7 +34,7 @@ func DiscardLogf(format string, args ...any) {}
|
|||
// It logs at the specified log level and above.
|
||||
// It decorates log lines with the log level, date, time, and prepend.
|
||||
func NewLogger(level int, prepend string) *Logger {
|
||||
logger := &Logger{DiscardLogf, DiscardLogf}
|
||||
logger := &Logger{conn.Logger{Verbosef: DiscardLogf, Errorf: DiscardLogf}}
|
||||
logf := func(prefix string) func(string, ...any) {
|
||||
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ func randDevice(t *testing.T) *Device {
|
|||
}
|
||||
tun := tuntest.NewChannelTUN()
|
||||
logger := NewLogger(LogLevelError, "")
|
||||
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
|
||||
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, make(chan HandshakeState))
|
||||
device.SetPrivateKey(sk)
|
||||
return device
|
||||
}
|
||||
|
|
|
@ -406,6 +406,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.handshakeStateChan <- HandshakeSuccess
|
||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
|
||||
|
@ -434,6 +435,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.handshakeStateChan <- HandshakeSuccess
|
||||
device.log.Verbosef("%v - Received handshake response", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
|
||||
|
|
|
@ -175,6 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
|
||||
err = peer.SendBuffers(sendBuffer)
|
||||
if err != nil {
|
||||
peer.device.handshakeStateChan <- HandshakeFail
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
peer.timersHandshakeInitiated()
|
||||
|
@ -318,12 +319,14 @@ func (device *Device) RoutineReadFromTUN() {
|
|||
|
||||
// lookup peer
|
||||
var peer *Peer
|
||||
var src []byte
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
src = elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case 6:
|
||||
|
@ -331,6 +334,7 @@ func (device *Device) RoutineReadFromTUN() {
|
|||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
src = elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
|
@ -340,6 +344,13 @@ func (device *Device) RoutineReadFromTUN() {
|
|||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Drop packets with unexpected src IP.
|
||||
if device.allowedSrcAddresses != nil && device.isUnexpectedSrcIP(src) {
|
||||
//device.log.Verbosef("Dropping packet with unexpected src IP: %v (allowed = %v)", src, device.allowedSrcAddresses)
|
||||
continue
|
||||
}
|
||||
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetOutboundElementsContainer()
|
||||
|
@ -383,6 +394,15 @@ func (device *Device) RoutineReadFromTUN() {
|
|||
}
|
||||
}
|
||||
|
||||
func (device *Device) isUnexpectedSrcIP(src []byte) bool {
|
||||
for _, allowed := range device.allowedSrcAddresses {
|
||||
if allowed.Equal(src) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||
for {
|
||||
select {
|
||||
|
|
247
device/statemanager.go
Normal file
247
device/statemanager.go
Normal file
|
@ -0,0 +1,247 @@
|
|||
/*
|
||||
* Copyright (c) 2022. Proton AG
|
||||
*
|
||||
* This file is part of ProtonVPN.
|
||||
*
|
||||
* ProtonVPN is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonVPN is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var initialRestartDelay = 4 * time.Second
|
||||
var maxRestartDelay = 32 * time.Second
|
||||
var resetRestartDelay = 10 * time.Minute
|
||||
var timeNow = time.Now
|
||||
|
||||
// WireGuardStateManager handles enabling/disabling WireGuard in response to network availability changes, serves
|
||||
// connection state to the client and resets WireGuard connection in response to socket and handshake errors.
|
||||
//
|
||||
// Client should call SetNetworkAvailable every time network changes - WireGuard will remain inactive until
|
||||
// SetNetworkAvailable(true) is called. When SetNetworkAvailable(true) is called twice in a row it'll be interpreted
|
||||
// as network change and trigger reset of the connection (on TCP/TLS socket).
|
||||
//
|
||||
// GetState is blocking and therefore should run in dedicated thread in a loop. After Close is called GetState will
|
||||
// return immediately with WireGuardDisabled.
|
||||
type WireGuardStateManager struct {
|
||||
HandshakeStateChan chan HandshakeState
|
||||
SocketErrChan chan error
|
||||
networkAvailableChan chan bool
|
||||
closeChan chan bool
|
||||
|
||||
stateChan chan WireGuardState
|
||||
isNetAvailable bool
|
||||
|
||||
lastRestart time.Time
|
||||
transmission string
|
||||
|
||||
log *Logger
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
startedTimestamp time.Time
|
||||
nextRestartDelay time.Duration
|
||||
}
|
||||
|
||||
type WireGuardState int
|
||||
|
||||
const (
|
||||
WireGuardDisabled WireGuardState = iota
|
||||
WireGuardConnecting
|
||||
WireGuardConnected
|
||||
WireGuardError
|
||||
WireGuardWaitingForNetwork
|
||||
)
|
||||
|
||||
type BaseDevice interface {
|
||||
Up() error
|
||||
Down() error
|
||||
}
|
||||
|
||||
//goland:noinspection GoUnusedExportedFunction
|
||||
func NewWireGuardStateManager(log *Logger, transmission string) *WireGuardStateManager {
|
||||
return &WireGuardStateManager{
|
||||
networkAvailableChan: make(chan bool, 100),
|
||||
SocketErrChan: make(chan error, 100),
|
||||
HandshakeStateChan: make(chan HandshakeState, 100),
|
||||
closeChan: make(chan bool, 1),
|
||||
stateChan: make(chan WireGuardState, 1),
|
||||
transmission: transmission,
|
||||
log: log,
|
||||
nextRestartDelay: initialRestartDelay,
|
||||
lastRestart: timeNow(),
|
||||
}
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) Start(device BaseDevice) {
|
||||
go man.handlerLoop(device)
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) GetState() WireGuardState {
|
||||
state, ok := <-man.stateChan
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) Close() {
|
||||
man.log.Verbosef("StateManager: closing")
|
||||
man.closed = true
|
||||
go func() {
|
||||
man.closeChan <- true
|
||||
man.stateChan <- WireGuardDisabled
|
||||
close(man.stateChan)
|
||||
}()
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) SetNetworkAvailable(available bool) {
|
||||
man.networkAvailableChan <- available
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) handlerLoop(device BaseDevice) {
|
||||
man.log.Verbosef("StateManager: start loop")
|
||||
// Ugly way of emulating optional bool type
|
||||
var wasNetAvailablePtr *bool = nil
|
||||
for {
|
||||
select {
|
||||
case netAvailable := <-man.networkAvailableChan:
|
||||
man.onNetworkAvailabilityChange(device, wasNetAvailablePtr, netAvailable)
|
||||
man.isNetAvailable = netAvailable
|
||||
wasNetAvailablePtr = &man.isNetAvailable
|
||||
case socketErr := <-man.SocketErrChan:
|
||||
if man.isNetAvailable {
|
||||
man.handleSocketErr(device, socketErr)
|
||||
}
|
||||
case handshakeState := <-man.HandshakeStateChan:
|
||||
if man.isNetAvailable {
|
||||
man.handleHandshakeState(device, handshakeState)
|
||||
}
|
||||
case <-man.closeChan:
|
||||
man.log.Verbosef("StateManager: end loop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) onNetworkAvailabilityChange(device BaseDevice, wasAvailable *bool, available bool) {
|
||||
if !available {
|
||||
man.postState(WireGuardWaitingForNetwork)
|
||||
}
|
||||
if available && wasAvailable == nil {
|
||||
man.log.Verbosef("StateManager: network on")
|
||||
man.setActive(device, true)
|
||||
man.startedTimestamp = timeNow()
|
||||
} else if available && *wasAvailable && !man.startedTimestamp.IsZero() &&
|
||||
timeNow().After(man.startedTimestamp.Add(5*time.Second)) {
|
||||
// Ignore network changes at the very beginning of connection as those might be false positive
|
||||
// (VPN tunnel opening)
|
||||
man.log.Verbosef("StateManager: network change detected")
|
||||
man.maybeRestart(device)
|
||||
} else if available && !*wasAvailable {
|
||||
man.log.Verbosef("StateManager: network back")
|
||||
man.setActive(device, true)
|
||||
} else if !available && wasAvailable != nil && *wasAvailable {
|
||||
man.log.Verbosef("StateManager: network gone")
|
||||
man.setActive(device, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) setActive(device BaseDevice, activate bool) {
|
||||
man.mu.Lock()
|
||||
defer man.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if activate {
|
||||
man.postState(WireGuardConnecting)
|
||||
err = device.Up()
|
||||
} else {
|
||||
err = device.Down()
|
||||
}
|
||||
if err != nil {
|
||||
man.log.Errorf("StateManager: setActive(%t) error %v", activate, err)
|
||||
man.postState(WireGuardError)
|
||||
}
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) handleSocketErr(device BaseDevice, err error) {
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset by peer") {
|
||||
man.log.Errorf("StateManager: %s", errStr)
|
||||
man.maybeRestart(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) handleHandshakeState(device BaseDevice, state HandshakeState) {
|
||||
switch state {
|
||||
case HandshakeInit:
|
||||
man.postState(WireGuardConnecting)
|
||||
case HandshakeSuccess:
|
||||
man.postState(WireGuardConnected)
|
||||
case HandshakeFail:
|
||||
man.postState(WireGuardError)
|
||||
man.maybeRestart(device)
|
||||
}
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) maybeRestart(device BaseDevice) {
|
||||
if man.transmission == "udp" {
|
||||
return
|
||||
}
|
||||
|
||||
man.mu.Lock()
|
||||
defer man.mu.Unlock()
|
||||
|
||||
if man.shouldRestart() {
|
||||
man.log.Verbosef("StateManager: restarting")
|
||||
man.postState(WireGuardConnecting)
|
||||
device.Down()
|
||||
if !man.closed {
|
||||
device.Up()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Don't restart too often, grow delay exponentially up to a limit and after some time reset to small initial value
|
||||
func (man *WireGuardStateManager) shouldRestart() bool {
|
||||
now := timeNow()
|
||||
restart := now.After(man.lastRestart.Add(man.nextRestartDelay))
|
||||
if restart {
|
||||
if now.After(man.lastRestart.Add(resetRestartDelay)) {
|
||||
man.nextRestartDelay = initialRestartDelay
|
||||
} else {
|
||||
man.nextRestartDelay *= 2
|
||||
if man.nextRestartDelay > maxRestartDelay {
|
||||
man.nextRestartDelay = maxRestartDelay
|
||||
}
|
||||
}
|
||||
man.lastRestart = now
|
||||
}
|
||||
return restart
|
||||
}
|
||||
|
||||
func (man *WireGuardStateManager) postState(state WireGuardState) {
|
||||
go func() {
|
||||
if !man.closed && (man.isNetAvailable || state == WireGuardWaitingForNetwork) {
|
||||
man.stateChan <- state
|
||||
}
|
||||
}()
|
||||
}
|
152
device/statemanager_test.go
Normal file
152
device/statemanager_test.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright (c) 2022. Proton AG
|
||||
*
|
||||
* This file is part of ProtonVPN.
|
||||
*
|
||||
* ProtonVPN is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* ProtonVPN is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var timeMs int64 = 0
|
||||
var mockDevice MockDevice
|
||||
var manager *WireGuardStateManager
|
||||
var lastState WireGuardState
|
||||
|
||||
type MockDevice struct {
|
||||
isUp bool
|
||||
upCount int
|
||||
}
|
||||
|
||||
func (dev *MockDevice) Up() error {
|
||||
dev.isUp = true
|
||||
dev.upCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dev *MockDevice) Down() error {
|
||||
dev.isUp = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func setup() {
|
||||
timeMs = 0
|
||||
timeNow = func() time.Time { return time.UnixMilli(timeMs) }
|
||||
mockDevice.isUp = false
|
||||
|
||||
manager = NewWireGuardStateManager(NewLogger(LogLevelVerbose, ""), "tcp")
|
||||
manager.Start(&mockDevice)
|
||||
lastState = WireGuardDisabled
|
||||
go func() {
|
||||
for lastState != -1 {
|
||||
lastState = manager.GetState()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func setdown() {
|
||||
manager.Close()
|
||||
}
|
||||
|
||||
func TestWireGuardStateManager_shouldRestart(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
setup()
|
||||
defer setdown()
|
||||
|
||||
assert.Equal(initialRestartDelay, manager.nextRestartDelay)
|
||||
|
||||
assert.Equal(false, manager.shouldRestart())
|
||||
timeMs += initialRestartDelay.Milliseconds()
|
||||
assert.Equal(false, manager.shouldRestart())
|
||||
timeMs += 1
|
||||
assert.Equal(true, manager.shouldRestart())
|
||||
|
||||
assert.Equal(2*initialRestartDelay, manager.nextRestartDelay)
|
||||
assert.Equal(false, manager.shouldRestart())
|
||||
timeMs += 2 * initialRestartDelay.Milliseconds()
|
||||
assert.Equal(false, manager.shouldRestart())
|
||||
timeMs += 1
|
||||
assert.Equal(true, manager.shouldRestart())
|
||||
|
||||
timeMs += resetRestartDelay.Milliseconds() + 1
|
||||
assert.Equal(true, manager.shouldRestart())
|
||||
assert.Equal(initialRestartDelay, manager.nextRestartDelay)
|
||||
}
|
||||
|
||||
func TestWireGuardStateManager_networkStartsAndStopsDevice(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
setup()
|
||||
defer setdown()
|
||||
|
||||
assert.Equal(false, mockDevice.isUp)
|
||||
manager.SetNetworkAvailable(true)
|
||||
time.Sleep(time.Millisecond) // Poor substitute for advanceUntilIdle, make sure goroutines finish before checking
|
||||
assert.Equal(true, mockDevice.isUp)
|
||||
assert.Equal(WireGuardConnecting, lastState)
|
||||
manager.SetNetworkAvailable(false)
|
||||
time.Sleep(time.Millisecond)
|
||||
assert.Equal(WireGuardWaitingForNetwork, lastState)
|
||||
assert.Equal(false, mockDevice.isUp)
|
||||
}
|
||||
|
||||
func TestWireGuardStateManager_happyConnectionPath(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
setup()
|
||||
defer setdown()
|
||||
|
||||
manager.SetNetworkAvailable(true)
|
||||
time.Sleep(time.Millisecond)
|
||||
manager.HandshakeStateChan <- HandshakeSuccess
|
||||
time.Sleep(time.Millisecond)
|
||||
assert.Equal(WireGuardConnected, lastState)
|
||||
assert.Equal(true, mockDevice.isUp)
|
||||
}
|
||||
|
||||
func TestWireGuardStateManager_handshakeFailCausesRestart(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
setup()
|
||||
defer setdown()
|
||||
|
||||
manager.SetNetworkAvailable(true)
|
||||
time.Sleep(time.Millisecond)
|
||||
manager.HandshakeStateChan <- HandshakeFail
|
||||
time.Sleep(time.Millisecond)
|
||||
assert.Equal(WireGuardError, lastState)
|
||||
timeMs += initialRestartDelay.Milliseconds() + 1
|
||||
manager.HandshakeStateChan <- HandshakeFail
|
||||
time.Sleep(time.Millisecond)
|
||||
assert.Equal(WireGuardConnecting, lastState)
|
||||
assert.Equal(2, mockDevice.upCount)
|
||||
}
|
||||
|
||||
func TestWireGuardStateManager_brokenPipeCausesRestart(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
setup()
|
||||
defer setdown()
|
||||
|
||||
manager.SetNetworkAvailable(true)
|
||||
timeMs += initialRestartDelay.Milliseconds() + 1
|
||||
time.Sleep(time.Millisecond)
|
||||
manager.SocketErrChan <- errors.New("broken pipe")
|
||||
time.Sleep(time.Millisecond)
|
||||
assert.Equal(WireGuardConnecting, lastState)
|
||||
assert.Equal(2, mockDevice.upCount)
|
||||
}
|
|
@ -78,6 +78,7 @@ func (peer *Peer) timersActive() bool {
|
|||
|
||||
func expiredRetransmitHandshake(peer *Peer) {
|
||||
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||
peer.device.handshakeStateChan <- HandshakeFail
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||
|
||||
if peer.timersActive() {
|
||||
|
@ -97,6 +98,7 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||
}
|
||||
} else {
|
||||
peer.timers.handshakeAttempts.Add(1)
|
||||
peer.device.handshakeStateChan <- HandshakeFail
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
|
|
10
go.mod
10
go.mod
|
@ -3,6 +3,8 @@ module github.com/amnezia-vpn/amneziawg-go
|
|||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/refraction-networking/utls v1.1.5
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/tevino/abool/v2 v2.1.0
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/net v0.21.0
|
||||
|
@ -15,3 +17,11 @@ require (
|
|||
github.com/google/btree v1.0.1 // indirect
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.4 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/klauspost/compress v1.15.9 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
31
go.sum
31
go.sum
|
@ -1,16 +1,41 @@
|
|||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY=
|
||||
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.1.5 h1:JtrojoNhbUQkBqEg05sP3gDgDj6hIEAAVKbI9lx4n6w=
|
||||
github.com/refraction-networking/utls v1.1.5/go.mod h1:jRQxtYi7nkq1p28HF2lwOH5zQm9aC8rpK0O9lIIzGh8=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
||||
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
||||
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
|
|
4
main.go
4
main.go
|
@ -222,11 +222,11 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
||||
errs := make(chan error)
|
||||
device := device.NewDevice(tdev, conn.NewDefaultBind(&logger.Logger, errs), logger, make(chan device.HandshakeState))
|
||||
|
||||
logger.Verbosef("Device started")
|
||||
|
||||
errs := make(chan error)
|
||||
term := make(chan os.Signal, 1)
|
||||
|
||||
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
|
||||
|
|
Loading…
Add table
Reference in a new issue