diff --git a/Dockerfile b/Dockerfile index 12159be..f165899 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.24 as awg +FROM golang:1.24.4 as awg COPY . /awg WORKDIR /awg RUN go mod download && \ @@ -7,6 +7,7 @@ RUN go mod download && \ FROM alpine:3.19 ARG AWGTOOLS_RELEASE="1.0.20241018" + RUN apk --no-cache add iproute2 iptables bash && \ cd /usr/bin/ && \ wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \ diff --git a/README.md b/README.md index 853d318..428b752 100644 --- a/README.md +++ b/README.md @@ -50,3 +50,4 @@ $ git clone https://github.com/amnezia-vpn/amneziawg-go $ cd amneziawg-go $ make ``` + diff --git a/conn/bind_std.go b/conn/bind_std.go index 312a538..6908ba8 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 6cfa099..1a0e021 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 42b0bb7..25b5eab 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package bindtest diff --git a/conn/boundif_android.go b/conn/boundif_android.go index dd3ca5b..be69b2a 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/conn.go b/conn/conn.go index a1f57d2..1304657 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package conn implements WireGuard's network connections. diff --git a/conn/conn_test.go b/conn/conn_test.go index c6194ee..618d02b 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/controlfns.go b/conn/controlfns.go index 4f7d90f..27421bd 100644 --- a/conn/controlfns.go +++ b/conn/controlfns.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe..f0deefa 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -13,6 +13,35 @@ import ( "golang.org/x/sys/unix" ) +// Taken from go/src/internal/syscall/unix/kernel_version_linux.go +func kernelVersion() (major, minor int) { + var uname unix.Utsname + if err := unix.Uname(&uname); err != nil { + return + } + + var ( + values [2]int + value, vi int + ) + for _, c := range uname.Release { + if '0' <= c && c <= '9' { + value = (value * 10) + int(c-'0') + } else { + // Note that we're assuming N.N.N here. + // If we see anything else, we are likely to mis-parse it. + values[vi] = value + vi++ + if vi >= len(values) { + break + } + value = 0 + } + } + + return values[0], values[1] +} + func init() { controlFns = append(controlFns, @@ -57,5 +86,24 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + // Kernels below 5.12 are missing 98184612aca0 ("net: + // udp: Add support for getsockopt(..., ..., UDP_GRO, + // ..., ...);"), which means we can't read this back + // later. We could pipe the return value through to + // the rest of the code, but UDP_GRO is kind of buggy + // anyway, so just gate this here. + major, minor := kernelVersion() + if major < 5 || (major == 5 && minor < 12) { + return nil + } + + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, ) } diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go index 91692c0..b2e7570 100644 --- a/conn/controlfns_unix.go +++ b/conn/controlfns_unix.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go index c3bdf7d..5e38305 100644 --- a/conn/controlfns_windows.go +++ b/conn/controlfns_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/default.go b/conn/default.go index b6f761b..2ce1579 100644 --- a/conn/default.go +++ b/conn/default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/errors_default.go b/conn/errors_default.go index f1e5b90..3c9b223 100644 --- a/conn/errors_default.go +++ b/conn/errors_default.go @@ -2,11 +2,11 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn -func errShouldDisableUDPGSO(err error) bool { +func errShouldDisableUDPGSO(_ error) bool { return false } diff --git a/conn/errors_linux.go b/conn/errors_linux.go index 7548a8a..9ed7d76 100644 --- a/conn/errors_linux.go +++ b/conn/errors_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/features_default.go b/conn/features_default.go index d53ff5f..9fc5088 100644 --- a/conn/features_default.go +++ b/conn/features_default.go @@ -3,13 +3,13 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn import "net" -func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { +func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) { return } diff --git a/conn/features_linux.go b/conn/features_linux.go index a6de8c1..936029e 100644 --- a/conn/features_linux.go +++ b/conn/features_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/gso_default.go b/conn/gso_default.go index 57780db..a9a3e80 100644 --- a/conn/gso_default.go +++ b/conn/gso_default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/gso_linux.go b/conn/gso_linux.go index 8596b29..4ee31fa 100644 --- a/conn/gso_linux.go +++ b/conn/gso_linux.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/mark_default.go b/conn/mark_default.go index 3102384..72b266e 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/mark_unix.go b/conn/mark_unix.go index d9e46ee..d0580d5 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/sticky_default.go b/conn/sticky_default.go index 0b21386..15b65af 100644 --- a/conn/sticky_default.go +++ b/conn/sticky_default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go index 8e206e9..adfedc1 100644 --- a/conn/sticky_linux.go +++ b/conn/sticky_linux.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index d2bd584..1b1ee68 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go index d1037bb..c396658 100644 --- a/conn/winrio/rio_windows.go +++ b/conn/winrio/rio_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package winrio diff --git a/device/allowedips.go b/device/allowedips.go index fa46f97..d15373c 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) } } +func (node *trieEntry) remove() { + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + return + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + return + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + return + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() +} + +func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + var node *trieEntry + var exact bool + + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits())) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits())) + } else { + panic(errors.New("removing unknown address type")) + } + if !exact || node == nil || peer != node.peer { + return + } + node.remove() +} + func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() @@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() - node := elem.Value.(*trieEntry) - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] != nil && node.child[1] != nil { - continue - } - bit := 0 - if node.child[0] == nil { - bit = 1 - } - child := node.child[bit] - if child != nil { - child.parent = node.parent - } - *node.parent.parentBit = child - if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { - node.zeroizePointers() - continue - } - parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) - if parent.peer != nil { - node.zeroizePointers() - continue - } - child = parent.child[node.parent.parentBitType^1] - if child != nil { - child.parent = parent.parent - } - *parent.parent.parentBit = child - node.zeroizePointers() - parent.zeroizePointers() + elem.Value.(*trieEntry).remove() } } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 07065c3..b863696 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) { var peers []*Peer var allowedIPs AllowedIPs - rand.Seed(1) + rng := rand.New(rand.NewSource(1)) for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) @@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) { for n := 0; n < NumberOfAddresses; n++ { var addr4 [4]byte - rand.Read(addr4[:]) + rng.Read(addr4[:]) cidr := uint8(rand.Intn(32) + 1) index := rand.Intn(NumberOfPeers) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index]) var addr6 [16]byte - rand.Read(addr6[:]) + rng.Read(addr6[:]) cidr = uint8(rand.Intn(128) + 1) index = rand.Intn(NumberOfPeers) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) @@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) { for p = 0; ; p++ { for n := 0; n < NumberOfTests; n++ { var addr4 [4]byte - rand.Read(addr4[:]) + rng.Read(addr4[:]) peer1 := slow4.Lookup(addr4[:]) peer2 := allowedIPs.Lookup(addr4[:]) if peer1 != peer2 { @@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) { } var addr6 [16]byte - rand.Read(addr6[:]) + rng.Read(addr6[:]) peer1 = slow6.Lookup(addr6[:]) peer2 = allowedIPs.Lookup(addr6[:]) if peer1 != peer2 { diff --git a/device/allowedips_test.go b/device/allowedips_test.go index cde068e..a4b08a3 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) { var trie *trieEntry var peers []*Peer root := parentIndirection{&trie, 2} - rand.Seed(1) + rng := rand.New(rand.NewSource(1)) const AddressLength = 4 @@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber + rng.Read(addr[:]) + cidr := uint8(rng.Uint32() % (AddressLength * 8)) + index := rng.Int() % peerNumber root.insert(addr[:], cidr, peers[index]) } for n := 0; n < b.N; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) + rng.Read(addr[:]) trie.lookup(addr[:]) } } @@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d byte) { p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { @@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 24) + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 32) + assertEQ(a, 192, 0, 0, 1) + remove(nil, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(b, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 24) + assertNEQ(a, 192, 0, 0, 1) + remove(a, 1, 0, 0, 0, 32) + assertNEQ(a, 1, 0, 0, 0) } /* Test ported from kernel implementation: @@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d uint32) { var addr []byte addr = append(addr, expand(a)...) @@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) { } } + assertNEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := allowedIPs.Lookup(addr) + if p == peer { + t.Error("Assert NEQ failed") + } + } + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) @@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) { assertEQ(h, 0x24046800, 0x40040800, 0, 0) assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) + + insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) } diff --git a/device/awg/awg.go b/device/awg/awg.go new file mode 100644 index 0000000..fd5a96d --- /dev/null +++ b/device/awg/awg.go @@ -0,0 +1,144 @@ +package awg + +import ( + "bytes" + "fmt" + "slices" + "strconv" + "strings" + "sync" + + "github.com/tevino/abool" +) + +type aSecCfgType struct { + IsSet bool + JunkPacketCount int + JunkPacketMinSize int + JunkPacketMaxSize int + InitHeaderJunkSize int + ResponseHeaderJunkSize int + CookieReplyHeaderJunkSize int + TransportHeaderJunkSize int + InitPacketMagicHeader uint32 + ResponsePacketMagicHeader uint32 + UnderloadPacketMagicHeader uint32 + TransportPacketMagicHeader uint32 + // InitPacketMagicHeader Limit + // ResponsePacketMagicHeader Limit + // UnderloadPacketMagicHeader Limit + // TransportPacketMagicHeader Limit +} + +type Limit struct { + Min uint32 + Max uint32 + HeaderType uint32 +} + +func NewLimit(min, max, headerType uint32) (Limit, error) { + if min > max { + return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max) + } + + return Limit{ + Min: min, + Max: max, + HeaderType: headerType, + }, nil +} + +func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) { + // tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType) + // var min, max, headerType uint32 + // _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType) + // if err != nil { + // return Limit{}, fmt.Errorf("invalid magic header format: %s", value) + // } + + limits := strings.Split(value, "-") + if len(limits) != 2 { + return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value) + } + + min, err := strconv.ParseUint(limits[0], 10, 32) + if err != nil { + return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err) + } + + max, err := strconv.ParseUint(limits[1], 10, 32) + if err != nil { + return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err) + } + + limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType) + if err != nil { + return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err) + } + + return limit, nil +} + +type Limits []Limit + +func NewLimits(limits []Limit) Limits { + slices.SortFunc(limits, func(a, b Limit) int { + if a.Min < b.Min { + return -1 + } else if a.Min > b.Min { + return 1 + } + return 0 + }) + + return Limits(limits) +} + +type Protocol struct { + IsASecOn abool.AtomicBool + // TODO: revision the need of the mutex + ASecMux sync.RWMutex + ASecCfg aSecCfgType + JunkCreator junkCreator + + HandshakeHandler SpecialHandshakeHandler +} + +func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize) +} + +func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize) +} + +func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize) +} + +func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) +} + +func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) { + extraSize := 0 + if len(optExtraSize) == 1 { + extraSize = optExtraSize[0] + } + + var junk []byte + protocol.ASecMux.RLock() + if junkSize != 0 { + buf := make([]byte, 0, junkSize+extraSize) + writer := bytes.NewBuffer(buf[:0]) + err := protocol.JunkCreator.AppendJunk(writer, junkSize) + if err != nil { + protocol.ASecMux.RUnlock() + return nil, err + } + junk = writer.Bytes() + } + protocol.ASecMux.RUnlock() + + return junk, nil +} diff --git a/device/awg/internal/mock.go b/device/awg/internal/mock.go new file mode 100644 index 0000000..a2e1c95 --- /dev/null +++ b/device/awg/internal/mock.go @@ -0,0 +1,37 @@ +package internal + +type mockGenerator struct { + size int +} + +func NewMockGenerator(size int) mockGenerator { + return mockGenerator{size: size} +} + +func (m mockGenerator) Generate() []byte { + return make([]byte, m.size) +} + +func (m mockGenerator) Size() int { + return m.size +} + +func (m mockGenerator) Name() string { + return "mock" +} + +type mockByteGenerator struct { + data []byte +} + +func NewMockByteGenerator(data []byte) mockByteGenerator { + return mockByteGenerator{data: data} +} + +func (bg mockByteGenerator) Generate() []byte { + return bg.data +} + +func (bg mockByteGenerator) Size() int { + return len(bg.data) +} diff --git a/device/junk_creator.go b/device/awg/junk_creator.go similarity index 52% rename from device/junk_creator.go rename to device/awg/junk_creator.go index 3a2d3b4..91fd253 100644 --- a/device/junk_creator.go +++ b/device/awg/junk_creator.go @@ -1,4 +1,4 @@ -package device +package awg import ( "bytes" @@ -8,61 +8,62 @@ import ( ) type junkCreator struct { - device *Device + aSecCfg aSecCfgType cha8Rand *v2.ChaCha8 } -func NewJunkCreator(d *Device) (junkCreator, error) { +// TODO: refactor param to only pass the junk related params +func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { buf := make([]byte, 32) _, err := crand.Read(buf) if err != nil { return junkCreator{}, err } - return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil + return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil } // Should be called with aSecMux RLocked -func (jc *junkCreator) createJunkPackets() ([][]byte, error) { - if jc.device.aSecCfg.junkPacketCount == 0 { - return nil, nil +func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error { + if jc.aSecCfg.JunkPacketCount == 0 { + return nil } - junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) - for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + for range jc.aSecCfg.JunkPacketCount { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { - return nil, fmt.Errorf("Failed to create junk packet: %v", err) + return fmt.Errorf("create junk packet: %v", err) } - junks = append(junks, junk) + *junks = append(*junks, junk) } - return junks, nil + return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize, ), - ) + jc.device.aSecCfg.junkPacketMinSize + ) + jc.aSecCfg.JunkPacketMinSize } // Should be called with aSecMux RLocked -func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { +func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { headerJunk, err := jc.randomJunkWithSize(size) if err != nil { - return fmt.Errorf("failed to create header junk: %v", err) + return fmt.Errorf("create header junk: %v", err) } _, err = writer.Write(headerJunk) if err != nil { - return fmt.Errorf("failed to write header junk: %v", err) + return fmt.Errorf("write header junk: %v", err) } return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { + // TODO: use a memory pool to allocate junk := make([]byte, size) _, err := jc.cha8Rand.Read(junk) return junk, err diff --git a/device/junk_creator_test.go b/device/awg/junk_creator_test.go similarity index 61% rename from device/junk_creator_test.go rename to device/awg/junk_creator_test.go index d3cf2b3..424f104 100644 --- a/device/junk_creator_test.go +++ b/device/awg/junk_creator_test.go @@ -1,36 +1,27 @@ -package device +package awg import ( "bytes" "fmt" "testing" - - "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" - "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" ) func setUpJunkCreator(t *testing.T) (junkCreator, error) { - cfg, _ := genASecurityConfigs(t) - tun := tuntest.NewChannelTUN() - binds := bindtest.NewChannelBinds() - level := LogLevelVerbose - dev := NewDevice( - tun.TUN(), - binds[0], - NewLogger(level, ""), - ) - - if err := dev.IpcSet(cfg[0]); err != nil { - t.Errorf("failed to configure device %v", err) - dev.Close() - return junkCreator{}, err - } - - jc, err := NewJunkCreator(dev) + jc, err := NewJunkCreator(aSecCfgType{ + IsSet: true, + JunkPacketCount: 5, + JunkPacketMinSize: 500, + JunkPacketMaxSize: 1000, + InitHeaderJunkSize: 30, + ResponseHeaderJunkSize: 40, + InitPacketMagicHeader: 123456, + ResponsePacketMagicHeader: 67543, + UnderloadPacketMagicHeader: 32345, + TransportPacketMagicHeader: 123123, + }) if err != nil { t.Errorf("failed to create junk creator %v", err) - dev.Close() return junkCreator{}, err } @@ -42,8 +33,9 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { if err != nil { return } - t.Run("", func(t *testing.T) { - got, err := jc.createJunkPackets() + t.Run("valid", func(t *testing.T) { + got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount) + err := jc.CreateJunkPackets(&got) if err != nil { t.Errorf( "junkCreator.createJunkPackets() = %v; failed", @@ -68,7 +60,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { } func Test_junkCreator_randomJunkWithSize(t *testing.T) { - t.Run("", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { jc, err := setUpJunkCreator(t) if err != nil { return @@ -78,7 +70,6 @@ func Test_junkCreator_randomJunkWithSize(t *testing.T) { fmt.Printf("%v\n%v\n", r1, r2) if bytes.Equal(r1, r2) { t.Errorf("same junks %v", err) - jc.device.Close() return } }) @@ -90,14 +81,14 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { return } for range [30]struct{}{} { - t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || - got > jc.device.aSecCfg.junkPacketMaxSize { + t.Run("valid", func(t *testing.T) { + if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || + got > jc.aSecCfg.JunkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.aSecCfg.junkPacketMinSize, - jc.device.aSecCfg.junkPacketMaxSize, + jc.aSecCfg.JunkPacketMinSize, + jc.aSecCfg.JunkPacketMaxSize, ) } }) @@ -109,13 +100,13 @@ func Test_junkCreator_appendJunk(t *testing.T) { if err != nil { return } - t.Run("", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { s := "apple" buffer := bytes.NewBuffer([]byte(s)) - err := jc.appendJunk(buffer, 30) + err := jc.AppendJunk(buffer, 30) if err != nil && buffer.Len() != len(s)+30 { - t.Errorf("appendWithJunk() size don't match") + t.Error("appendWithJunk() size don't match") } read := make([]byte, 50) buffer.Read(read) diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go new file mode 100644 index 0000000..e582d97 --- /dev/null +++ b/device/awg/special_handshake_handler.go @@ -0,0 +1,73 @@ +package awg + +import ( + "errors" + "time" + + "github.com/tevino/abool" + "go.uber.org/atomic" +) + +// TODO: atomic?/ and better way to use this +var PacketCounter *atomic.Uint64 = atomic.NewUint64(0) + +// TODO +var WaitResponse = struct { + Channel chan struct{} + ShouldWait *abool.AtomicBool +}{ + make(chan struct{}, 1), + abool.New(), +} + +type SpecialHandshakeHandler struct { + isFirstDone bool + SpecialJunk TagJunkPacketGenerators + ControlledJunk TagJunkPacketGenerators + + nextItime time.Time + ITimeout time.Duration // seconds + + IsSet bool +} + +func (handler *SpecialHandshakeHandler) Validate() error { + var errs []error + if err := handler.SpecialJunk.Validate(); err != nil { + errs = append(errs, err) + } + if err := handler.ControlledJunk.Validate(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { + if !handler.SpecialJunk.IsDefined() { + return nil + } + + // TODO: create tests + if !handler.isFirstDone { + handler.isFirstDone = true + } else if !handler.isTimeToSendSpecial() { + return nil + } + + rv := handler.SpecialJunk.GeneratePackets() + handler.nextItime = time.Now().Add(handler.ITimeout) + + return rv +} + +func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool { + return time.Now().After(handler.nextItime) +} + +func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte { + if !handler.ControlledJunk.IsDefined() { + return nil + } + + return handler.ControlledJunk.GeneratePackets() +} diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go new file mode 100644 index 0000000..65d8004 --- /dev/null +++ b/device/awg/tag_generator.go @@ -0,0 +1,190 @@ +package awg + +import ( + crand "crypto/rand" + "encoding/binary" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + + v2 "math/rand/v2" + // "go.uber.org/atomic" +) + +type Generator interface { + Generate() []byte + Size() int +} + +type newGenerator func(string) (Generator, error) + +type BytesGenerator struct { + value []byte + size int +} + +func (bg *BytesGenerator) Generate() []byte { + return bg.value +} + +func (bg *BytesGenerator) Size() int { + return bg.size +} + +func newBytesGenerator(param string) (Generator, error) { + hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X") + if !hasPrefix { + return nil, fmt.Errorf("not correct hex: %s", param) + } + + hex, err := hexToBytes(param) + if err != nil { + return nil, fmt.Errorf("hexToBytes: %w", err) + } + + return &BytesGenerator{value: hex, size: len(hex)}, nil +} + +func hexToBytes(hexStr string) ([]byte, error) { + hexStr = strings.TrimPrefix(hexStr, "0x") + hexStr = strings.TrimPrefix(hexStr, "0X") + + // Ensure even length (pad with leading zero if needed) + if len(hexStr)%2 != 0 { + hexStr = "0" + hexStr + } + + return hex.DecodeString(hexStr) +} + +type RandomPacketGenerator struct { + cha8Rand *v2.ChaCha8 + size int +} + +func (rpg *RandomPacketGenerator) Generate() []byte { + junk := make([]byte, rpg.size) + rpg.cha8Rand.Read(junk) + return junk +} + +func (rpg *RandomPacketGenerator) Size() int { + return rpg.size +} + +func newRandomPacketGenerator(param string) (Generator, error) { + size, err := strconv.Atoi(param) + if err != nil { + return nil, fmt.Errorf("random packet parse int: %w", err) + } + + if size > 1000 { + return nil, fmt.Errorf("random packet size must be less than 1000") + } + + buf := make([]byte, 32) + _, err = crand.Read(buf) + if err != nil { + return nil, fmt.Errorf("random packet crand read: %w", err) + } + + return &RandomPacketGenerator{ + cha8Rand: v2.NewChaCha8([32]byte(buf)), + size: size, + }, nil +} + +type TimestampGenerator struct { +} + +func (tg *TimestampGenerator) Generate() []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix())) + return buf +} + +func (tg *TimestampGenerator) Size() int { + return 8 +} + +func newTimestampGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("timestamp param needs to be empty: %s", param) + } + + return &TimestampGenerator{}, nil +} + +type WaitTimeoutGenerator struct { + waitTimeout time.Duration +} + +func (wtg *WaitTimeoutGenerator) Generate() []byte { + time.Sleep(wtg.waitTimeout) + return []byte{} +} + +func (wtg *WaitTimeoutGenerator) Size() int { + return 0 +} + +func newWaitTimeoutGenerator(param string) (Generator, error) { + timeout, err := strconv.Atoi(param) + if err != nil { + return nil, fmt.Errorf("timeout parse int: %w", err) + } + + if timeout > 5000 { + return nil, fmt.Errorf("timeout must be less than 5000ms") + } + + return &WaitTimeoutGenerator{ + waitTimeout: time.Duration(timeout) * time.Millisecond, + }, nil +} + +type PacketCounterGenerator struct { +} + +func (c *PacketCounterGenerator) Generate() []byte { + buf := make([]byte, 8) + // TODO: better way to handle counter tag + binary.BigEndian.PutUint64(buf, PacketCounter.Load()) + return buf +} + +func (c *PacketCounterGenerator) Size() int { + return 8 +} + +func newPacketCounterGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("packet counter param needs to be empty: %s", param) + } + + return &PacketCounterGenerator{}, nil +} + +type WaitResponseGenerator struct { +} + +func (c *WaitResponseGenerator) Generate() []byte { + WaitResponse.ShouldWait.Set() + <-WaitResponse.Channel + WaitResponse.ShouldWait.UnSet() + return []byte{} +} + +func (c *WaitResponseGenerator) Size() int { + return 0 +} + +func newWaitResponseGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("wait response param needs to be empty: %s", param) + } + + return &WaitResponseGenerator{}, nil +} diff --git a/device/awg/tag_generator_test.go b/device/awg/tag_generator_test.go new file mode 100644 index 0000000..4950b33 --- /dev/null +++ b/device/awg/tag_generator_test.go @@ -0,0 +1,189 @@ +package awg + +import ( + "encoding/binary" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_newBytesGenerator(t *testing.T) { + type args struct { + param string + } + tests := []struct { + name string + args args + want []byte + wantErr error + }{ + { + name: "empty", + args: args{ + param: "", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "wrong start", + args: args{ + param: "123456", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "not only hex value with X", + args: args{ + param: "0X12345q", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "not only hex value with x", + args: args{ + param: "0x12345q", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "valid hex", + args: args{ + param: "0xf6ab3267fa", + }, + want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa}, + }, + { + name: "valid hex with odd length", + args: args{ + param: "0xfab3267fa", + }, + want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newBytesGenerator(tt.args.param) + + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + require.Nil(t, got) + return + } + + require.Nil(t, err) + require.NotNil(t, got) + + gotValues := got.Generate() + require.Equal(t, tt.want, gotValues) + }) + } +} + +func Test_newRandomPacketGenerator(t *testing.T) { + type args struct { + param string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "empty", + args: args{ + param: "", + }, + wantErr: fmt.Errorf("parse int"), + }, + { + name: "not an int", + args: args{ + param: "x", + }, + wantErr: fmt.Errorf("parse int"), + }, + { + name: "too large", + args: args{ + param: "1001", + }, + wantErr: fmt.Errorf("random packet size must be less than 1000"), + }, + { + name: "valid", + args: args{ + param: "12", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newRandomPacketGenerator(tt.args.param) + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + require.Nil(t, got) + return + } + + require.Nil(t, err) + require.NotNil(t, got) + first := got.Generate() + + second := got.Generate() + require.NotEqual(t, first, second) + }) + } +} + +func TestPacketCounterGenerator(t *testing.T) { + tests := []struct { + name string + param string + wantErr bool + }{ + { + name: "Valid empty param", + param: "", + wantErr: false, + }, + { + name: "Invalid non-empty param", + param: "anything", + wantErr: true, + }, + } + + for _, tc := range tests { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + gen, err := newPacketCounterGenerator(tc.param) + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, 8, gen.Size()) + + // Reset counter to known value for test + initialCount := uint64(42) + PacketCounter.Store(initialCount) + + output := gen.Generate() + require.Equal(t, 8, len(output)) + + // Verify counter value in output + counterValue := binary.BigEndian.Uint64(output) + require.Equal(t, initialCount, counterValue) + + // Increment counter and verify change + PacketCounter.Add(1) + output = gen.Generate() + counterValue = binary.BigEndian.Uint64(output) + require.Equal(t, initialCount+1, counterValue) + }) + } +} diff --git a/device/awg/tag_junk_packet_generator.go b/device/awg/tag_junk_packet_generator.go new file mode 100644 index 0000000..fdbebc8 --- /dev/null +++ b/device/awg/tag_junk_packet_generator.go @@ -0,0 +1,59 @@ +package awg + +import ( + "fmt" + "strconv" +) + +type TagJunkPacketGenerator struct { + name string + tagValue string + + packetSize int + generators []Generator +} + +func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator { + return TagJunkPacketGenerator{ + name: name, + tagValue: tagValue, + generators: make([]Generator, 0, size), + } +} + +func (tg *TagJunkPacketGenerator) append(generator Generator) { + tg.generators = append(tg.generators, generator) + tg.packetSize += generator.Size() +} + +func (tg *TagJunkPacketGenerator) generatePacket() []byte { + packet := make([]byte, 0, tg.packetSize) + for _, generator := range tg.generators { + packet = append(packet, generator.Generate()...) + } + + return packet +} + +func (tg *TagJunkPacketGenerator) Name() string { + return tg.name +} + +func (tg *TagJunkPacketGenerator) nameIndex() (int, error) { + if len(tg.name) != 2 { + return 0, fmt.Errorf("name must be 2 character long: %s", tg.name) + } + + index, err := strconv.Atoi(tg.name[1:2]) + if err != nil { + return 0, fmt.Errorf("name 2 char should be an int %w", err) + } + return index, nil +} + +func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields { + return IpcFields{ + Key: tg.name, + Value: tg.tagValue, + } +} diff --git a/device/awg/tag_junk_packet_generator_test.go b/device/awg/tag_junk_packet_generator_test.go new file mode 100644 index 0000000..309d425 --- /dev/null +++ b/device/awg/tag_junk_packet_generator_test.go @@ -0,0 +1,210 @@ +package awg + +import ( + "testing" + + "github.com/amnezia-vpn/amneziawg-go/device/awg/internal" + "github.com/stretchr/testify/require" +) + +func TestNewTagJunkGenerator(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + genName string + size int + expected TagJunkPacketGenerator + }{ + { + name: "Create new generator with empty name", + genName: "", + size: 0, + expected: TagJunkPacketGenerator{ + name: "", + packetSize: 0, + generators: make([]Generator, 0), + }, + }, + { + name: "Create new generator with valid name", + genName: "T1", + size: 0, + expected: TagJunkPacketGenerator{ + name: "T1", + packetSize: 0, + generators: make([]Generator, 0), + }, + }, + { + name: "Create new generator with non-zero size", + genName: "T2", + size: 5, + expected: TagJunkPacketGenerator{ + name: "T2", + packetSize: 0, + generators: make([]Generator, 5), + }, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := newTagJunkPacketGenerator(tc.genName, "", tc.size) + require.Equal(t, tc.expected.name, result.name) + require.Equal(t, tc.expected.packetSize, result.packetSize) + require.Equal(t, cap(result.generators), len(tc.expected.generators)) + }) + } +} + +func TestTagJunkGeneratorAppend(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + initialState TagJunkPacketGenerator + mockSize int + expectedLength int + expectedSize int + }{ + { + name: "Append to empty generator", + initialState: newTagJunkPacketGenerator("T1", "", 0), + mockSize: 5, + expectedLength: 1, + expectedSize: 5, + }, + { + name: "Append to non-empty generator", + initialState: TagJunkPacketGenerator{ + name: "T2", + packetSize: 10, + generators: make([]Generator, 2), + }, + mockSize: 7, + expectedLength: 3, // 2 existing + 1 new + expectedSize: 17, // 10 + 7 + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := tc.initialState + mockGen := internal.NewMockGenerator(tc.mockSize) + + tg.append(mockGen) + + require.Equal(t, tc.expectedLength, len(tg.generators)) + require.Equal(t, tc.expectedSize, tg.packetSize) + }) + } +} + +func TestTagJunkGeneratorGenerate(t *testing.T) { + t.Parallel() + + // Create mock generators for testing + mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02}) + mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05}) + + testCases := []struct { + name string + setupGenerator func() TagJunkPacketGenerator + expected []byte + }{ + { + name: "Generate with empty generators", + setupGenerator: func() TagJunkPacketGenerator { + return newTagJunkPacketGenerator("T1", "", 0) + }, + expected: []byte{}, + }, + { + name: "Generate with single generator", + setupGenerator: func() TagJunkPacketGenerator { + tg := newTagJunkPacketGenerator("T2", "", 0) + tg.append(mockGen1) + return tg + }, + expected: []byte{0x01, 0x02}, + }, + { + name: "Generate with multiple generators", + setupGenerator: func() TagJunkPacketGenerator { + tg := newTagJunkPacketGenerator("T3", "", 0) + tg.append(mockGen1) + tg.append(mockGen2) + return tg + }, + expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := tc.setupGenerator() + result := tg.generatePacket() + + require.Equal(t, tc.expected, result) + }) + } +} + +func TestTagJunkGeneratorNameIndex(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + generatorName string + expectedIndex int + expectError bool + }{ + { + name: "Valid name with digit", + generatorName: "T5", + expectedIndex: 5, + expectError: false, + }, + { + name: "Invalid name - too short", + generatorName: "T", + expectError: true, + }, + { + name: "Invalid name - too long", + generatorName: "T55", + expectError: true, + }, + { + name: "Invalid name - non-digit second character", + generatorName: "TX", + expectError: true, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := TagJunkPacketGenerator{name: tc.generatorName} + index, err := tg.nameIndex() + + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedIndex, index) + } + }) + } +} diff --git a/device/awg/tag_junk_packet_generators.go b/device/awg/tag_junk_packet_generators.go new file mode 100644 index 0000000..9921eb0 --- /dev/null +++ b/device/awg/tag_junk_packet_generators.go @@ -0,0 +1,66 @@ +package awg + +import "fmt" + +type TagJunkPacketGenerators struct { + tagGenerators []TagJunkPacketGenerator + length int + DefaultJunkCount int // Jc +} + +func (generators *TagJunkPacketGenerators) AppendGenerator( + generator TagJunkPacketGenerator, +) { + generators.tagGenerators = append(generators.tagGenerators, generator) + generators.length++ +} + +func (generators *TagJunkPacketGenerators) IsDefined() bool { + return len(generators.tagGenerators) > 0 +} + +// validate that packets were defined consecutively +func (generators *TagJunkPacketGenerators) Validate() error { + seen := make([]bool, len(generators.tagGenerators)) + for _, generator := range generators.tagGenerators { + index, err := generator.nameIndex() + if index > len(generators.tagGenerators) { + return fmt.Errorf("junk packet index should be consecutive") + } + if err != nil { + return fmt.Errorf("name index: %w", err) + } else { + seen[index-1] = true + } + } + + for _, found := range seen { + if !found { + return fmt.Errorf("junk packet index should be consecutive") + } + } + + return nil +} + +func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte { + var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount) + + for i, tagGenerator := range generators.tagGenerators { + rv = append(rv, make([]byte, tagGenerator.packetSize)) + copy(rv[i], tagGenerator.generatePacket()) + PacketCounter.Inc() + } + PacketCounter.Add(uint64(generators.DefaultJunkCount)) + + return rv +} + +func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields { + rv := make([]IpcFields, 0, len(tg.tagGenerators)) + for _, generator := range tg.tagGenerators { + rv = append(rv, generator.IpcGetFields()) + } + + return rv +} diff --git a/device/awg/tag_junk_packet_generators_test.go b/device/awg/tag_junk_packet_generators_test.go new file mode 100644 index 0000000..6b1fd47 --- /dev/null +++ b/device/awg/tag_junk_packet_generators_test.go @@ -0,0 +1,149 @@ +package awg + +import ( + "testing" + + "github.com/amnezia-vpn/amneziawg-go/device/awg/internal" + "github.com/stretchr/testify/require" +) + +func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) { + tests := []struct { + name string + generator TagJunkPacketGenerator + }{ + { + name: "append single generator", + generator: newTagJunkPacketGenerator("t1", "", 10), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + generators := &TagJunkPacketGenerators{} + + // Initial length should be 0 + require.Equal(t, 0, generators.length) + require.Empty(t, generators.tagGenerators) + + // After append, length should be 1 and generator should be added + generators.AppendGenerator(tt.generator) + require.Equal(t, 1, generators.length) + require.Len(t, generators.tagGenerators, 1) + require.Equal(t, tt.generator, generators.tagGenerators[0]) + }) + } +} + +func TestTagJunkGeneratorHandlerValidate(t *testing.T) { + tests := []struct { + name string + generators []TagJunkPacketGenerator + wantErr bool + errMsg string + }{ + { + name: "bad start", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t3", "", 10), + newTagJunkPacketGenerator("t4", "", 10), + }, + wantErr: true, + errMsg: "junk packet index should be consecutive", + }, + { + name: "non-consecutive indices", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t1", "", 10), + newTagJunkPacketGenerator("t3", "", 10), // Missing t2 + }, + wantErr: true, + errMsg: "junk packet index should be consecutive", + }, + { + name: "consecutive indices", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t1", "", 10), + newTagJunkPacketGenerator("t2", "", 10), + newTagJunkPacketGenerator("t3", "", 10), + newTagJunkPacketGenerator("t4", "", 10), + newTagJunkPacketGenerator("t5", "", 10), + }, + }, + { + name: "nameIndex error", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("error", "", 10), + }, + wantErr: true, + errMsg: "name must be 2 character long", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + generators := &TagJunkPacketGenerators{} + for _, gen := range tt.generators { + generators.AppendGenerator(gen) + } + + err := generators.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + return + } + require.NoError(t, err) + }) + } +} + +func TestTagJunkGeneratorHandlerGenerate(t *testing.T) { + mockByte1 := []byte{0x01, 0x02} + mockByte2 := []byte{0x03, 0x04, 0x05} + mockGen1 := internal.NewMockByteGenerator(mockByte1) + mockGen2 := internal.NewMockByteGenerator(mockByte2) + + tests := []struct { + name string + setupGenerator func() []TagJunkPacketGenerator + expected [][]byte + }{ + { + name: "generate with no default junk", + setupGenerator: func() []TagJunkPacketGenerator { + tg1 := newTagJunkPacketGenerator("t1", "", 0) + tg1.append(mockGen1) + tg1.append(mockGen2) + tg2 := newTagJunkPacketGenerator("t2", "", 0) + tg2.append(mockGen2) + tg2.append(mockGen1) + + return []TagJunkPacketGenerator{tg1, tg2} + }, + expected: [][]byte{ + append(mockByte1, mockByte2...), + append(mockByte2, mockByte1...), + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + generators := &TagJunkPacketGenerators{} + tagGenerators := tt.setupGenerator() + for _, gen := range tagGenerators { + generators.AppendGenerator(gen) + } + + result := generators.GeneratePackets() + require.Equal(t, result, tt.expected) + }) + } +} diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go new file mode 100644 index 0000000..2b09226 --- /dev/null +++ b/device/awg/tag_parser.go @@ -0,0 +1,112 @@ +package awg + +import ( + "fmt" + "maps" + "regexp" + "strings" +) + +type IpcFields struct{ Key, Value string } + +type EnumTag string + +const ( + BytesEnumTag EnumTag = "b" + CounterEnumTag EnumTag = "c" + TimestampEnumTag EnumTag = "t" + RandomBytesEnumTag EnumTag = "r" + WaitTimeoutEnumTag EnumTag = "wt" + WaitResponseEnumTag EnumTag = "wr" +) + +var generatorCreator = map[EnumTag]newGenerator{ + BytesEnumTag: newBytesGenerator, + CounterEnumTag: newPacketCounterGenerator, + TimestampEnumTag: newTimestampGenerator, + RandomBytesEnumTag: newRandomPacketGenerator, + WaitTimeoutEnumTag: newWaitTimeoutGenerator, + // WaitResponseEnumTag: newWaitResponseGenerator, +} + +// helper map to determine enumTags are unique +var uniqueTags = map[EnumTag]bool{ + CounterEnumTag: false, + TimestampEnumTag: false, +} + +type Tag struct { + Name EnumTag + Param string +} + +func parseTag(input string) (Tag, error) { + // Regular expression to match + re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`) + + match := re.FindStringSubmatch(input) + tag := Tag{ + Name: EnumTag(match[1]), + } + if len(match) > 2 && match[2] != "" { + tag.Param = strings.TrimSpace(match[2]) + } + + return tag, nil +} + +func Parse(name, input string) (TagJunkPacketGenerator, error) { + inputSlice := strings.Split(input, "<") + if len(inputSlice) <= 1 { + return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input) + } + + uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags)) + maps.Copy(uniqueTagCheck, uniqueTags) + + // skip byproduct of split + inputSlice = inputSlice[1:] + rv := newTagJunkPacketGenerator(name, input, len(inputSlice)) + for _, inputParam := range inputSlice { + if len(inputParam) <= 1 { + return TagJunkPacketGenerator{}, fmt.Errorf( + "empty tag in input: %s", + inputSlice, + ) + } else if strings.Count(inputParam, ">") != 1 { + return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input) + } + + tag, _ := parseTag(inputParam) + creator, ok := generatorCreator[tag.Name] + if !ok { + return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) + } + if present, ok := uniqueTagCheck[tag.Name]; ok { + if present { + return TagJunkPacketGenerator{}, fmt.Errorf( + "tag %s needs to be unique", + tag.Name, + ) + } + uniqueTagCheck[tag.Name] = true + } + generator, err := creator(tag.Param) + if err != nil { + return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err) + } + + // TODO: handle counter tag + // if tag.Name == CounterEnumTag { + // packetCounter, ok := generator.(*PacketCounterGenerator) + // if !ok { + // log.Fatalf("packet counter generator expected, got %T", generator) + // } + // PacketCounter = packetCounter.counter + // } + + rv.append(generator) + } + + return rv, nil +} diff --git a/device/awg/tag_parser_test.go b/device/awg/tag_parser_test.go new file mode 100644 index 0000000..8f828ec --- /dev/null +++ b/device/awg/tag_parser_test.go @@ -0,0 +1,77 @@ +package awg + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + type args struct { + name string + input string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "invalid name", + args: args{name: "apple", input: ""}, + wantErr: fmt.Errorf("ill formated input"), + }, + { + name: "empty", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("ill formated input"), + }, + { + name: "extra >", + args: args{name: "i1", input: ">"}, + wantErr: fmt.Errorf("ill formated input"), + }, + { + name: "extra <", + args: args{name: "i1", input: "<"}, + wantErr: fmt.Errorf("empty tag in input"), + }, + { + name: "empty <>", + args: args{name: "i1", input: "<>"}, + wantErr: fmt.Errorf("empty tag in input"), + }, + { + name: "invalid tag", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("invalid tag"), + }, + { + name: "counter uniqueness violation", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, + { + name: "timestamp uniqueness violation", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, + { + name: "valid", + args: args{input: ""}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Parse(tt.args.name, tt.args.input) + + // TODO: ErrorAs doesn't work as you think + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + return + } + require.Nil(t, err) + }) + } +} diff --git a/device/bind_test.go b/device/bind_test.go index 34d1c4a..24dec1f 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/channels.go b/device/channels.go index e526f6b..be15d1c 100644 --- a/device/channels.go +++ b/device/channels.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/constants.go b/device/constants.go index 59854a1..41da618 100644 --- a/device/constants.go +++ b/device/constants.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/cookie.go b/device/cookie.go index 876f05d..a093c8b 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/cookie_test.go b/device/cookie_test.go index 4f1e50a..c937290 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/device.go b/device/device.go index 1be15d0..1829352 100644 --- a/device/device.go +++ b/device/device.go @@ -1,24 +1,60 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( + "errors" "runtime" "sync" "sync/atomic" "time" "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" "github.com/amnezia-vpn/amneziawg-go/tun" - "github.com/tevino/abool/v2" ) +type Version uint8 + +const ( + VersionDefault Version = iota + VersionAwg + VersionAwgSpecialHandshake +) + +// TODO: +type AtomicVersion struct { + value atomic.Uint32 +} + +func NewAtomicVersion(v Version) *AtomicVersion { + av := &AtomicVersion{} + av.Store(v) + return av +} + +func (av *AtomicVersion) Load() Version { + return Version(av.value.Load()) +} + +func (av *AtomicVersion) Store(v Version) { + av.value.Store(uint32(v)) +} + +func (av *AtomicVersion) CompareAndSwap(old, new Version) bool { + return av.value.CompareAndSwap(uint32(old), uint32(new)) +} + +func (av *AtomicVersion) Swap(new Version) Version { + return Version(av.value.Swap(uint32(new))) +} + type Device struct { state struct { // state holds the device's state. It is accessed atomically. @@ -92,23 +128,8 @@ type Device struct { closed chan struct{} log *Logger - isASecOn abool.AtomicBool - aSecMux sync.RWMutex - aSecCfg aSecCfgType - junkCreator junkCreator -} - -type aSecCfgType struct { - isSet bool - junkPacketCount int - junkPacketMinSize int - junkPacketMaxSize int - initPacketJunkSize int - responsePacketJunkSize int - initPacketMagicHeader uint32 - responsePacketMagicHeader uint32 - underloadPacketMagicHeader uint32 - transportPacketMagicHeader uint32 + version Version + awg awg.Protocol } // deviceState represents the state of a Device. @@ -557,251 +578,261 @@ func (device *Device) BindClose() error { device.net.Unlock() return err } -func (device *Device) isAdvancedSecurityOn() bool { - return device.isASecOn.IsSet() +func (device *Device) isAWG() bool { + return device.version >= VersionAwg } func (device *Device) resetProtocol() { // restore default message type values - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 + MessageInitiationType = DefaultMessageInitiationType + MessageResponseType = DefaultMessageResponseType + MessageCookieReplyType = DefaultMessageCookieReplyType + MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { - - if !tempASecCfg.isSet { - return err +func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { + if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet { + return nil } + var errs []error + isASecOn := false - device.aSecMux.Lock() - if tempASecCfg.junkPacketCount < 0 { - err = ipcErrorf( + device.awg.ASecMux.Lock() + if tempAwg.ASecCfg.JunkPacketCount < 0 { + errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", + ), ) } - device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount - if tempASecCfg.junkPacketCount != 0 { + device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount + if tempAwg.ASecCfg.JunkPacketCount != 0 { isASecOn = true } - device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize - if tempASecCfg.junkPacketMinSize != 0 { + device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize + if tempAwg.ASecCfg.JunkPacketMinSize != 0 { isASecOn = true } - if device.aSecCfg.junkPacketCount > 0 && - tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { + if device.awg.ASecCfg.JunkPacketCount > 0 && + tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize { - tempASecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work } - if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.aSecCfg.junkPacketMinSize = 0 - device.aSecCfg.junkPacketMaxSize = 1 - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempASecCfg.junkPacketMaxSize, - MaxSegmentSize, - ) - } - } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, - ) - } + if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize { + device.awg.ASecCfg.JunkPacketMinSize = 0 + device.awg.ASecCfg.JunkPacketMaxSize = 1 + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + MaxSegmentSize, + )) + } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "maxSize: %d; should be greater than minSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + tempAwg.ASecCfg.JunkPacketMinSize, + )) } else { - device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize } - if tempASecCfg.junkPacketMaxSize != 0 { + if tempAwg.ASecCfg.JunkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.initPacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.initPacketJunkSize, - MaxSegmentSize, - ) - } + newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize + + if newInitSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.InitHeaderJunkSize, + MaxSegmentSize, + ), + ) } else { - device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize } - if tempASecCfg.initPacketJunkSize != 0 { + if tempAwg.ASecCfg.InitHeaderJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.responsePacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.responsePacketJunkSize, - MaxSegmentSize, - ) - } + newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize + + if newResponseSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.ResponseHeaderJunkSize, + MaxSegmentSize, + ), + ) } else { - device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize } - if tempASecCfg.responsePacketJunkSize != 0 { + if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 { isASecOn = true } - if tempASecCfg.initPacketMagicHeader > 4 { + newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize + + if newCookieSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.CookieReplyHeaderJunkSize, + MaxSegmentSize, + ), + ) + } else { + device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize + } + + if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + isASecOn = true + } + + newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize + + if newTransportSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.TransportHeaderJunkSize, + MaxSegmentSize, + ), + ) + } else { + device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize + } + + if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 { + isASecOn = true + } + + if tempAwg.ASecCfg.InitPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader - MessageInitiationType = device.aSecCfg.initPacketMagicHeader + device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader + MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") - MessageInitiationType = 1 + MessageInitiationType = DefaultMessageInitiationType } - if tempASecCfg.responsePacketMagicHeader > 4 { + if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader - MessageResponseType = device.aSecCfg.responsePacketMagicHeader + device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader + MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") - MessageResponseType = 2 + MessageResponseType = DefaultMessageResponseType } - if tempASecCfg.underloadPacketMagicHeader > 4 { + if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader + MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") - MessageCookieReplyType = 3 + MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempASecCfg.transportPacketMagicHeader > 4 { + if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader - MessageTransportType = device.aSecCfg.transportPacketMagicHeader + device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader + MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") - MessageTransportType = 4 + MessageTransportType = DefaultMessageTransportType } - isSameMap := map[uint32]bool{} - isSameMap[MessageInitiationType] = true - isSameMap[MessageResponseType] = true - isSameMap[MessageCookieReplyType] = true - isSameMap[MessageTransportType] = true + isSameHeaderMap := map[uint32]struct{}{ + MessageInitiationType: {}, + MessageResponseType: {}, + MessageCookieReplyType: {}, + MessageTransportType: {}, + } // size will be different if same values - if len(isSameMap) != 4 { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - ) + if len(isSameHeaderMap) != 4 { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, + MessageInitiationType, + MessageResponseType, + MessageCookieReplyType, + MessageTransportType, + ), + ) + } + + isSameSizeMap := map[int]struct{}{ + newInitSize: {}, + newResponseSize: {}, + newCookieSize: {}, + newTransportSize: {}, + } + + if len(isSameSizeMap) != 4 { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`, + newInitSize, + newResponseSize, + newCookieSize, + newTransportSize, + ), + ) + } else { + msgTypeToJunkSize = map[uint32]int{ + MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize, + MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize, + MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize, + MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize, + } + + packetSizeToMsgType = map[int]uint32{ + newInitSize: MessageInitiationType, + newResponseSize: MessageResponseType, + newCookieSize: MessageCookieReplyType, + newTransportSize: MessageTransportType, } } - newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize + device.awg.IsASecOn.SetTo(isASecOn) + var err error + device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg) + if err != nil { + errs = append(errs, err) + } - if newInitSize == newResponseSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ; %w`, - newInitSize, - newResponseSize, - err, - ) + if tempAwg.HandshakeHandler.IsSet { + if err := tempAwg.HandshakeHandler.Validate(); err != nil { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ`, - newInitSize, - newResponseSize, - ) + device.awg.HandshakeHandler = tempAwg.HandshakeHandler + device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.version = VersionAwgSpecialHandshake } } else { - packetSizeToMsgType = map[int]uint32{ - newInitSize: MessageInitiationType, - newResponseSize: MessageResponseType, - MessageCookieReplySize: MessageCookieReplyType, - MessageTransportSize: MessageTransportType, - } - - msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.aSecCfg.initPacketJunkSize, - MessageResponseType: device.aSecCfg.responsePacketJunkSize, - MessageCookieReplyType: 0, - MessageTransportType: 0, - } + device.version = VersionAwg } - device.isASecOn.SetTo(isASecOn) - device.junkCreator, err = NewJunkCreator(device) - device.aSecMux.Unlock() + device.awg.ASecMux.Unlock() - return err + return errors.Join(errs...) } diff --git a/device/device_test.go b/device/device_test.go index d03610f..5824cf9 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,25 +1,28 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( "bytes" + "context" "encoding/hex" "fmt" "io" "math/rand" "net/netip" "os" + "os/signal" "runtime" "runtime/pprof" "sync" - "sync/atomic" "testing" "time" + "go.uber.org/atomic" + "github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" "github.com/amnezia-vpn/amneziawg-go/tun" @@ -50,7 +53,7 @@ func uapiCfg(cfg ...string) string { // genConfigs generates a pair of configs that connect to each other. // The configs use distinct, probably-usable ports. -func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { +func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) { var key1, key2 NoisePrivateKey _, err := rand.Read(key1[:]) if err != nil { @@ -62,7 +65,8 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { } pub1, pub2 := key1.publicKey(), key2.publicKey() - cfgs[0] = uapiCfg( + args0 := append([]string(nil), cfg...) + args0 = append(args0, []string{ "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", @@ -70,12 +74,16 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "protocol_version", "1", "replace_allowed_ips", "true", "allowed_ip", "1.0.0.2/32", - ) + }...) + cfgs[0] = uapiCfg(args0...) + endpointCfgs[0] = uapiCfg( "public_key", hex.EncodeToString(pub2[:]), "endpoint", "127.0.0.1:%d", ) - cfgs[1] = uapiCfg( + + args1 := append([]string(nil), cfg...) + args1 = append(args1, []string{ "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", @@ -83,66 +91,9 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "protocol_version", "1", "replace_allowed_ips", "true", "allowed_ip", "1.0.0.1/32", - ) - endpointCfgs[1] = uapiCfg( - "public_key", hex.EncodeToString(pub1[:]), - "endpoint", "127.0.0.1:%d", - ) - return -} + }...) -func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { - var key1, key2 NoisePrivateKey - _, err := rand.Read(key1[:]) - if err != nil { - tb.Errorf("unable to generate private key random bytes: %v", err) - } - _, err = rand.Read(key2[:]) - if err != nil { - tb.Errorf("unable to generate private key random bytes: %v", err) - } - pub1, pub2 := key1.publicKey(), key2.publicKey() - - cfgs[0] = uapiCfg( - "private_key", hex.EncodeToString(key1[:]), - "listen_port", "0", - "replace_peers", "true", - "jc", "5", - "jmin", "500", - "jmax", "1000", - "s1", "30", - "s2", "40", - "h1", "123456", - "h2", "67543", - "h4", "32345", - "h3", "123123", - "public_key", hex.EncodeToString(pub2[:]), - "protocol_version", "1", - "replace_allowed_ips", "true", - "allowed_ip", "1.0.0.2/32", - ) - endpointCfgs[0] = uapiCfg( - "public_key", hex.EncodeToString(pub2[:]), - "endpoint", "127.0.0.1:%d", - ) - cfgs[1] = uapiCfg( - "private_key", hex.EncodeToString(key2[:]), - "listen_port", "0", - "replace_peers", "true", - "jc", "5", - "jmin", "500", - "jmax", "1000", - "s1", "30", - "s2", "40", - "h1", "123456", - "h2", "67543", - "h4", "32345", - "h3", "123123", - "public_key", hex.EncodeToString(pub1[:]), - "protocol_version", "1", - "replace_allowed_ips", "true", - "allowed_ip", "1.0.0.1/32", - ) + cfgs[1] = uapiCfg(args1...) endpointCfgs[1] = uapiCfg( "public_key", hex.EncodeToString(pub1[:]), "endpoint", "127.0.0.1:%d", @@ -185,9 +136,10 @@ func (pair *testPair) Send( // pong is the new ping p0, p1 = p1, p0 } + msg := tuntest.Ping(p0.ip, p1.ip) p1.tun.Outbound <- msg - timer := time.NewTimer(5 * time.Second) + timer := time.NewTimer(6 * time.Second) defer timer.Stop() var err error select { @@ -214,14 +166,12 @@ func (pair *testPair) Send( // genTestPair creates a testPair. func genTestPair( tb testing.TB, - realSocket, withASecurity bool, + realSocket bool, + extraCfg ...string, ) (pair testPair) { var cfg, endpointCfg [2]string - if withASecurity { - cfg, endpointCfg = genASecurityConfigs(tb) - } else { - cfg, endpointCfg = genConfigs(tb) - } + cfg, endpointCfg = genConfigs(tb, extraCfg...) + var binds [2]conn.Bind if realSocket { binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() @@ -265,7 +215,7 @@ func genTestPair( func TestTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) - pair := genTestPair(t, true, false) + pair := genTestPair(t, true) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) @@ -274,9 +224,23 @@ func TestTwoDevicePing(t *testing.T) { }) } -func TestASecurityTwoDevicePing(t *testing.T) { +// Run test with -race=false to avoid the race for setting the default msgTypes 2 times +func TestAWGDevicePing(t *testing.T) { goroutineLeakCheck(t) - pair := genTestPair(t, true, true) + + pair := genTestPair(t, true, + "jc", "5", + "jmin", "500", + "jmax", "1000", + "s1", "30", + "s2", "40", + "s3", "50", + "s4", "5", + "h1", "123456", + "h2", "67543", + "h3", "123123", + "h4", "32345", + ) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) @@ -285,13 +249,58 @@ func TestASecurityTwoDevicePing(t *testing.T) { }) } +// Needs to be stopped with Ctrl-C +func TestAWGHandshakeDevicePing(t *testing.T) { + t.Skip("This test is intended to be run manually, not as part of the test suite.") + + signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + isRunning := atomic.NewBool(true) + go func() { + <-signalContext.Done() + fmt.Println("Waiting to finish") + isRunning.Store(false) + }() + + goroutineLeakCheck(t) + pair := genTestPair(t, true, + "i1", "", + "i2", "", + "j1", "", + "j2", "", + "j3", "", + "itime", "60", + // "jc", "1", + // "jmin", "500", + // "jmax", "1000", + // "s1", "30", + // "s2", "40", + // "h1", "123456", + // "h2", "67543", + // "h4", "32345", + // "h3", "123123", + ) + t.Run("ping 1.0.0.1", func(t *testing.T) { + for isRunning.Load() { + pair.Send(t, Ping, nil) + time.Sleep(2 * time.Second) + } + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + for isRunning.Load() { + pair.Send(t, Pong, nil) + time.Sleep(2 * time.Second) + } + }) +} + func TestUpDown(t *testing.T) { goroutineLeakCheck(t) const itrials = 50 const otrials = 10 for n := 0; n < otrials; n++ { - pair := genTestPair(t, false, false) + pair := genTestPair(t, false) for i := range pair { for k := range pair[i].dev.peers.keyMap { pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) @@ -325,7 +334,7 @@ func TestUpDown(t *testing.T) { // TestConcurrencySafety does other things concurrently with tunnel use. // It is intended to be used with the race detector to catch data races. func TestConcurrencySafety(t *testing.T) { - pair := genTestPair(t, true, false) + pair := genTestPair(t, true) done := make(chan struct{}) const warmupIters = 10 @@ -406,7 +415,7 @@ func TestConcurrencySafety(t *testing.T) { } func BenchmarkLatency(b *testing.B) { - pair := genTestPair(b, true, false) + pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) @@ -420,7 +429,7 @@ func BenchmarkLatency(b *testing.B) { } func BenchmarkThroughput(b *testing.B) { - pair := genTestPair(b, true, false) + pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) @@ -464,7 +473,7 @@ func BenchmarkThroughput(b *testing.B) { } func BenchmarkUAPIGet(b *testing.B) { - pair := genTestPair(b, true, false) + pair := genTestPair(b, true) pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) b.ReportAllocs() diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 93a4998..85482d8 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/indextable.go b/device/indextable.go index 00ade7d..2460fa6 100644 --- a/device/indextable.go +++ b/device/indextable.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/ip.go b/device/ip.go index eaf2363..f558744 100644 --- a/device/ip.go +++ b/device/ip.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/kdf_test.go b/device/kdf_test.go index f9c76d6..325db59 100644 --- a/device/kdf_test.go +++ b/device/kdf_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/keypair.go b/device/keypair.go index cc2941a..05bce68 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/logger.go b/device/logger.go index 22b0df0..a2adea3 100644 --- a/device/logger.go +++ b/device/logger.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/mobilequirks.go b/device/mobilequirks.go index 0a0080e..af4be31 100644 --- a/device/mobilequirks.go +++ b/device/mobilequirks.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/noise-helpers.go b/device/noise-helpers.go index c2f356b..35dd907 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1289249..f637b24 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -52,11 +52,18 @@ const ( WGLabelCookie = "cookie--" ) +const ( + DefaultMessageInitiationType uint32 = 1 + DefaultMessageResponseType uint32 = 2 + DefaultMessageCookieReplyType uint32 = 3 + DefaultMessageTransportType uint32 = 4 +) + var ( - MessageInitiationType uint32 = 1 - MessageResponseType uint32 = 2 - MessageCookieReplyType uint32 = 3 - MessageTransportType uint32 = 4 + MessageInitiationType uint32 = DefaultMessageInitiationType + MessageResponseType uint32 = DefaultMessageResponseType + MessageCookieReplyType uint32 = DefaultMessageCookieReplyType + MessageTransportType uint32 = DefaultMessageTransportType ) const ( @@ -75,9 +82,10 @@ const ( MessageTransportOffsetContent = 16 ) -var packetSizeToMsgType map[int]uint32 - -var msgTypeToJunkSize map[uint32]int +var ( + packetSizeToMsgType map[int]uint32 + msgTypeToJunkSize map[uint32]int +) /* Type is an 8-bit field, followed by 3 nul bytes, * by marshalling the messages in little-endian byteorder @@ -197,12 +205,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.aSecMux.RLock() + device.awg.ASecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -256,12 +264,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageInitiationType { - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -376,9 +384,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.aSecMux.RLock() + device.awg.ASecMux.RLock() msg.Type = MessageResponseType - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -428,12 +436,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageResponseType { - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() // lookup handshake by receiver diff --git a/device/noise-types.go b/device/noise-types.go index e850359..41c944e 100644 --- a/device/noise-types.go +++ b/device/noise-types.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/noise_test.go b/device/noise_test.go index 075b6d3..8f72f29 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/peer.go b/device/peer.go index 5bc8ca4..e8a5168 100644 --- a/device/peer.go +++ b/device/peer.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -13,6 +13,7 @@ import ( "time" "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device/awg" ) type Peer struct { @@ -113,6 +114,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error { + err := peer.SendBuffers(buffers) + if err == nil { + awg.PacketCounter.Add(uint64(len(buffers))) + return nil + } + + return err +} + func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() diff --git a/device/pools.go b/device/pools.go index 94f3dc7..2c18f41 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,20 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( "sync" - "sync/atomic" ) type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count atomic.Uint32 + count uint32 // Get calls not yet Put back max uint32 } @@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool { func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for p.count.Load() >= p.max { + for p.count >= p.max { p.cond.Wait() } - p.count.Add(1) + p.count++ p.lock.Unlock() } return p.pool.Get() @@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) { if p.max == 0 { return } - p.count.Add(^uint32(0)) + p.lock.Lock() + defer p.lock.Unlock() + p.count-- p.cond.Signal() } diff --git a/device/pools_test.go b/device/pools_test.go index 82d7493..8381d5a 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) { wg.Add(workers) var max atomic.Uint32 updateMax := func() { - count := p.count.Load() + p.lock.Lock() + count := p.count + p.lock.Unlock() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) } diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 1bff95a..741fcf3 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index 0061b63..f19e9b1 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index acd3cec..632e29d 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go index 1eee32b..9a296d6 100644 --- a/device/queueconstants_windows.go +++ b/device/queueconstants_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go index bb5c450..14b3284 100644 --- a/device/race_disabled_test.go +++ b/device/race_disabled_test.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go index 4e9daea..f1ea5cf 100644 --- a/device/race_enabled_test.go +++ b/device/race_enabled_test.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/receive.go b/device/receive.go index 66c1a32..6daba0d 100644 --- a/device/receive.go +++ b/device/receive.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming( } // check size of packet - packet := bufsArrs[i][:size] var msgType uint32 - if device.isAdvancedSecurityOn() { + if device.isAWG() { + // TODO: + // if awg.WaitResponse.ShouldWait.IsSet() { + // awg.WaitResponse.Channel <- struct{}{} + // } + if assumedMsgType, ok := packetSizeToMsgType[size]; ok { junkSize := msgTypeToJunkSize[assumedMsgType] // transport size can align with other header types; @@ -149,19 +153,29 @@ func (device *Device) RoutineReceiveIncoming( if msgType == assumedMsgType { packet = packet[junkSize:] } else { - device.log.Verbosef("Transport packet lined up with another msg type") + device.log.Verbosef("transport packet lined up with another msg type") msgType = binary.LittleEndian.Uint32(packet[:4]) } } else { - msgType = binary.LittleEndian.Uint32(packet[:4]) + transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize + msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4]) if msgType != MessageTransportType { - device.log.Verbosef("ASec: Received message with unknown type") + // probably a junk packet + device.log.Verbosef("aSec: Received message with unknown type: %d", msgType) continue } + + // remove junk from bufsArrs by shifting the packet + // this buffer is also used for decryption, so it needs to be corrected + copy(bufsArrs[i][:size], packet[transportJunkSize:]) + size -= transportJunkSize + // need to reinitialize packet as well + packet = packet[:size] } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) } + switch msgType { // check if transport @@ -245,7 +259,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -304,7 +318,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle cookie fields and ratelimiting @@ -456,7 +470,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index 7eca099..04ca2ad 100644 --- a/device/send.go +++ b/device/send.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -124,12 +124,30 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } var sendBuffer [][]byte + // so only packet processed for cookie generation var junkedHeader []byte - if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - junks, err := peer.device.junkCreator.createJunkPackets() - peer.device.aSecMux.RUnlock() + if peer.device.version >= VersionAwg { + var junks [][]byte + if peer.device.version == VersionAwgSpecialHandshake { + peer.device.awg.ASecMux.RLock() + // set junks depending on packet type + junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() + if junks == nil { + junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk() + if junks != nil { + peer.device.log.Verbosef("%v - Controlled junks sent", peer) + } + } else { + peer.device.log.Verbosef("%v - Special junks sent", peer) + } + peer.device.awg.ASecMux.RUnlock() + } else { + junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount) + } + peer.device.awg.ASecMux.RLock() + err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks) + peer.device.awg.ASecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -145,19 +163,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) - writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) - if err != nil { - peer.device.log.Errorf("%v - %v", peer, err) - peer.device.aSecMux.RUnlock() - return err - } - junkedHeader = writer.Bytes() + junkedHeader, err = peer.device.awg.CreateInitHeaderJunk() + if err != nil { + peer.device.log.Errorf("%v - %v", peer, err) + return err } - peer.device.aSecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -172,7 +182,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { sendBuffer = append(sendBuffer, junkedHeader) - err = peer.SendBuffers(sendBuffer) + err = peer.SendAndCountBuffers(sendBuffer) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -193,22 +203,13 @@ func (peer *Peer) SendHandshakeResponse() error { peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } - var junkedHeader []byte - if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) - writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) - if err != nil { - peer.device.aSecMux.RUnlock() - peer.device.log.Errorf("%v - %v", peer, err) - return err - } - junkedHeader = writer.Bytes() - } - peer.device.aSecMux.RUnlock() + + junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk() + if err != nil { + peer.device.log.Errorf("%v - %v", peer, err) + return err } + var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -228,7 +229,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{junkedHeader}) + err = peer.SendAndCountBuffers([][]byte{junkedHeader}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -251,11 +252,19 @@ func (device *Device) SendHandshakeCookie( return err } + junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk() + if err != nil { + device.log.Errorf("%v - %v", device, err) + return err + } + var buf [MessageCookieReplySize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) + + junkedHeader = append(junkedHeader, writer.Bytes()...) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint) return nil } @@ -568,6 +577,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } + device.PutOutboundElementsContainer(elemsContainer) continue } dataSent := false @@ -575,6 +585,14 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true + + junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet)) + if err != nil { + device.log.Errorf("%v - %v", device, err) + continue + } + + elem.packet = append(junkedHeader, elem.packet...) } bufs = append(bufs, elem.packet) } @@ -582,10 +600,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendAndCountBuffers(bufs) if dataSent { peer.timersDataSent() } + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) diff --git a/device/sticky_default.go b/device/sticky_default.go index da776e8..1751927 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -7,6 +7,6 @@ import ( "github.com/amnezia-vpn/amneziawg-go/rwcancel" ) -func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { +func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) { return nil, nil } diff --git a/device/sticky_linux.go b/device/sticky_linux.go index 63164a7..2edb628 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after * WireGuard's kernelspace implementation. This is more or less a straight port @@ -9,7 +9,7 @@ * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. + * So this code remains platform dependent. */ package device @@ -47,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er return netlinkCancel, nil } -func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { +func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { type peerEndpointPtr struct { peer *Peer endpoint *conn.Endpoint diff --git a/device/timers.go b/device/timers.go index d4a4ed4..32519aa 100644 --- a/device/timers.go +++ b/device/timers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. * * This is based heavily on timers.c from the kernel implementation. */ diff --git a/device/tun.go b/device/tun.go index 600a5e5..42178b2 100644 --- a/device/tun.go +++ b/device/tun.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/uapi.go b/device/uapi.go index 777bdda..e9f962a 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -97,33 +98,51 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } - if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.aSecCfg.junkPacketCount) + if device.isAWG() { + if device.awg.ASecCfg.JunkPacketCount != 0 { + sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount) } - if device.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + if device.awg.ASecCfg.JunkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize) } - if device.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + if device.awg.ASecCfg.JunkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) } - if device.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + if device.awg.ASecCfg.InitHeaderJunkSize != 0 { + sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize) } - if device.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 { + sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize) } - if device.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize) } - if device.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + if device.awg.ASecCfg.TransportHeaderJunkSize != 0 { + sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize) } - if device.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + if device.awg.ASecCfg.InitPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) } - if device.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) + } + if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) + } + if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) + } + + specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() + for _, field := range specialJunkIpcFields { + sendf("%s=%s", field.Key, field.Value) + } + controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields() + for _, field := range controlledJunkIpcFields { + sendf("%s=%s", field.Key, field.Value) + } + if device.awg.HandshakeHandler.ITimeout != 0 { + sendf("itime=%d", device.awg.HandshakeHandler.ITimeout/time.Second) } } @@ -180,13 +199,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempASecCfg := aSecCfgType{} + tempAwg := awg.Protocol{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempASecCfg) + err := device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -217,7 +236,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempASecCfg) + err = device.handleDeviceLine(key, value, &tempAwg) } else { err = device.handlePeerLine(peer, key, value) } @@ -225,7 +244,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempASecCfg) + err = device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -237,7 +256,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error { switch key { case "private_key": var sk NoisePrivateKey @@ -278,7 +297,11 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy case "replace_peers": if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set replace_peers, invalid value: %v", + value, + ) } device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() @@ -286,80 +309,138 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempASecCfg.junkPacketCount = junkPacketCount - tempASecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketCount = junkPacketCount + tempAwg.ASecCfg.IsSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempASecCfg.junkPacketMinSize = junkPacketMinSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize + tempAwg.ASecCfg.IsSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempASecCfg.junkPacketMaxSize = junkPacketMaxSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize + tempAwg.ASecCfg.IsSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempASecCfg.initPacketJunkSize = initPacketJunkSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize + tempAwg.ASecCfg.IsSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize + tempAwg.ASecCfg.IsSet = true + + case "s3": + cookieReplyPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err) + } + device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size") + tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize + tempAwg.ASecCfg.IsSet = true + + case "s4": + transportPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err) + } + device.log.Verbosef("UAPI: Updating transport_packet_junk_size") + tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize + tempAwg.ASecCfg.IsSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) } - tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) } - tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) } - tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) + } + tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true + case "i1", "i2", "i3", "i4", "i5": + if len(value) == 0 { + device.log.Verbosef("UAPI: received empty %s", key) + return nil } - tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempASecCfg.isSet = true + generators, err := awg.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true + case "j1", "j2", "j3": + if len(value) == 0 { + device.log.Verbosef("UAPI: received empty %s", key) + return nil + } + + generators, err := awg.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + + tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true + case "itime": + if len(value) == 0 { + device.log.Verbosef("UAPI: received empty itime") + return nil + } + + itime, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err) + } + device.log.Verbosef("UAPI: Updating itime") + + tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second + tempAwg.HandshakeHandler.IsSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } @@ -432,7 +513,11 @@ func (device *Device) handlePeerLine( case "update_only": // allow disabling of creation if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set update only, invalid value: %v", + value, + ) } if peer.created && !peer.dummy { device.RemovePeer(peer.handshake.remoteStatic) @@ -478,7 +563,11 @@ func (device *Device) handlePeerLine( secs, err := strconv.ParseUint(value, 10, 16) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set persistent keepalive interval: %w", + err, + ) } old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) @@ -489,7 +578,11 @@ func (device *Device) handlePeerLine( case "replace_allowed_ips": device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to replace allowedips, invalid value: %v", + value, + ) } if peer.dummy { return nil @@ -497,7 +590,14 @@ func (device *Device) handlePeerLine( device.allowedips.RemoveByPeer(peer.Peer) case "allowed_ip": - device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + add := true + verb := "Adding" + if len(value) > 0 && value[0] == '-' { + add = false + verb = "Removing" + value = value[1:] + } + device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb) prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) @@ -505,7 +605,11 @@ func (device *Device) handlePeerLine( if peer.dummy { return nil } - device.allowedips.Insert(prefix, peer.Peer) + if add { + device.allowedips.Insert(prefix, peer.Peer) + } else { + device.allowedips.Remove(prefix, peer.Peer) + } case "protocol_version": if value != "1" { @@ -557,7 +661,11 @@ func (device *Device) IpcHandle(socket net.Conn) { return } if nextByte != '\n' { - err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + err = ipcErrorf( + ipc.IpcErrorInvalid, + "trailing character in UAPI get: %q", + nextByte, + ) break } err = device.IpcGetOperation(buffered.Writer) diff --git a/format_test.go b/format_test.go index 6f6cab7..4d02c48 100644 --- a/format_test.go +++ b/format_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/go.mod b/go.mod index 608969f..5e5f34d 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,23 @@ module github.com/amnezia-vpn/amneziawg-go -go 1.24 +go 1.24.4 require ( + github.com/stretchr/testify v1.10.0 + github.com/tevino/abool v1.2.0 github.com/tevino/abool/v2 v2.1.0 - golang.org/x/crypto v0.36.0 - golang.org/x/net v0.37.0 - golang.org/x/sys v0.31.0 + go.uber.org/atomic v1.11.0 + golang.org/x/crypto v0.39.0 + golang.org/x/net v0.41.0 + golang.org/x/sys v0.33.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 + gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.3 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/time v0.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 497f949..6b8f36b 100644 --- a/go.sum +++ b/go.sum @@ -1,20 +1,40 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= +github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ= -gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE= +gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk= +gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk= +gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go index ddcaf27..fd433a5 100644 --- a/ipc/uapi_bsd.go +++ b/ipc/uapi_bsd.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go index 9738aea..058e8e7 100644 --- a/ipc/uapi_linux.go +++ b/ipc/uapi_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go index 0da452a..79604ee 100644 --- a/ipc/uapi_unix.go +++ b/ipc/uapi_unix.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_wasm.go b/ipc/uapi_wasm.go index fa84684..50ac091 100644 --- a/ipc/uapi_wasm.go +++ b/ipc/uapi_wasm.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index 31d2a63..321fe60 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/main.go b/main.go index 5a3dfef..f8fded9 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/main_windows.go b/main_windows.go index bbfa690..d3e2fe6 100644 --- a/main_windows.go +++ b/main_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index f7d05ef..ac69e3a 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ratelimiter diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index 0bfa3af..71140da 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ratelimiter diff --git a/replay/replay.go b/replay/replay.go index 8b99e23..46e224d 100644 --- a/replay/replay.go +++ b/replay/replay.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. diff --git a/replay/replay_test.go b/replay/replay_test.go index 9a9e4a8..8378ec3 100644 --- a/replay/replay_test.go +++ b/replay/replay_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package replay diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index e397c0e..4372453 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package rwcancel implements cancelable read/write operations on @@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool { func (rw *RWCancel) ReadyWrite() bool { closeFd := int32(rw.closingReader.Fd()) - pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} + pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}} var err error for { _, err = unix.Poll(pollFds, -1) diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go index 8f10b39..e1a97a5 100644 --- a/tai64n/tai64n.go +++ b/tai64n/tai64n.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tai64n diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go index c70fc1a..d0b4425 100644 --- a/tai64n/tai64n_test.go +++ b/tai64n/tai64n_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tai64n diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go index 67a785e..e3252b2 100644 --- a/tun/alignment_windows_test.go +++ b/tun/alignment_windows_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/checksum.go b/tun/checksum.go index 29a8fc8..b489c56 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -1,102 +1,86 @@ package tun -import "encoding/binary" +import ( + "encoding/binary" + "math/bits" +) // TODO: Explore SIMD and/or other assembly optimizations. -// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { - ac := initial + tmp := make([]byte, 8) + binary.NativeEndian.PutUint64(tmp, initial) + ac := binary.BigEndian.Uint64(tmp) + var carry uint64 for len(b) >= 128 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) - ac += uint64(binary.BigEndian.Uint32(b[64:68])) - ac += uint64(binary.BigEndian.Uint32(b[68:72])) - ac += uint64(binary.BigEndian.Uint32(b[72:76])) - ac += uint64(binary.BigEndian.Uint32(b[76:80])) - ac += uint64(binary.BigEndian.Uint32(b[80:84])) - ac += uint64(binary.BigEndian.Uint32(b[84:88])) - ac += uint64(binary.BigEndian.Uint32(b[88:92])) - ac += uint64(binary.BigEndian.Uint32(b[92:96])) - ac += uint64(binary.BigEndian.Uint32(b[96:100])) - ac += uint64(binary.BigEndian.Uint32(b[100:104])) - ac += uint64(binary.BigEndian.Uint32(b[104:108])) - ac += uint64(binary.BigEndian.Uint32(b[108:112])) - ac += uint64(binary.BigEndian.Uint32(b[112:116])) - ac += uint64(binary.BigEndian.Uint32(b[116:120])) - ac += uint64(binary.BigEndian.Uint32(b[120:124])) - ac += uint64(binary.BigEndian.Uint32(b[124:128])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry) + ac += carry b = b[128:] } if len(b) >= 64 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry) + ac += carry b = b[64:] } if len(b) >= 32 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac += carry b = b[32:] } if len(b) >= 16 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac += carry b = b[16:] } if len(b) >= 8 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac += carry b = b[8:] } if len(b) >= 4 { - ac += uint64(binary.BigEndian.Uint32(b)) + ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0) + ac += carry b = b[4:] } if len(b) >= 2 { - ac += uint64(binary.BigEndian.Uint16(b)) + ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0) + ac += carry b = b[2:] } if len(b) == 1 { - ac += uint64(b[0]) << 8 + tmp := binary.NativeEndian.Uint16([]byte{b[0], 0}) + ac, carry = bits.Add64(ac, uint64(tmp), 0) + ac += carry } - return ac + binary.NativeEndian.PutUint64(tmp, ac) + return binary.BigEndian.Uint64(tmp) } func checksum(b []byte, initial uint64) uint16 { diff --git a/tun/checksum_test.go b/tun/checksum_test.go index c1ccff5..4ea9b8b 100644 --- a/tun/checksum_test.go +++ b/tun/checksum_test.go @@ -1,11 +1,74 @@ package tun import ( + "encoding/binary" "fmt" "math/rand" "testing" + + "golang.org/x/sys/unix" ) +func checksumRef(b []byte, initial uint16) uint16 { + ac := uint64(initial) + + for len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(ac) +} + +func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + sum := checksumRef(srcAddr, 0) + sum = checksumRef(dstAddr, sum) + sum = checksumRef([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumRef(tmp, sum) +} + +func TestChecksum(t *testing.T) { + for length := 0; length <= 9001; length++ { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + csum := checksum(buf, 0x1234) + csumRef := checksumRef(buf, 0x1234) + if csum != csumRef { + t.Error("Expected checksum", csumRef, "got", csum) + } + } +} + +func TestPseudoHeaderChecksum(t *testing.T) { + for _, addrLen := range []int{4, 16} { + for length := 0; length <= 9001; length++ { + srcAddr := make([]byte, addrLen) + dstAddr := make([]byte, addrLen) + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(srcAddr) + rng.Read(dstAddr) + rng.Read(buf) + phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + csum := checksum(buf, phSum) + phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + csumRef := checksumRef(buf, phSumRef) + if csum != csumRef { + t.Error("Expected checksumRef", csumRef, "got", csum) + } + } + } +} + func BenchmarkChecksum(b *testing.B) { lengths := []int{ 64, diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index 4c4ea12..8b12ecc 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index 09929e0..80cd036 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go index d7897b2..b243b5c 100644 --- a/tun/netstack/examples/ping_client.go +++ b/tun/netstack/examples/ping_client.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 2275173..48a428b 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package netstack @@ -43,6 +43,7 @@ type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event + notifyHandle *channel.NotificationHandle incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr @@ -70,7 +71,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, if tcpipErr != nil { return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } - dev.ep.AddNotify(dev) + dev.notifyHandle = dev.ep.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) @@ -155,7 +156,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { func (tun *netTun) WriteNotify() { pkt := tun.ep.Read() - if pkt.IsNil() { + if pkt == nil { return } @@ -167,13 +168,14 @@ func (tun *netTun) WriteNotify() { func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) + tun.stack.Close() + tun.ep.RemoveNotify(tun.notifyHandle) + tun.ep.Close() if tun.events != nil { close(tun.events) } - tun.ep.Close() - if tun.incomingPacket != nil { close(tun.incomingPacket) } diff --git a/tun/offload_linux.go b/tun/offload_linux.go index 89cf024..b61654b 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index a68cd98..c04e003 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/operateonfd.go b/tun/operateonfd.go index f1beb6d..343f754 100644 --- a/tun/operateonfd.go +++ b/tun/operateonfd.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun.go b/tun/tun.go index 0ae53d0..336d642 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index c9a6c0b..341afe3 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -1,19 +1,17 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun import ( - "errors" "fmt" "io" "net" "os" "sync" "syscall" - "time" "unsafe" "golang.org/x/sys/unix" @@ -30,18 +28,6 @@ type NativeTun struct { closeOnce sync.Once } -func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { - for i := 0; i < 20; i++ { - iface, err = net.InterfaceByIndex(index) - if err != nil && errors.Is(err, unix.ENOMEM) { - time.Sleep(time.Duration(i) * time.Second / 3) - continue - } - return iface, err - } - return nil, err -} - func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( statusUp bool @@ -62,26 +48,22 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { return } - if n < 14 { + if n < 28 { continue } - if data[3 /* type */] != unix.RTM_IFINFO { + if data[3 /* ifm_type */] != unix.RTM_IFINFO { continue } - ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) + ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifm_index */]))) if ifindex != tunIfindex { continue } - iface, err := retryInterfaceByIndex(ifindex) - if err != nil { - tun.errors <- err - return - } + flags := int(*(*uint32)(unsafe.Pointer(&data[8 /* ifm_flags */]))) // Up / Down event - up := (iface.Flags & net.FlagUp) != 0 + up := (flags & syscall.IFF_UP) != 0 if up != statusUp && up { tun.events <- EventUp } @@ -90,11 +72,13 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } statusUp = up + mtu := int(*(*uint32)(unsafe.Pointer(&data[24 /* ifm_data.ifi_mtu */]))) + // MTU changes - if iface.MTU != statusMTU { + if mtu != statusMTU { tun.events <- EventMTUUpdate } - statusMTU = iface.MTU + statusMTU = mtu } } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 7c65fd9..4adf3a1 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 011e56a..bc6e7c1 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index ae571b9..5aa9070 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 2af8e3e..de65fb4 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index f620e0a..0fa70b0 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tuntest diff --git a/version.go b/version.go index db75bb9..d5524e8 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package main -const Version = "0.0.20230223" +const Version = "0.0.20250522"