Compare commits

..

No commits in common. "master" and "v0.2.8" have entirely different histories.

113 changed files with 747 additions and 2544 deletions

View file

@ -1,4 +1,4 @@
FROM golang:1.24.4 as awg
FROM golang:1.20 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
@ -6,8 +6,7 @@ RUN go mod download && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
ARG AWGTOOLS_RELEASE="1.0.20240213"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \

View file

@ -9,7 +9,7 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \

View file

@ -50,4 +50,3 @@ $ git clone https://github.com/amnezia-vpn/amneziawg-go
$ cd amneziawg-go
$ make
```

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -298,6 +298,11 @@ func (s *StdNetBind) BatchSize() int {
return 1
}
func (s *StdNetBind) GetOffloadInfo() string {
return fmt.Sprintf("ipv4TxOffload: %v, ipv4RxOffload: %v\nipv6TxOffload: %v, ipv6RxOffload: %v",
s.ipv4TxOffload, s.ipv4RxOffload, s.ipv6TxOffload, s.ipv6RxOffload)
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -328,6 +328,10 @@ func (bind *WinRingBind) BatchSize() int {
return 1
}
func (bind *WinRingBind) GetOffloadInfo() string {
return ""
}
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
@ -91,6 +91,8 @@ func (c *ChannelBind) Close() error {
func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) GetOffloadInfo() string { return "" }
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
@ -55,6 +55,8 @@ type Bind interface {
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
GetOffloadInfo() string
}
// BindSocketToInterface is implemented by Bind objects that support being

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -13,35 +13,6 @@ import (
"golang.org/x/sys/unix"
)
// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
func kernelVersion() (major, minor int) {
var uname unix.Utsname
if err := unix.Uname(&uname); err != nil {
return
}
var (
values [2]int
value, vi int
)
for _, c := range uname.Release {
if '0' <= c && c <= '9' {
value = (value * 10) + int(c-'0')
} else {
// Note that we're assuming N.N.N here.
// If we see anything else, we are likely to mis-parse it.
values[vi] = value
vi++
if vi >= len(values) {
break
}
value = 0
}
}
return values[0], values[1]
}
func init() {
controlFns = append(controlFns,
@ -86,24 +57,5 @@ func init() {
}
return err
},
// Attempt to enable UDP_GRO
func(network, address string, c syscall.RawConn) error {
// Kernels below 5.12 are missing 98184612aca0 ("net:
// udp: Add support for getsockopt(..., ..., UDP_GRO,
// ..., ...);"), which means we can't read this back
// later. We could pipe the return value through to
// the rest of the code, but UDP_GRO is kind of buggy
// anyway, so just gate this here.
major, minor := kernelVersion()
if major < 5 || (major == 5 && minor < 12) {
return nil
}
c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
})
return nil
},
)
}

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,11 +2,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(_ error) bool {
func errShouldDisableUDPGSO(err error) bool {
return false
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -3,13 +3,13 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package winrio

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -223,11 +223,19 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
}
}
func (node *trieEntry) remove() {
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
return
continue
}
bit := 0
if node.child[0] == nil {
@ -240,12 +248,12 @@ func (node *trieEntry) remove() {
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
return
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
return
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
@ -255,37 +263,6 @@ func (node *trieEntry) remove() {
node.zeroizePointers()
parent.zeroizePointers()
}
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var node *trieEntry
var exact bool
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
} else {
panic(errors.New("removing unknown address type"))
}
if !exact || node == nil || peer != node.peer {
return
}
node.remove()
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
elem.Value.(*trieEntry).remove()
}
}
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) {
var peers []*Peer
var allowedIPs AllowedIPs
rng := rand.New(rand.NewSource(1))
rand.Seed(1)
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) {
for n := 0; n < NumberOfAddresses; n++ {
var addr4 [4]byte
rng.Read(addr4[:])
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rng.Read(addr6[:])
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) {
for p = 0; ; p++ {
for n := 0; n < NumberOfTests; n++ {
var addr4 [4]byte
rng.Read(addr4[:])
rand.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.Lookup(addr4[:])
if peer1 != peer2 {
@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) {
}
var addr6 [16]byte
rng.Read(addr6[:])
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) {
}
}
func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
root := parentIndirection{&trie, 2}
rng := rand.New(rand.NewSource(1))
rand.Seed(1)
const AddressLength = 4
@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rng.Read(addr[:])
cidr := uint8(rng.Uint32() % (AddressLength * 8))
index := rng.Int() % peerNumber
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
root.insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
rng.Read(addr[:])
rand.Read(addr[:])
trie.lookup(addr[:])
}
}
@ -101,10 +101,6 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
@ -180,21 +176,6 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
insert(a, 1, 0, 0, 0, 32)
insert(a, 192, 0, 0, 0, 24)
assertEQ(a, 1, 0, 0, 0)
assertEQ(a, 192, 0, 0, 1)
remove(a, 192, 0, 0, 0, 32)
assertEQ(a, 192, 0, 0, 1)
remove(nil, 192, 0, 0, 0, 24)
assertEQ(a, 192, 0, 0, 1)
remove(b, 192, 0, 0, 0, 24)
assertEQ(a, 192, 0, 0, 1)
remove(a, 192, 0, 0, 0, 24)
assertNEQ(a, 192, 0, 0, 1)
remove(a, 1, 0, 0, 0, 32)
assertNEQ(a, 1, 0, 0, 0)
}
/* Test ported from kernel implementation:
@ -230,15 +211,6 @@ func TestTrieIPv6(t *testing.T) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
@ -251,18 +223,6 @@ func TestTrieIPv6(t *testing.T) {
}
}
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
p := allowedIPs.Lookup(addr)
if p == peer {
t.Error("Assert NEQ failed")
}
}
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
@ -284,21 +244,4 @@ func TestTrieIPv6(t *testing.T) {
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
}

View file

@ -1,144 +0,0 @@
package awg
import (
"bytes"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"github.com/tevino/abool"
)
type aSecCfgType struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
JunkPacketMaxSize int
InitHeaderJunkSize int
ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
InitPacketMagicHeader uint32
ResponsePacketMagicHeader uint32
UnderloadPacketMagicHeader uint32
TransportPacketMagicHeader uint32
// InitPacketMagicHeader Limit
// ResponsePacketMagicHeader Limit
// UnderloadPacketMagicHeader Limit
// TransportPacketMagicHeader Limit
}
type Limit struct {
Min uint32
Max uint32
HeaderType uint32
}
func NewLimit(min, max, headerType uint32) (Limit, error) {
if min > max {
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return Limit{
Min: min,
Max: max,
HeaderType: headerType,
}, nil
}
func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
// tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType)
// var min, max, headerType uint32
// _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType)
// if err != nil {
// return Limit{}, fmt.Errorf("invalid magic header format: %s", value)
// }
limits := strings.Split(value, "-")
if len(limits) != 2 {
return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value)
}
min, err := strconv.ParseUint(limits[0], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err)
}
max, err := strconv.ParseUint(limits[1], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err)
}
limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
if err != nil {
return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err)
}
return limit, nil
}
type Limits []Limit
func NewLimits(limits []Limit) Limits {
slices.SortFunc(limits, func(a, b Limit) int {
if a.Min < b.Min {
return -1
} else if a.Min > b.Min {
return 1
}
return 0
})
return Limits(limits)
}
type Protocol struct {
IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex
ASecMux sync.RWMutex
ASecCfg aSecCfgType
JunkCreator junkCreator
HandshakeHandler SpecialHandshakeHandler
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) {
extraSize := 0
if len(optExtraSize) == 1 {
extraSize = optExtraSize[0]
}
var junk []byte
protocol.ASecMux.RLock()
if junkSize != 0 {
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
protocol.ASecMux.RUnlock()
return nil, err
}
junk = writer.Bytes()
}
protocol.ASecMux.RUnlock()
return junk, nil
}

View file

@ -1,37 +0,0 @@
package internal
type mockGenerator struct {
size int
}
func NewMockGenerator(size int) mockGenerator {
return mockGenerator{size: size}
}
func (m mockGenerator) Generate() []byte {
return make([]byte, m.size)
}
func (m mockGenerator) Size() int {
return m.size
}
func (m mockGenerator) Name() string {
return "mock"
}
type mockByteGenerator struct {
data []byte
}
func NewMockByteGenerator(data []byte) mockByteGenerator {
return mockByteGenerator{data: data}
}
func (bg mockByteGenerator) Generate() []byte {
return bg.data
}
func (bg mockByteGenerator) Size() int {
return len(bg.data)
}

View file

@ -1,70 +0,0 @@
package awg
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
aSecCfg aSecCfgType
cha8Rand *v2.ChaCha8
}
// TODO: refactor param to only pass the junk related params
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
if jc.aSecCfg.JunkPacketCount == 0 {
return nil
}
for range jc.aSecCfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return fmt.Errorf("create junk packet: %v", err)
}
*junks = append(*junks, junk)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
),
) + jc.aSecCfg.JunkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
// TODO: use a memory pool to allocate
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

View file

@ -1,115 +0,0 @@
package awg
import (
"bytes"
"fmt"
"testing"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
jc, err := NewJunkCreator(aSecCfgType{
IsSet: true,
JunkPacketCount: 5,
JunkPacketMinSize: 500,
JunkPacketMaxSize: 1000,
InitHeaderJunkSize: 30,
ResponseHeaderJunkSize: 40,
InitPacketMagicHeader: 123456,
ResponsePacketMagicHeader: 67543,
UnderloadPacketMagicHeader: 32345,
TransportPacketMagicHeader: 123123,
})
if err != nil {
t.Errorf("failed to create junk creator %v", err)
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("valid", func(t *testing.T) {
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
err := jc.CreateJunkPackets(&got)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("valid", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
got > jc.aSecCfg.JunkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.aSecCfg.JunkPacketMinSize,
jc.aSecCfg.JunkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("valid", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.AppendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Error("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -1,73 +0,0 @@
package awg
import (
"errors"
"time"
"github.com/tevino/abool"
"go.uber.org/atomic"
)
// TODO: atomic?/ and better way to use this
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
// TODO
var WaitResponse = struct {
Channel chan struct{}
ShouldWait *abool.AtomicBool
}{
make(chan struct{}, 1),
abool.New(),
}
type SpecialHandshakeHandler struct {
isFirstDone bool
SpecialJunk TagJunkPacketGenerators
ControlledJunk TagJunkPacketGenerators
nextItime time.Time
ITimeout time.Duration // seconds
IsSet bool
}
func (handler *SpecialHandshakeHandler) Validate() error {
var errs []error
if err := handler.SpecialJunk.Validate(); err != nil {
errs = append(errs, err)
}
if err := handler.ControlledJunk.Validate(); err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
if !handler.SpecialJunk.IsDefined() {
return nil
}
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
} else if !handler.isTimeToSendSpecial() {
return nil
}
rv := handler.SpecialJunk.GeneratePackets()
handler.nextItime = time.Now().Add(handler.ITimeout)
return rv
}
func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
return time.Now().After(handler.nextItime)
}
func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
if !handler.ControlledJunk.IsDefined() {
return nil
}
return handler.ControlledJunk.GeneratePackets()
}

View file

@ -1,190 +0,0 @@
package awg
import (
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
v2 "math/rand/v2"
// "go.uber.org/atomic"
)
type Generator interface {
Generate() []byte
Size() int
}
type newGenerator func(string) (Generator, error)
type BytesGenerator struct {
value []byte
size int
}
func (bg *BytesGenerator) Generate() []byte {
return bg.value
}
func (bg *BytesGenerator) Size() int {
return bg.size
}
func newBytesGenerator(param string) (Generator, error) {
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
if !hasPrefix {
return nil, fmt.Errorf("not correct hex: %s", param)
}
hex, err := hexToBytes(param)
if err != nil {
return nil, fmt.Errorf("hexToBytes: %w", err)
}
return &BytesGenerator{value: hex, size: len(hex)}, nil
}
func hexToBytes(hexStr string) ([]byte, error) {
hexStr = strings.TrimPrefix(hexStr, "0x")
hexStr = strings.TrimPrefix(hexStr, "0X")
// Ensure even length (pad with leading zero if needed)
if len(hexStr)%2 != 0 {
hexStr = "0" + hexStr
}
return hex.DecodeString(hexStr)
}
type RandomPacketGenerator struct {
cha8Rand *v2.ChaCha8
size int
}
func (rpg *RandomPacketGenerator) Generate() []byte {
junk := make([]byte, rpg.size)
rpg.cha8Rand.Read(junk)
return junk
}
func (rpg *RandomPacketGenerator) Size() int {
return rpg.size
}
func newRandomPacketGenerator(param string) (Generator, error) {
size, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("random packet parse int: %w", err)
}
if size > 1000 {
return nil, fmt.Errorf("random packet size must be less than 1000")
}
buf := make([]byte, 32)
_, err = crand.Read(buf)
if err != nil {
return nil, fmt.Errorf("random packet crand read: %w", err)
}
return &RandomPacketGenerator{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
size: size,
}, nil
}
type TimestampGenerator struct {
}
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
return buf
}
func (tg *TimestampGenerator) Size() int {
return 8
}
func newTimestampGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
}
return &TimestampGenerator{}, nil
}
type WaitTimeoutGenerator struct {
waitTimeout time.Duration
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
time.Sleep(wtg.waitTimeout)
return []byte{}
}
func (wtg *WaitTimeoutGenerator) Size() int {
return 0
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
timeout, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
if timeout > 5000 {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
return &WaitTimeoutGenerator{
waitTimeout: time.Duration(timeout) * time.Millisecond,
}, nil
}
type PacketCounterGenerator struct {
}
func (c *PacketCounterGenerator) Generate() []byte {
buf := make([]byte, 8)
// TODO: better way to handle counter tag
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
return buf
}
func (c *PacketCounterGenerator) Size() int {
return 8
}
func newPacketCounterGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
}
return &PacketCounterGenerator{}, nil
}
type WaitResponseGenerator struct {
}
func (c *WaitResponseGenerator) Generate() []byte {
WaitResponse.ShouldWait.Set()
<-WaitResponse.Channel
WaitResponse.ShouldWait.UnSet()
return []byte{}
}
func (c *WaitResponseGenerator) Size() int {
return 0
}
func newWaitResponseGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
}
return &WaitResponseGenerator{}, nil
}

View file

@ -1,189 +0,0 @@
package awg
import (
"encoding/binary"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func Test_newBytesGenerator(t *testing.T) {
type args struct {
param string
}
tests := []struct {
name string
args args
want []byte
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "wrong start",
args: args{
param: "123456",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with X",
args: args{
param: "0X12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with x",
args: args{
param: "0x12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "valid hex",
args: args{
param: "0xf6ab3267fa",
},
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
},
{
name: "valid hex with odd length",
args: args{
param: "0xfab3267fa",
},
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newBytesGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
gotValues := got.Generate()
require.Equal(t, tt.want, gotValues)
})
}
}
func Test_newRandomPacketGenerator(t *testing.T) {
type args struct {
param string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "not an int",
args: args{
param: "x",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "too large",
args: args{
param: "1001",
},
wantErr: fmt.Errorf("random packet size must be less than 1000"),
},
{
name: "valid",
args: args{
param: "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newRandomPacketGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
first := got.Generate()
second := got.Generate()
require.NotEqual(t, first, second)
})
}
}
func TestPacketCounterGenerator(t *testing.T) {
tests := []struct {
name string
param string
wantErr bool
}{
{
name: "Valid empty param",
param: "",
wantErr: false,
},
{
name: "Invalid non-empty param",
param: "anything",
wantErr: true,
},
}
for _, tc := range tests {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
gen, err := newPacketCounterGenerator(tc.param)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, 8, gen.Size())
// Reset counter to known value for test
initialCount := uint64(42)
PacketCounter.Store(initialCount)
output := gen.Generate()
require.Equal(t, 8, len(output))
// Verify counter value in output
counterValue := binary.BigEndian.Uint64(output)
require.Equal(t, initialCount, counterValue)
// Increment counter and verify change
PacketCounter.Add(1)
output = gen.Generate()
counterValue = binary.BigEndian.Uint64(output)
require.Equal(t, initialCount+1, counterValue)
})
}
}

View file

@ -1,59 +0,0 @@
package awg
import (
"fmt"
"strconv"
)
type TagJunkPacketGenerator struct {
name string
tagValue string
packetSize int
generators []Generator
}
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
return TagJunkPacketGenerator{
name: name,
tagValue: tagValue,
generators: make([]Generator, 0, size),
}
}
func (tg *TagJunkPacketGenerator) append(generator Generator) {
tg.generators = append(tg.generators, generator)
tg.packetSize += generator.Size()
}
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
}
return packet
}
func (tg *TagJunkPacketGenerator) Name() string {
return tg.name
}
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
if len(tg.name) != 2 {
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
}
index, err := strconv.Atoi(tg.name[1:2])
if err != nil {
return 0, fmt.Errorf("name 2 char should be an int %w", err)
}
return index, nil
}
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
return IpcFields{
Key: tg.name,
Value: tg.tagValue,
}
}

View file

@ -1,210 +0,0 @@
package awg
import (
"testing"
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
"github.com/stretchr/testify/require"
)
func TestNewTagJunkGenerator(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
genName string
size int
expected TagJunkPacketGenerator
}{
{
name: "Create new generator with empty name",
genName: "",
size: 0,
expected: TagJunkPacketGenerator{
name: "",
packetSize: 0,
generators: make([]Generator, 0),
},
},
{
name: "Create new generator with valid name",
genName: "T1",
size: 0,
expected: TagJunkPacketGenerator{
name: "T1",
packetSize: 0,
generators: make([]Generator, 0),
},
},
{
name: "Create new generator with non-zero size",
genName: "T2",
size: 5,
expected: TagJunkPacketGenerator{
name: "T2",
packetSize: 0,
generators: make([]Generator, 5),
},
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
require.Equal(t, tc.expected.name, result.name)
require.Equal(t, tc.expected.packetSize, result.packetSize)
require.Equal(t, cap(result.generators), len(tc.expected.generators))
})
}
}
func TestTagJunkGeneratorAppend(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
initialState TagJunkPacketGenerator
mockSize int
expectedLength int
expectedSize int
}{
{
name: "Append to empty generator",
initialState: newTagJunkPacketGenerator("T1", "", 0),
mockSize: 5,
expectedLength: 1,
expectedSize: 5,
},
{
name: "Append to non-empty generator",
initialState: TagJunkPacketGenerator{
name: "T2",
packetSize: 10,
generators: make([]Generator, 2),
},
mockSize: 7,
expectedLength: 3, // 2 existing + 1 new
expectedSize: 17, // 10 + 7
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := tc.initialState
mockGen := internal.NewMockGenerator(tc.mockSize)
tg.append(mockGen)
require.Equal(t, tc.expectedLength, len(tg.generators))
require.Equal(t, tc.expectedSize, tg.packetSize)
})
}
}
func TestTagJunkGeneratorGenerate(t *testing.T) {
t.Parallel()
// Create mock generators for testing
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
testCases := []struct {
name string
setupGenerator func() TagJunkPacketGenerator
expected []byte
}{
{
name: "Generate with empty generators",
setupGenerator: func() TagJunkPacketGenerator {
return newTagJunkPacketGenerator("T1", "", 0)
},
expected: []byte{},
},
{
name: "Generate with single generator",
setupGenerator: func() TagJunkPacketGenerator {
tg := newTagJunkPacketGenerator("T2", "", 0)
tg.append(mockGen1)
return tg
},
expected: []byte{0x01, 0x02},
},
{
name: "Generate with multiple generators",
setupGenerator: func() TagJunkPacketGenerator {
tg := newTagJunkPacketGenerator("T3", "", 0)
tg.append(mockGen1)
tg.append(mockGen2)
return tg
},
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := tc.setupGenerator()
result := tg.generatePacket()
require.Equal(t, tc.expected, result)
})
}
}
func TestTagJunkGeneratorNameIndex(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
generatorName string
expectedIndex int
expectError bool
}{
{
name: "Valid name with digit",
generatorName: "T5",
expectedIndex: 5,
expectError: false,
},
{
name: "Invalid name - too short",
generatorName: "T",
expectError: true,
},
{
name: "Invalid name - too long",
generatorName: "T55",
expectError: true,
},
{
name: "Invalid name - non-digit second character",
generatorName: "TX",
expectError: true,
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := TagJunkPacketGenerator{name: tc.generatorName}
index, err := tg.nameIndex()
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.expectedIndex, index)
}
})
}
}

View file

@ -1,66 +0,0 @@
package awg
import "fmt"
type TagJunkPacketGenerators struct {
tagGenerators []TagJunkPacketGenerator
length int
DefaultJunkCount int // Jc
}
func (generators *TagJunkPacketGenerators) AppendGenerator(
generator TagJunkPacketGenerator,
) {
generators.tagGenerators = append(generators.tagGenerators, generator)
generators.length++
}
func (generators *TagJunkPacketGenerators) IsDefined() bool {
return len(generators.tagGenerators) > 0
}
// validate that packets were defined consecutively
func (generators *TagJunkPacketGenerators) Validate() error {
seen := make([]bool, len(generators.tagGenerators))
for _, generator := range generators.tagGenerators {
index, err := generator.nameIndex()
if index > len(generators.tagGenerators) {
return fmt.Errorf("junk packet index should be consecutive")
}
if err != nil {
return fmt.Errorf("name index: %w", err)
} else {
seen[index-1] = true
}
}
for _, found := range seen {
if !found {
return fmt.Errorf("junk packet index should be consecutive")
}
}
return nil
}
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
for i, tagGenerator := range generators.tagGenerators {
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
PacketCounter.Inc()
}
PacketCounter.Add(uint64(generators.DefaultJunkCount))
return rv
}
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
rv := make([]IpcFields, 0, len(tg.tagGenerators))
for _, generator := range tg.tagGenerators {
rv = append(rv, generator.IpcGetFields())
}
return rv
}

View file

@ -1,149 +0,0 @@
package awg
import (
"testing"
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
"github.com/stretchr/testify/require"
)
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
tests := []struct {
name string
generator TagJunkPacketGenerator
}{
{
name: "append single generator",
generator: newTagJunkPacketGenerator("t1", "", 10),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
// Initial length should be 0
require.Equal(t, 0, generators.length)
require.Empty(t, generators.tagGenerators)
// After append, length should be 1 and generator should be added
generators.AppendGenerator(tt.generator)
require.Equal(t, 1, generators.length)
require.Len(t, generators.tagGenerators, 1)
require.Equal(t, tt.generator, generators.tagGenerators[0])
})
}
}
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
tests := []struct {
name string
generators []TagJunkPacketGenerator
wantErr bool
errMsg string
}{
{
name: "bad start",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t3", "", 10),
newTagJunkPacketGenerator("t4", "", 10),
},
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "non-consecutive indices",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t1", "", 10),
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
},
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "consecutive indices",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t1", "", 10),
newTagJunkPacketGenerator("t2", "", 10),
newTagJunkPacketGenerator("t3", "", 10),
newTagJunkPacketGenerator("t4", "", 10),
newTagJunkPacketGenerator("t5", "", 10),
},
},
{
name: "nameIndex error",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("error", "", 10),
},
wantErr: true,
errMsg: "name must be 2 character long",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
for _, gen := range tt.generators {
generators.AppendGenerator(gen)
}
err := generators.Validate()
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
return
}
require.NoError(t, err)
})
}
}
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
mockByte1 := []byte{0x01, 0x02}
mockByte2 := []byte{0x03, 0x04, 0x05}
mockGen1 := internal.NewMockByteGenerator(mockByte1)
mockGen2 := internal.NewMockByteGenerator(mockByte2)
tests := []struct {
name string
setupGenerator func() []TagJunkPacketGenerator
expected [][]byte
}{
{
name: "generate with no default junk",
setupGenerator: func() []TagJunkPacketGenerator {
tg1 := newTagJunkPacketGenerator("t1", "", 0)
tg1.append(mockGen1)
tg1.append(mockGen2)
tg2 := newTagJunkPacketGenerator("t2", "", 0)
tg2.append(mockGen2)
tg2.append(mockGen1)
return []TagJunkPacketGenerator{tg1, tg2}
},
expected: [][]byte{
append(mockByte1, mockByte2...),
append(mockByte2, mockByte1...),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
tagGenerators := tt.setupGenerator()
for _, gen := range tagGenerators {
generators.AppendGenerator(gen)
}
result := generators.GeneratePackets()
require.Equal(t, result, tt.expected)
})
}
}

View file

@ -1,112 +0,0 @@
package awg
import (
"fmt"
"maps"
"regexp"
"strings"
)
type IpcFields struct{ Key, Value string }
type EnumTag string
const (
BytesEnumTag EnumTag = "b"
CounterEnumTag EnumTag = "c"
TimestampEnumTag EnumTag = "t"
RandomBytesEnumTag EnumTag = "r"
WaitTimeoutEnumTag EnumTag = "wt"
WaitResponseEnumTag EnumTag = "wr"
)
var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomPacketGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
// WaitResponseEnumTag: newWaitResponseGenerator,
}
// helper map to determine enumTags are unique
var uniqueTags = map[EnumTag]bool{
CounterEnumTag: false,
TimestampEnumTag: false,
}
type Tag struct {
Name EnumTag
Param string
}
func parseTag(input string) (Tag, error) {
// Regular expression to match <tagname optional_param>
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
match := re.FindStringSubmatch(input)
tag := Tag{
Name: EnumTag(match[1]),
}
if len(match) > 2 && match[2] != "" {
tag.Param = strings.TrimSpace(match[2])
}
return tag, nil
}
func Parse(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
}
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
maps.Copy(uniqueTagCheck, uniqueTags)
// skip byproduct of split
inputSlice = inputSlice[1:]
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
for _, inputParam := range inputSlice {
if len(inputParam) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf(
"empty tag in input: %s",
inputSlice,
)
} else if strings.Count(inputParam, ">") != 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
}
tag, _ := parseTag(inputParam)
creator, ok := generatorCreator[tag.Name]
if !ok {
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
}
if present, ok := uniqueTagCheck[tag.Name]; ok {
if present {
return TagJunkPacketGenerator{}, fmt.Errorf(
"tag %s needs to be unique",
tag.Name,
)
}
uniqueTagCheck[tag.Name] = true
}
generator, err := creator(tag.Param)
if err != nil {
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
}
// TODO: handle counter tag
// if tag.Name == CounterEnumTag {
// packetCounter, ok := generator.(*PacketCounterGenerator)
// if !ok {
// log.Fatalf("packet counter generator expected, got %T", generator)
// }
// PacketCounter = packetCounter.counter
// }
rv.append(generator)
}
return rv, nil
}

View file

@ -1,77 +0,0 @@
package awg
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestParse(t *testing.T) {
type args struct {
name string
input string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "invalid name",
args: args{name: "apple", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "empty",
args: args{name: "i1", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "extra >",
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "extra <",
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"),
},
{
name: "empty <>",
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"),
},
{
name: "invalid tag",
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
wantErr: fmt.Errorf("invalid tag"),
},
{
name: "counter uniqueness violation",
args: args{name: "i1", input: "<c><c>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "timestamp uniqueness violation",
args: args{name: "i1", input: "<t><t>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "valid",
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Parse(tt.args.name, tt.args.input)
// TODO: ErrorAs doesn't work as you think
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
return
}
require.Nil(t, err)
})
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,60 +1,24 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/tevino/abool/v2"
)
type Version uint8
const (
VersionDefault Version = iota
VersionAwg
VersionAwgSpecialHandshake
)
// TODO:
type AtomicVersion struct {
value atomic.Uint32
}
func NewAtomicVersion(v Version) *AtomicVersion {
av := &AtomicVersion{}
av.Store(v)
return av
}
func (av *AtomicVersion) Load() Version {
return Version(av.value.Load())
}
func (av *AtomicVersion) Store(v Version) {
av.value.Store(uint32(v))
}
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
return av.value.CompareAndSwap(uint32(old), uint32(new))
}
func (av *AtomicVersion) Swap(new Version) Version {
return Version(av.value.Swap(uint32(new)))
}
type Device struct {
state struct {
// state holds the device's state. It is accessed atomically.
@ -128,8 +92,22 @@ type Device struct {
closed chan struct{}
log *Logger
version Version
awg awg.Protocol
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
}
type aSecCfgType struct {
isSet bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
// deviceState represents the state of a Device.
@ -569,6 +547,7 @@ func (device *Device) BindUpdate() error {
}
device.log.Verbosef("UDP bind has been updated")
device.log.Verbosef(netc.bind.GetOffloadInfo())
return nil
}
@ -578,261 +557,250 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
func (device *Device) isAWG() bool {
return device.version >= VersionAwg
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = DefaultMessageInitiationType
MessageResponseType = DefaultMessageResponseType
MessageCookieReplyType = DefaultMessageCookieReplyType
MessageTransportType = DefaultMessageTransportType
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
}
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
var errs []error
if !tempASecCfg.isSet {
return err
}
isASecOn := false
device.awg.ASecMux.Lock()
if tempAwg.ASecCfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
),
)
}
device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
if tempAwg.ASecCfg.JunkPacketCount != 0 {
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.awg.ASecCfg.JunkPacketCount > 0 &&
tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
if device.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.ASecCfg.JunkPacketMinSize = 0
device.awg.ASecCfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf(
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize,
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
))
} else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf(
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize,
tempAwg.ASecCfg.JunkPacketMinSize,
))
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
)
}
} else {
device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
if tempASecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.InitHeaderJunkSize,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
),
)
}
} else {
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
if tempASecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.ResponseHeaderJunkSize,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
),
)
}
} else {
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
if tempASecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
}
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
}
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
}
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
isASecOn = true
}
if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
MessageInitiationType = 1
}
if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 {
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
MessageResponseType = 2
}
if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 {
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
MessageCookieReplyType = 3
}
if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 {
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
MessageTransportType = 4
}
isSameHeaderMap := map[uint32]struct{}{
MessageInitiationType: {},
MessageResponseType: {},
MessageCookieReplyType: {},
MessageTransportType: {},
}
isSameMap := map[uint32]bool{}
isSameMap[MessageInitiationType] = true
isSameMap[MessageResponseType] = true
isSameMap[MessageCookieReplyType] = true
isSameMap[MessageTransportType] = true
// size will be different if same values
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
if len(isSameMap) != 4 {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
),
)
}
isSameSizeMap := map[int]struct{}{
newInitSize: {},
newResponseSize: {},
newCookieSize: {},
newTransportSize: {},
}
if len(isSameSizeMap) != 4 {
errs = append(errs, ipcErrorf(
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
`new init size:%d; and new response size:%d; should differ; %w`,
newInitSize,
newResponseSize,
newCookieSize,
newTransportSize,
),
err,
)
} else {
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
)
}
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
newCookieSize: MessageCookieReplyType,
newTransportSize: MessageTransportType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.awg.IsASecOn.SetTo(isASecOn)
var err error
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
if err != nil {
errs = append(errs, err)
}
device.isASecOn.SetTo(isASecOn)
device.aSecMux.Unlock()
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
device.awg.ASecMux.Unlock()
return errors.Join(errs...)
return err
}

View file

@ -1,28 +1,25 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/netip"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"sync"
"sync/atomic"
"testing"
"time"
"go.uber.org/atomic"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
@ -53,7 +50,7 @@ func uapiCfg(cfg ...string) string {
// genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
@ -65,8 +62,7 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
args0 := append([]string(nil), cfg...)
args0 = append(args0, []string{
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
@ -74,16 +70,12 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
}...)
cfgs[0] = uapiCfg(args0...)
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
args1 := append([]string(nil), cfg...)
args1 = append(args1, []string{
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
@ -91,9 +83,66 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
}...)
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
cfgs[1] = uapiCfg(args1...)
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
@ -136,10 +185,9 @@ func (pair *testPair) Send(
// pong is the new ping
p0, p1 = p1, p0
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(6 * time.Second)
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
var err error
select {
@ -166,12 +214,14 @@ func (pair *testPair) Send(
// genTestPair creates a testPair.
func genTestPair(
tb testing.TB,
realSocket bool,
extraCfg ...string,
realSocket, withASecurity bool,
) (pair testPair) {
var cfg, endpointCfg [2]string
cfg, endpointCfg = genConfigs(tb, extraCfg...)
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
} else {
cfg, endpointCfg = genConfigs(tb)
}
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
@ -215,7 +265,7 @@ func genTestPair(
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true)
pair := genTestPair(t, true, false)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@ -224,23 +274,9 @@ func TestTwoDevicePing(t *testing.T) {
})
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestAWGDevicePing(t *testing.T) {
func TestTwoDevicePingASecurity(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true,
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"s3", "50",
"s4", "5",
"h1", "123456",
"h2", "67543",
"h3", "123123",
"h4", "32345",
)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@ -249,58 +285,13 @@ func TestAWGDevicePing(t *testing.T) {
})
}
// Needs to be stopped with Ctrl-C
func TestAWGHandshakeDevicePing(t *testing.T) {
t.Skip("This test is intended to be run manually, not as part of the test suite.")
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
isRunning := atomic.NewBool(true)
go func() {
<-signalContext.Done()
fmt.Println("Waiting to finish")
isRunning.Store(false)
}()
goroutineLeakCheck(t)
pair := genTestPair(t, true,
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
"i2", "<b 0xf6ab3267fa><r 100>",
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
"j2", "<c><b 0xf6ab><t><wt 1000>",
"j3", "<t><b 0xf6ab><c><r 10>",
"itime", "60",
// "jc", "1",
// "jmin", "500",
// "jmax", "1000",
// "s1", "30",
// "s2", "40",
// "h1", "123456",
// "h2", "67543",
// "h4", "32345",
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
for isRunning.Load() {
pair.Send(t, Ping, nil)
time.Sleep(2 * time.Second)
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
for isRunning.Load() {
pair.Send(t, Pong, nil)
time.Sleep(2 * time.Second)
}
})
}
func TestUpDown(t *testing.T) {
goroutineLeakCheck(t)
const itrials = 50
const otrials = 10
for n := 0; n < otrials; n++ {
pair := genTestPair(t, false)
pair := genTestPair(t, false, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@ -334,7 +325,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true)
pair := genTestPair(t, true, false)
done := make(chan struct{})
const warmupIters = 10
@ -415,7 +406,7 @@ func TestConcurrencySafety(t *testing.T) {
}
func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true)
pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
@ -429,7 +420,7 @@ func BenchmarkLatency(b *testing.B) {
}
func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true)
pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
@ -473,7 +464,7 @@ func BenchmarkThroughput(b *testing.B) {
}
func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true)
pair := genTestPair(b, true, false)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -52,18 +52,11 @@ const (
WGLabelCookie = "cookie--"
)
const (
DefaultMessageInitiationType uint32 = 1
DefaultMessageResponseType uint32 = 2
DefaultMessageCookieReplyType uint32 = 3
DefaultMessageTransportType uint32 = 4
)
var (
MessageInitiationType uint32 = DefaultMessageInitiationType
MessageResponseType uint32 = DefaultMessageResponseType
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
MessageTransportType uint32 = DefaultMessageTransportType
MessageInitiationType uint32 = 1
MessageResponseType uint32 = 2
MessageCookieReplyType uint32 = 3
MessageTransportType uint32 = 4
)
const (
@ -82,10 +75,9 @@ const (
MessageTransportOffsetContent = 16
)
var (
packetSizeToMsgType map[int]uint32
msgTypeToJunkSize map[uint32]int
)
var packetSizeToMsgType map[int]uint32
var msgTypeToJunkSize map[uint32]int
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
@ -205,12 +197,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock()
device.aSecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -264,12 +256,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.awg.ASecMux.RLock()
device.aSecMux.RLock()
if msg.Type != MessageInitiationType {
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
return nil
}
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -384,9 +376,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.awg.ASecMux.RLock()
device.aSecMux.RLock()
msg.Type = MessageResponseType
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -436,12 +428,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.ASecMux.RLock()
device.aSecMux.RLock()
if msg.Type != MessageResponseType {
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
return nil
}
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
// lookup handshake by receiver

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -13,7 +13,6 @@ import (
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
)
type Peer struct {
@ -114,16 +113,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
err := peer.SendBuffers(buffers)
if err == nil {
awg.PacketCounter.Add(uint64(len(buffers)))
return nil
}
return err
}
func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()

View file

@ -1,19 +1,20 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync"
"sync/atomic"
)
type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count uint32 // Get calls not yet Put back
count atomic.Uint32
max uint32
}
@ -26,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count >= p.max {
for p.count.Load() >= p.max {
p.cond.Wait()
}
p.count++
p.count.Add(1)
p.lock.Unlock()
}
return p.pool.Get()
@ -40,9 +41,7 @@ func (p *WaitPool) Put(x any) {
if p.max == 0 {
return
}
p.lock.Lock()
defer p.lock.Unlock()
p.count--
p.count.Add(^uint32(0))
p.cond.Signal()
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -32,9 +32,7 @@ func TestWaitPool(t *testing.T) {
wg.Add(workers)
var max atomic.Uint32
updateMax := func() {
p.lock.Lock()
count := p.count
p.lock.Unlock()
count := p.count.Load()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.awg.ASecMux.RLock()
device.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@ -137,14 +137,10 @@ func (device *Device) RoutineReceiveIncoming(
}
// check size of packet
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAWG() {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
@ -153,29 +149,19 @@ func (device *Device) RoutineReceiveIncoming(
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
device.log.Verbosef("transport packet lined up with another msg type")
device.log.Verbosef("Transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
msgType = binary.LittleEndian.Uint32(packet[:4])
if msgType != MessageTransportType {
// probably a junk packet
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
device.log.Verbosef("ASec: Received message with unknown type")
continue
}
// remove junk from bufsArrs by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy(bufsArrs[i][:size], packet[transportJunkSize:])
size -= transportJunkSize
// need to reinitialize packet as well
packet = packet[:size]
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
}
switch msgType {
// check if transport
@ -259,7 +245,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@ -318,7 +304,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.awg.ASecMux.RLock()
device.aSecMux.RLock()
// handle cookie fields and ratelimiting
@ -470,7 +456,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.awg.ASecMux.RUnlock()
device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -9,6 +9,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"math/rand"
"net"
"os"
"sync"
@ -124,30 +125,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.ASecMux.RLock()
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
if junks != nil {
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
}
} else {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
peer.device.awg.ASecMux.RUnlock()
} else {
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.createJunkPackets()
peer.device.aSecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
@ -163,11 +146,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
@ -182,7 +173,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
sendBuffer = append(sendBuffer, junkedHeader)
err = peer.SendAndCountBuffers(sendBuffer)
err = peer.SendBuffers(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@ -203,13 +194,22 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@ -229,7 +229,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@ -252,19 +252,11 @@ func (device *Device) SendHandshakeCookie(
return err
}
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
if err != nil {
device.log.Errorf("%v - %v", device, err)
return err
}
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
junkedHeader = append(junkedHeader, writer.Bytes()...)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@ -477,6 +469,31 @@ top:
}
}
func (peer *Peer) createJunkPackets() ([][]byte, error) {
if peer.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
packetSize := rand.Intn(
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
) + peer.device.aSecCfg.junkPacketMinSize
junk, err := randomJunkWithSize(packetSize)
if err != nil {
peer.device.log.Errorf(
"%v - Failed to create junk packet: %v",
peer,
err,
)
return nil, err
}
junks = append(junks, junk)
}
return junks, nil
}
func (peer *Peer) FlushStagedPackets() {
for {
select {
@ -577,7 +594,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
continue
}
dataSent := false
@ -585,14 +601,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
if err != nil {
device.log.Errorf("%v - %v", device, err)
continue
}
elem.packet = append(junkedHeader, elem.packet...)
}
bufs = append(bufs, elem.packet)
}
@ -600,11 +608,10 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendAndCountBuffers(bufs)
err := peer.SendBuffers(bufs)
if dataSent {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)

View file

@ -7,6 +7,6 @@ import (
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) {
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
return nil, nil
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@ -9,7 +9,7 @@
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code remains platform dependent.
* So this code is remains platform dependent.
*/
package device
@ -47,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
return netlinkCancel, nil
}
func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@ -18,7 +18,6 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
@ -98,51 +97,33 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.isAWG() {
if device.awg.ASecCfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
}
if device.awg.ASecCfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
}
if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
}
if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
}
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
}
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
}
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
}
if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
}
if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
}
if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
}
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
}
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
for _, field := range controlledJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
if device.awg.HandshakeHandler.ITimeout != 0 {
sendf("itime=%d", device.awg.HandshakeHandler.ITimeout/time.Second)
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
}
}
@ -199,13 +180,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
tempAwg := awg.Protocol{}
tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempAwg)
err := device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
@ -236,7 +217,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value, &tempAwg)
err = device.handleDeviceLine(key, value, &tempASecCfg)
} else {
err = device.handlePeerLine(peer, key, value)
}
@ -244,7 +225,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
err = device.handlePostConfig(&tempAwg)
err = device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
@ -256,7 +237,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@ -297,11 +278,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
case "replace_peers":
if value != "true" {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set replace_peers, invalid value: %v",
value,
)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
@ -309,138 +286,80 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
tempAwg.ASecCfg.IsSet = true
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.ASecCfg.IsSet = true
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.ASecCfg.IsSet = true
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.ASecCfg.IsSet = true
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.ASecCfg.IsSet = true
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
}
tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
generators, err := awg.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.IsSet = true
case "j1", "j2", "j3":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
}
generators, err := awg.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.IsSet = true
case "itime":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty itime")
return nil
}
itime, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
}
device.log.Verbosef("UAPI: Updating itime")
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
tempAwg.HandshakeHandler.IsSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
}
@ -513,11 +432,7 @@ func (device *Device) handlePeerLine(
case "update_only":
// allow disabling of creation
if value != "true" {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set update only, invalid value: %v",
value,
)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
}
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
@ -563,11 +478,7 @@ func (device *Device) handlePeerLine(
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set persistent keepalive interval: %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
@ -578,11 +489,7 @@ func (device *Device) handlePeerLine(
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to replace allowedips, invalid value: %v",
value,
)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
}
if peer.dummy {
return nil
@ -590,14 +497,7 @@ func (device *Device) handlePeerLine(
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
add := true
verb := "Adding"
if len(value) > 0 && value[0] == '-' {
add = false
verb = "Removing"
value = value[1:]
}
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
@ -605,11 +505,7 @@ func (device *Device) handlePeerLine(
if peer.dummy {
return nil
}
if add {
device.allowedips.Insert(prefix, peer.Peer)
} else {
device.allowedips.Remove(prefix, peer.Peer)
}
case "protocol_version":
if value != "1" {
@ -661,11 +557,7 @@ func (device *Device) IpcHandle(socket net.Conn) {
return
}
if nextByte != '\n' {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"trailing character in UAPI get: %q",
nextByte,
)
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
break
}
err = device.IpcGetOperation(buffered.Writer)

25
device/util.go Normal file
View file

@ -0,0 +1,25 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
)
func appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
func randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := crand.Read(junk)
return junk, err
}

27
device/util_test.go Normal file
View file

@ -0,0 +1,27 @@
package device
import (
"bytes"
"fmt"
"testing"
)
func Test_randomJunktWithSize(t *testing.T) {
junk, err := randomJunkWithSize(30)
fmt.Println(string(junk), len(junk), err)
}
func Test_appendJunk(t *testing.T) {
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main

20
go.mod
View file

@ -1,23 +1,17 @@
module github.com/amnezia-vpn/amneziawg-go
go 1.24.4
go 1.20
require (
github.com/stretchr/testify v1.10.0
github.com/tevino/abool v1.2.0
github.com/tevino/abool/v2 v2.1.0
go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.39.0
golang.org/x/net v0.41.0
golang.org/x/sys v0.33.0
golang.org/x/crypto v0.19.0
golang.org/x/net v0.21.0
golang.org/x/sys v0.17.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/time v0.9.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)

48
go.sum
View file

@ -1,40 +1,16 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package replay

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package rwcancel implements cancelable read/write operations on
@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool {
func (rw *RWCancel) ReadyWrite() bool {
closeFd := int32(rw.closingReader.Fd())
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}}
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
var err error
for {
_, err = unix.Poll(pollFds, -1)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun

View file

@ -1,86 +1,102 @@
package tun
import (
"encoding/binary"
"math/bits"
)
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
tmp := make([]byte, 8)
binary.NativeEndian.PutUint64(tmp, initial)
ac := binary.BigEndian.Uint64(tmp)
var carry uint64
ac := initial
for len(b) >= 128 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
ac += carry
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
b = b[128:]
}
if len(b) >= 64 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac += carry
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
b = b[64:]
}
if len(b) >= 32 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac += carry
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
if len(b) >= 16 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac += carry
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
b = b[16:]
}
if len(b) >= 8 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac += carry
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
b = b[8:]
}
if len(b) >= 4 {
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
ac += carry
ac += uint64(binary.BigEndian.Uint32(b))
b = b[4:]
}
if len(b) >= 2 {
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
ac += carry
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
ac, carry = bits.Add64(ac, uint64(tmp), 0)
ac += carry
ac += uint64(b[0]) << 8
}
binary.NativeEndian.PutUint64(tmp, ac)
return binary.BigEndian.Uint64(tmp)
return ac
}
func checksum(b []byte, initial uint64) uint16 {

View file

@ -1,74 +1,11 @@
package tun
import (
"encoding/binary"
"fmt"
"math/rand"
"testing"
"golang.org/x/sys/unix"
)
func checksumRef(b []byte, initial uint16) uint16 {
ac := uint64(initial)
for len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}
for (ac >> 16) > 0 {
ac = (ac >> 16) + (ac & 0xffff)
}
return uint16(ac)
}
func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
sum := checksumRef(srcAddr, 0)
sum = checksumRef(dstAddr, sum)
sum = checksumRef([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumRef(tmp, sum)
}
func TestChecksum(t *testing.T) {
for length := 0; length <= 9001; length++ {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
csum := checksum(buf, 0x1234)
csumRef := checksumRef(buf, 0x1234)
if csum != csumRef {
t.Error("Expected checksum", csumRef, "got", csum)
}
}
}
func TestPseudoHeaderChecksum(t *testing.T) {
for _, addrLen := range []int{4, 16} {
for length := 0; length <= 9001; length++ {
srcAddr := make([]byte, addrLen)
dstAddr := make([]byte, addrLen)
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(srcAddr)
rng.Read(dstAddr)
rng.Read(buf)
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csum := checksum(buf, phSum)
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csumRef := checksumRef(buf, phSumRef)
if csum != csumRef {
t.Error("Expected checksumRef", csumRef, "got", csum)
}
}
}
}
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main

Some files were not shown because too many files have changed in this diff Show more