mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-29 16:22:49 +02:00
Compare commits
No commits in common. "master" and "v0.2.8" have entirely different histories.
113 changed files with 747 additions and 2544 deletions
|
@ -1,4 +1,4 @@
|
|||
FROM golang:1.24.4 as awg
|
||||
FROM golang:1.20 as awg
|
||||
COPY . /awg
|
||||
WORKDIR /awg
|
||||
RUN go mod download && \
|
||||
|
@ -6,8 +6,7 @@ RUN go mod download && \
|
|||
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG AWGTOOLS_RELEASE="1.0.20241018"
|
||||
|
||||
ARG AWGTOOLS_RELEASE="1.0.20240213"
|
||||
RUN apk --no-cache add iproute2 iptables bash && \
|
||||
cd /usr/bin/ && \
|
||||
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
|
||||
|
|
2
Makefile
2
Makefile
|
@ -9,7 +9,7 @@ MAKEFLAGS += --no-print-directory
|
|||
|
||||
generate-version-and-build:
|
||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
|
||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
||||
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
||||
echo "$$ver" > version.go && \
|
||||
|
|
|
@ -50,4 +50,3 @@ $ git clone https://github.com/amnezia-vpn/amneziawg-go
|
|||
$ cd amneziawg-go
|
||||
$ make
|
||||
```
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -298,6 +298,11 @@ func (s *StdNetBind) BatchSize() int {
|
|||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBind) GetOffloadInfo() string {
|
||||
return fmt.Sprintf("ipv4TxOffload: %v, ipv4RxOffload: %v\nipv6TxOffload: %v, ipv6RxOffload: %v",
|
||||
s.ipv4TxOffload, s.ipv4RxOffload, s.ipv6TxOffload, s.ipv6RxOffload)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -328,6 +328,10 @@ func (bind *WinRingBind) BatchSize() int {
|
|||
return 1
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) GetOffloadInfo() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bindtest
|
||||
|
@ -91,6 +91,8 @@ func (c *ChannelBind) Close() error {
|
|||
|
||||
func (c *ChannelBind) BatchSize() int { return 1 }
|
||||
|
||||
func (c *ChannelBind) GetOffloadInfo() string { return "" }
|
||||
|
||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package conn implements WireGuard's network connections.
|
||||
|
@ -55,6 +55,8 @@ type Bind interface {
|
|||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
|
||||
GetOffloadInfo() string
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -13,35 +13,6 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
|
||||
func kernelVersion() (major, minor int) {
|
||||
var uname unix.Utsname
|
||||
if err := unix.Uname(&uname); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
values [2]int
|
||||
value, vi int
|
||||
)
|
||||
for _, c := range uname.Release {
|
||||
if '0' <= c && c <= '9' {
|
||||
value = (value * 10) + int(c-'0')
|
||||
} else {
|
||||
// Note that we're assuming N.N.N here.
|
||||
// If we see anything else, we are likely to mis-parse it.
|
||||
values[vi] = value
|
||||
vi++
|
||||
if vi >= len(values) {
|
||||
break
|
||||
}
|
||||
value = 0
|
||||
}
|
||||
}
|
||||
|
||||
return values[0], values[1]
|
||||
}
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
|
@ -86,24 +57,5 @@ func init() {
|
|||
}
|
||||
return err
|
||||
},
|
||||
|
||||
// Attempt to enable UDP_GRO
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
// Kernels below 5.12 are missing 98184612aca0 ("net:
|
||||
// udp: Add support for getsockopt(..., ..., UDP_GRO,
|
||||
// ..., ...);"), which means we can't read this back
|
||||
// later. We could pipe the return value through to
|
||||
// the rest of the code, but UDP_GRO is kind of buggy
|
||||
// anyway, so just gate this here.
|
||||
major, minor := kernelVersion()
|
||||
if major < 5 || (major == 5 && minor < 12) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
})
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func errShouldDisableUDPGSO(_ error) bool {
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winrio
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -223,11 +223,19 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
|
|||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) remove() {
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
return
|
||||
continue
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
|
@ -240,12 +248,12 @@ func (node *trieEntry) remove() {
|
|||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
return
|
||||
continue
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
return
|
||||
continue
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
|
@ -255,37 +263,6 @@ func (node *trieEntry) remove() {
|
|||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
var node *trieEntry
|
||||
var exact bool
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else {
|
||||
panic(errors.New("removing unknown address type"))
|
||||
}
|
||||
if !exact || node == nil || peer != node.peer {
|
||||
return
|
||||
}
|
||||
node.remove()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
elem.Value.(*trieEntry).remove()
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
var peers []*Peer
|
||||
var allowedIPs AllowedIPs
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rand.Seed(1)
|
||||
|
||||
for n := 0; n < NumberOfPeers; n++ {
|
||||
peers = append(peers, &Peer{})
|
||||
|
@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) {
|
|||
|
||||
for n := 0; n < NumberOfAddresses; n++ {
|
||||
var addr4 [4]byte
|
||||
rng.Read(addr4[:])
|
||||
rand.Read(addr4[:])
|
||||
cidr := uint8(rand.Intn(32) + 1)
|
||||
index := rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||
|
||||
var addr6 [16]byte
|
||||
rng.Read(addr6[:])
|
||||
rand.Read(addr6[:])
|
||||
cidr = uint8(rand.Intn(128) + 1)
|
||||
index = rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||
|
@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
for p = 0; ; p++ {
|
||||
for n := 0; n < NumberOfTests; n++ {
|
||||
var addr4 [4]byte
|
||||
rng.Read(addr4[:])
|
||||
rand.Read(addr4[:])
|
||||
peer1 := slow4.Lookup(addr4[:])
|
||||
peer2 := allowedIPs.Lookup(addr4[:])
|
||||
if peer1 != peer2 {
|
||||
|
@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
}
|
||||
|
||||
var addr6 [16]byte
|
||||
rng.Read(addr6[:])
|
||||
rand.Read(addr6[:])
|
||||
peer1 = slow6.Lookup(addr6[:])
|
||||
peer2 = allowedIPs.Lookup(addr6[:])
|
||||
if peer1 != peer2 {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
|
||||
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||
var trie *trieEntry
|
||||
var peers []*Peer
|
||||
root := parentIndirection{&trie, 2}
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rand.Seed(1)
|
||||
|
||||
const AddressLength = 4
|
||||
|
||||
|
@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
|
|||
|
||||
for n := 0; n < addressNumber; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rng.Read(addr[:])
|
||||
cidr := uint8(rng.Uint32() % (AddressLength * 8))
|
||||
index := rng.Int() % peerNumber
|
||||
rand.Read(addr[:])
|
||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||
index := rand.Int() % peerNumber
|
||||
root.insert(addr[:], cidr, peers[index])
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rng.Read(addr[:])
|
||||
rand.Read(addr[:])
|
||||
trie.lookup(addr[:])
|
||||
}
|
||||
}
|
||||
|
@ -101,10 +101,6 @@ func TestTrieIPv4(t *testing.T) {
|
|||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||
if p != peer {
|
||||
|
@ -180,21 +176,6 @@ func TestTrieIPv4(t *testing.T) {
|
|||
allowedIPs.RemoveByPeer(a)
|
||||
|
||||
assertNEQ(a, 192, 168, 0, 1)
|
||||
|
||||
insert(a, 1, 0, 0, 0, 32)
|
||||
insert(a, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 1, 0, 0, 0)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 32)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(nil, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(b, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 24)
|
||||
assertNEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 1, 0, 0, 0, 32)
|
||||
assertNEQ(a, 1, 0, 0, 0)
|
||||
}
|
||||
|
||||
/* Test ported from kernel implementation:
|
||||
|
@ -230,15 +211,6 @@ func TestTrieIPv6(t *testing.T) {
|
|||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
|
@ -251,18 +223,6 @@ func TestTrieIPv6(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
p := allowedIPs.Lookup(addr)
|
||||
if p == peer {
|
||||
t.Error("Assert NEQ failed")
|
||||
}
|
||||
}
|
||||
|
||||
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
||||
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
||||
insert(e, 0, 0, 0, 0, 0)
|
||||
|
@ -284,21 +244,4 @@ func TestTrieIPv6(t *testing.T) {
|
|||
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
||||
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
||||
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
||||
|
||||
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
}
|
||||
|
|
|
@ -1,144 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
type aSecCfgType struct {
|
||||
IsSet bool
|
||||
JunkPacketCount int
|
||||
JunkPacketMinSize int
|
||||
JunkPacketMaxSize int
|
||||
InitHeaderJunkSize int
|
||||
ResponseHeaderJunkSize int
|
||||
CookieReplyHeaderJunkSize int
|
||||
TransportHeaderJunkSize int
|
||||
InitPacketMagicHeader uint32
|
||||
ResponsePacketMagicHeader uint32
|
||||
UnderloadPacketMagicHeader uint32
|
||||
TransportPacketMagicHeader uint32
|
||||
// InitPacketMagicHeader Limit
|
||||
// ResponsePacketMagicHeader Limit
|
||||
// UnderloadPacketMagicHeader Limit
|
||||
// TransportPacketMagicHeader Limit
|
||||
}
|
||||
|
||||
type Limit struct {
|
||||
Min uint32
|
||||
Max uint32
|
||||
HeaderType uint32
|
||||
}
|
||||
|
||||
func NewLimit(min, max, headerType uint32) (Limit, error) {
|
||||
if min > max {
|
||||
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
|
||||
}
|
||||
|
||||
return Limit{
|
||||
Min: min,
|
||||
Max: max,
|
||||
HeaderType: headerType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
|
||||
// tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType)
|
||||
// var min, max, headerType uint32
|
||||
// _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType)
|
||||
// if err != nil {
|
||||
// return Limit{}, fmt.Errorf("invalid magic header format: %s", value)
|
||||
// }
|
||||
|
||||
limits := strings.Split(value, "-")
|
||||
if len(limits) != 2 {
|
||||
return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value)
|
||||
}
|
||||
|
||||
min, err := strconv.ParseUint(limits[0], 10, 32)
|
||||
if err != nil {
|
||||
return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err)
|
||||
}
|
||||
|
||||
max, err := strconv.ParseUint(limits[1], 10, 32)
|
||||
if err != nil {
|
||||
return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err)
|
||||
}
|
||||
|
||||
limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
|
||||
if err != nil {
|
||||
return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err)
|
||||
}
|
||||
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
type Limits []Limit
|
||||
|
||||
func NewLimits(limits []Limit) Limits {
|
||||
slices.SortFunc(limits, func(a, b Limit) int {
|
||||
if a.Min < b.Min {
|
||||
return -1
|
||||
} else if a.Min > b.Min {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
return Limits(limits)
|
||||
}
|
||||
|
||||
type Protocol struct {
|
||||
IsASecOn abool.AtomicBool
|
||||
// TODO: revision the need of the mutex
|
||||
ASecMux sync.RWMutex
|
||||
ASecCfg aSecCfgType
|
||||
JunkCreator junkCreator
|
||||
|
||||
HandshakeHandler SpecialHandshakeHandler
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) {
|
||||
extraSize := 0
|
||||
if len(optExtraSize) == 1 {
|
||||
extraSize = optExtraSize[0]
|
||||
}
|
||||
|
||||
var junk []byte
|
||||
protocol.ASecMux.RLock()
|
||||
if junkSize != 0 {
|
||||
buf := make([]byte, 0, junkSize+extraSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
|
||||
if err != nil {
|
||||
protocol.ASecMux.RUnlock()
|
||||
return nil, err
|
||||
}
|
||||
junk = writer.Bytes()
|
||||
}
|
||||
protocol.ASecMux.RUnlock()
|
||||
|
||||
return junk, nil
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package internal
|
||||
|
||||
type mockGenerator struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func NewMockGenerator(size int) mockGenerator {
|
||||
return mockGenerator{size: size}
|
||||
}
|
||||
|
||||
func (m mockGenerator) Generate() []byte {
|
||||
return make([]byte, m.size)
|
||||
}
|
||||
|
||||
func (m mockGenerator) Size() int {
|
||||
return m.size
|
||||
}
|
||||
|
||||
func (m mockGenerator) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
type mockByteGenerator struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewMockByteGenerator(data []byte) mockByteGenerator {
|
||||
return mockByteGenerator{data: data}
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Generate() []byte {
|
||||
return bg.data
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Size() int {
|
||||
return len(bg.data)
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
v2 "math/rand/v2"
|
||||
)
|
||||
|
||||
type junkCreator struct {
|
||||
aSecCfg aSecCfgType
|
||||
cha8Rand *v2.ChaCha8
|
||||
}
|
||||
|
||||
// TODO: refactor param to only pass the junk related params
|
||||
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
|
||||
buf := make([]byte, 32)
|
||||
_, err := crand.Read(buf)
|
||||
if err != nil {
|
||||
return junkCreator{}, err
|
||||
}
|
||||
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
|
||||
if jc.aSecCfg.JunkPacketCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for range jc.aSecCfg.JunkPacketCount {
|
||||
packetSize := jc.randomPacketSize()
|
||||
junk, err := jc.randomJunkWithSize(packetSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create junk packet: %v", err)
|
||||
}
|
||||
*junks = append(*junks, junk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) randomPacketSize() int {
|
||||
return int(
|
||||
jc.cha8Rand.Uint64()%uint64(
|
||||
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
|
||||
),
|
||||
) + jc.aSecCfg.JunkPacketMinSize
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk, err := jc.randomJunkWithSize(size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create header junk: %v", err)
|
||||
}
|
||||
_, err = writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write header junk: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
||||
// TODO: use a memory pool to allocate
|
||||
junk := make([]byte, size)
|
||||
_, err := jc.cha8Rand.Read(junk)
|
||||
return junk, err
|
||||
}
|
|
@ -1,115 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
||||
jc, err := NewJunkCreator(aSecCfgType{
|
||||
IsSet: true,
|
||||
JunkPacketCount: 5,
|
||||
JunkPacketMinSize: 500,
|
||||
JunkPacketMaxSize: 1000,
|
||||
InitHeaderJunkSize: 30,
|
||||
ResponseHeaderJunkSize: 40,
|
||||
InitPacketMagicHeader: 123456,
|
||||
ResponsePacketMagicHeader: 67543,
|
||||
UnderloadPacketMagicHeader: 32345,
|
||||
TransportPacketMagicHeader: 123123,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("failed to create junk creator %v", err)
|
||||
return junkCreator{}, err
|
||||
}
|
||||
|
||||
return jc, nil
|
||||
}
|
||||
|
||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
|
||||
err := jc.CreateJunkPackets(&got)
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v; failed",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, junk := range got {
|
||||
key := string(junk)
|
||||
if seen[key] {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
||||
got,
|
||||
junk,
|
||||
)
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r1, _ := jc.randomJunkWithSize(10)
|
||||
r2, _ := jc.randomJunkWithSize(10)
|
||||
fmt.Printf("%v\n%v\n", r1, r2)
|
||||
if bytes.Equal(r1, r2) {
|
||||
t.Errorf("same junks %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for range [30]struct{}{} {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
|
||||
got > jc.aSecCfg.JunkPacketMaxSize {
|
||||
t.Errorf(
|
||||
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||
got,
|
||||
jc.aSecCfg.JunkPacketMinSize,
|
||||
jc.aSecCfg.JunkPacketMaxSize,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
err := jc.AppendJunk(buffer, 30)
|
||||
if err != nil &&
|
||||
buffer.Len() != len(s)+30 {
|
||||
t.Error("appendWithJunk() size don't match")
|
||||
}
|
||||
read := make([]byte, 50)
|
||||
buffer.Read(read)
|
||||
fmt.Println(string(read))
|
||||
})
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// TODO: atomic?/ and better way to use this
|
||||
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||
|
||||
// TODO
|
||||
var WaitResponse = struct {
|
||||
Channel chan struct{}
|
||||
ShouldWait *abool.AtomicBool
|
||||
}{
|
||||
make(chan struct{}, 1),
|
||||
abool.New(),
|
||||
}
|
||||
|
||||
type SpecialHandshakeHandler struct {
|
||||
isFirstDone bool
|
||||
SpecialJunk TagJunkPacketGenerators
|
||||
ControlledJunk TagJunkPacketGenerators
|
||||
|
||||
nextItime time.Time
|
||||
ITimeout time.Duration // seconds
|
||||
|
||||
IsSet bool
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) Validate() error {
|
||||
var errs []error
|
||||
if err := handler.SpecialJunk.Validate(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if err := handler.ControlledJunk.Validate(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||
if !handler.SpecialJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: create tests
|
||||
if !handler.isFirstDone {
|
||||
handler.isFirstDone = true
|
||||
} else if !handler.isTimeToSendSpecial() {
|
||||
return nil
|
||||
}
|
||||
|
||||
rv := handler.SpecialJunk.GeneratePackets()
|
||||
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
|
||||
return time.Now().After(handler.nextItime)
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
|
||||
if !handler.ControlledJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return handler.ControlledJunk.GeneratePackets()
|
||||
}
|
|
@ -1,190 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v2 "math/rand/v2"
|
||||
// "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type Generator interface {
|
||||
Generate() []byte
|
||||
Size() int
|
||||
}
|
||||
|
||||
type newGenerator func(string) (Generator, error)
|
||||
|
||||
type BytesGenerator struct {
|
||||
value []byte
|
||||
size int
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Generate() []byte {
|
||||
return bg.value
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Size() int {
|
||||
return bg.size
|
||||
}
|
||||
|
||||
func newBytesGenerator(param string) (Generator, error) {
|
||||
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
|
||||
if !hasPrefix {
|
||||
return nil, fmt.Errorf("not correct hex: %s", param)
|
||||
}
|
||||
|
||||
hex, err := hexToBytes(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hexToBytes: %w", err)
|
||||
}
|
||||
|
||||
return &BytesGenerator{value: hex, size: len(hex)}, nil
|
||||
}
|
||||
|
||||
func hexToBytes(hexStr string) ([]byte, error) {
|
||||
hexStr = strings.TrimPrefix(hexStr, "0x")
|
||||
hexStr = strings.TrimPrefix(hexStr, "0X")
|
||||
|
||||
// Ensure even length (pad with leading zero if needed)
|
||||
if len(hexStr)%2 != 0 {
|
||||
hexStr = "0" + hexStr
|
||||
}
|
||||
|
||||
return hex.DecodeString(hexStr)
|
||||
}
|
||||
|
||||
type RandomPacketGenerator struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
size int
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Generate() []byte {
|
||||
junk := make([]byte, rpg.size)
|
||||
rpg.cha8Rand.Read(junk)
|
||||
return junk
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Size() int {
|
||||
return rpg.size
|
||||
}
|
||||
|
||||
func newRandomPacketGenerator(param string) (Generator, error) {
|
||||
size, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("random packet parse int: %w", err)
|
||||
}
|
||||
|
||||
if size > 1000 {
|
||||
return nil, fmt.Errorf("random packet size must be less than 1000")
|
||||
}
|
||||
|
||||
buf := make([]byte, 32)
|
||||
_, err = crand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("random packet crand read: %w", err)
|
||||
}
|
||||
|
||||
return &RandomPacketGenerator{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newTimestampGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &TimestampGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitTimeoutGenerator struct {
|
||||
waitTimeout time.Duration
|
||||
}
|
||||
|
||||
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
||||
time.Sleep(wtg.waitTimeout)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (wtg *WaitTimeoutGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
||||
timeout, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("timeout parse int: %w", err)
|
||||
}
|
||||
|
||||
if timeout > 5000 {
|
||||
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||
}
|
||||
|
||||
return &WaitTimeoutGenerator{
|
||||
waitTimeout: time.Duration(timeout) * time.Millisecond,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type PacketCounterGenerator struct {
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
// TODO: better way to handle counter tag
|
||||
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newPacketCounterGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &PacketCounterGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitResponseGenerator struct {
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Generate() []byte {
|
||||
WaitResponse.ShouldWait.Set()
|
||||
<-WaitResponse.Channel
|
||||
WaitResponse.ShouldWait.UnSet()
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitResponseGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &WaitResponseGenerator{}, nil
|
||||
}
|
|
@ -1,189 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_newBytesGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "wrong start",
|
||||
args: args{
|
||||
param: "123456",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with X",
|
||||
args: args{
|
||||
param: "0X12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with x",
|
||||
args: args{
|
||||
param: "0x12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "valid hex",
|
||||
args: args{
|
||||
param: "0xf6ab3267fa",
|
||||
},
|
||||
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
{
|
||||
name: "valid hex with odd length",
|
||||
args: args{
|
||||
param: "0xfab3267fa",
|
||||
},
|
||||
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newBytesGenerator(tt.args.param)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
gotValues := got.Generate()
|
||||
require.Equal(t, tt.want, gotValues)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newRandomPacketGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newRandomPacketGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketCounterGenerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
param string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid empty param",
|
||||
param: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid non-empty param",
|
||||
param: "anything",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gen, err := newPacketCounterGenerator(tc.param)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8, gen.Size())
|
||||
|
||||
// Reset counter to known value for test
|
||||
initialCount := uint64(42)
|
||||
PacketCounter.Store(initialCount)
|
||||
|
||||
output := gen.Generate()
|
||||
require.Equal(t, 8, len(output))
|
||||
|
||||
// Verify counter value in output
|
||||
counterValue := binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount, counterValue)
|
||||
|
||||
// Increment counter and verify change
|
||||
PacketCounter.Add(1)
|
||||
output = gen.Generate()
|
||||
counterValue = binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount+1, counterValue)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type TagJunkPacketGenerator struct {
|
||||
name string
|
||||
tagValue string
|
||||
|
||||
packetSize int
|
||||
generators []Generator
|
||||
}
|
||||
|
||||
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
|
||||
return TagJunkPacketGenerator{
|
||||
name: name,
|
||||
tagValue: tagValue,
|
||||
generators: make([]Generator, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
||||
tg.generators = append(tg.generators, generator)
|
||||
tg.packetSize += generator.Size()
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
|
||||
packet := make([]byte, 0, tg.packetSize)
|
||||
for _, generator := range tg.generators {
|
||||
packet = append(packet, generator.Generate()...)
|
||||
}
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) Name() string {
|
||||
return tg.name
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
|
||||
if len(tg.name) != 2 {
|
||||
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(tg.name[1:2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("name 2 char should be an int %w", err)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||
return IpcFields{
|
||||
Key: tg.name,
|
||||
Value: tg.tagValue,
|
||||
}
|
||||
}
|
|
@ -1,210 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTagJunkGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
genName string
|
||||
size int
|
||||
expected TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "Create new generator with empty name",
|
||||
genName: "",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with valid name",
|
||||
genName: "T1",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T1",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with non-zero size",
|
||||
genName: "T2",
|
||||
size: 5,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 5),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
|
||||
require.Equal(t, tc.expected.name, result.name)
|
||||
require.Equal(t, tc.expected.packetSize, result.packetSize)
|
||||
require.Equal(t, cap(result.generators), len(tc.expected.generators))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorAppend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialState TagJunkPacketGenerator
|
||||
mockSize int
|
||||
expectedLength int
|
||||
expectedSize int
|
||||
}{
|
||||
{
|
||||
name: "Append to empty generator",
|
||||
initialState: newTagJunkPacketGenerator("T1", "", 0),
|
||||
mockSize: 5,
|
||||
expectedLength: 1,
|
||||
expectedSize: 5,
|
||||
},
|
||||
{
|
||||
name: "Append to non-empty generator",
|
||||
initialState: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 10,
|
||||
generators: make([]Generator, 2),
|
||||
},
|
||||
mockSize: 7,
|
||||
expectedLength: 3, // 2 existing + 1 new
|
||||
expectedSize: 17, // 10 + 7
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.initialState
|
||||
mockGen := internal.NewMockGenerator(tc.mockSize)
|
||||
|
||||
tg.append(mockGen)
|
||||
|
||||
require.Equal(t, tc.expectedLength, len(tg.generators))
|
||||
require.Equal(t, tc.expectedSize, tg.packetSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorGenerate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock generators for testing
|
||||
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
|
||||
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupGenerator func() TagJunkPacketGenerator
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "Generate with empty generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
return newTagJunkPacketGenerator("T1", "", 0)
|
||||
},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Generate with single generator",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T2", "", 0)
|
||||
tg.append(mockGen1)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02},
|
||||
},
|
||||
{
|
||||
name: "Generate with multiple generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T3", "", 0)
|
||||
tg.append(mockGen1)
|
||||
tg.append(mockGen2)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.setupGenerator()
|
||||
result := tg.generatePacket()
|
||||
|
||||
require.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorNameIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
generatorName string
|
||||
expectedIndex int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid name with digit",
|
||||
generatorName: "T5",
|
||||
expectedIndex: 5,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too short",
|
||||
generatorName: "T",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too long",
|
||||
generatorName: "T55",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - non-digit second character",
|
||||
generatorName: "TX",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := TagJunkPacketGenerator{name: tc.generatorName}
|
||||
index, err := tg.nameIndex()
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectedIndex, index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package awg
|
||||
|
||||
import "fmt"
|
||||
|
||||
type TagJunkPacketGenerators struct {
|
||||
tagGenerators []TagJunkPacketGenerator
|
||||
length int
|
||||
DefaultJunkCount int // Jc
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) AppendGenerator(
|
||||
generator TagJunkPacketGenerator,
|
||||
) {
|
||||
generators.tagGenerators = append(generators.tagGenerators, generator)
|
||||
generators.length++
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) IsDefined() bool {
|
||||
return len(generators.tagGenerators) > 0
|
||||
}
|
||||
|
||||
// validate that packets were defined consecutively
|
||||
func (generators *TagJunkPacketGenerators) Validate() error {
|
||||
seen := make([]bool, len(generators.tagGenerators))
|
||||
for _, generator := range generators.tagGenerators {
|
||||
index, err := generator.nameIndex()
|
||||
if index > len(generators.tagGenerators) {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("name index: %w", err)
|
||||
} else {
|
||||
seen[index-1] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, found := range seen {
|
||||
if !found {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
|
||||
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
||||
|
||||
for i, tagGenerator := range generators.tagGenerators {
|
||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||
copy(rv[i], tagGenerator.generatePacket())
|
||||
PacketCounter.Inc()
|
||||
}
|
||||
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
|
||||
rv := make([]IpcFields, 0, len(tg.tagGenerators))
|
||||
for _, generator := range tg.tagGenerators {
|
||||
rv = append(rv, generator.IpcGetFields())
|
||||
}
|
||||
|
||||
return rv
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generator TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "append single generator",
|
||||
generator: newTagJunkPacketGenerator("t1", "", 10),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
|
||||
// Initial length should be 0
|
||||
require.Equal(t, 0, generators.length)
|
||||
require.Empty(t, generators.tagGenerators)
|
||||
|
||||
// After append, length should be 1 and generator should be added
|
||||
generators.AppendGenerator(tt.generator)
|
||||
require.Equal(t, 1, generators.length)
|
||||
require.Len(t, generators.tagGenerators, 1)
|
||||
require.Equal(t, tt.generator, generators.tagGenerators[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generators []TagJunkPacketGenerator
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "bad start",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "non-consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t2", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
newTagJunkPacketGenerator("t5", "", 10),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nameIndex error",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("error", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "name must be 2 character long",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
for _, gen := range tt.generators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
err := generators.Validate()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errMsg)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
|
||||
mockByte1 := []byte{0x01, 0x02}
|
||||
mockByte2 := []byte{0x03, 0x04, 0x05}
|
||||
mockGen1 := internal.NewMockByteGenerator(mockByte1)
|
||||
mockGen2 := internal.NewMockByteGenerator(mockByte2)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupGenerator func() []TagJunkPacketGenerator
|
||||
expected [][]byte
|
||||
}{
|
||||
{
|
||||
name: "generate with no default junk",
|
||||
setupGenerator: func() []TagJunkPacketGenerator {
|
||||
tg1 := newTagJunkPacketGenerator("t1", "", 0)
|
||||
tg1.append(mockGen1)
|
||||
tg1.append(mockGen2)
|
||||
tg2 := newTagJunkPacketGenerator("t2", "", 0)
|
||||
tg2.append(mockGen2)
|
||||
tg2.append(mockGen1)
|
||||
|
||||
return []TagJunkPacketGenerator{tg1, tg2}
|
||||
},
|
||||
expected: [][]byte{
|
||||
append(mockByte1, mockByte2...),
|
||||
append(mockByte2, mockByte1...),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
tagGenerators := tt.setupGenerator()
|
||||
for _, gen := range tagGenerators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
result := generators.GeneratePackets()
|
||||
require.Equal(t, result, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,112 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IpcFields struct{ Key, Value string }
|
||||
|
||||
type EnumTag string
|
||||
|
||||
const (
|
||||
BytesEnumTag EnumTag = "b"
|
||||
CounterEnumTag EnumTag = "c"
|
||||
TimestampEnumTag EnumTag = "t"
|
||||
RandomBytesEnumTag EnumTag = "r"
|
||||
WaitTimeoutEnumTag EnumTag = "wt"
|
||||
WaitResponseEnumTag EnumTag = "wr"
|
||||
)
|
||||
|
||||
var generatorCreator = map[EnumTag]newGenerator{
|
||||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||
// WaitResponseEnumTag: newWaitResponseGenerator,
|
||||
}
|
||||
|
||||
// helper map to determine enumTags are unique
|
||||
var uniqueTags = map[EnumTag]bool{
|
||||
CounterEnumTag: false,
|
||||
TimestampEnumTag: false,
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name EnumTag
|
||||
Param string
|
||||
}
|
||||
|
||||
func parseTag(input string) (Tag, error) {
|
||||
// Regular expression to match <tagname optional_param>
|
||||
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
|
||||
|
||||
match := re.FindStringSubmatch(input)
|
||||
tag := Tag{
|
||||
Name: EnumTag(match[1]),
|
||||
}
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
tag.Param = strings.TrimSpace(match[2])
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
||||
inputSlice := strings.Split(input, "<")
|
||||
if len(inputSlice) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
|
||||
}
|
||||
|
||||
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
|
||||
maps.Copy(uniqueTagCheck, uniqueTags)
|
||||
|
||||
// skip byproduct of split
|
||||
inputSlice = inputSlice[1:]
|
||||
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
|
||||
for _, inputParam := range inputSlice {
|
||||
if len(inputParam) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"empty tag in input: %s",
|
||||
inputSlice,
|
||||
)
|
||||
} else if strings.Count(inputParam, ">") != 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
||||
}
|
||||
|
||||
tag, _ := parseTag(inputParam)
|
||||
creator, ok := generatorCreator[tag.Name]
|
||||
if !ok {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
|
||||
}
|
||||
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
||||
if present {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"tag %s needs to be unique",
|
||||
tag.Name,
|
||||
)
|
||||
}
|
||||
uniqueTagCheck[tag.Name] = true
|
||||
}
|
||||
generator, err := creator(tag.Param)
|
||||
if err != nil {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
|
||||
}
|
||||
|
||||
// TODO: handle counter tag
|
||||
// if tag.Name == CounterEnumTag {
|
||||
// packetCounter, ok := generator.(*PacketCounterGenerator)
|
||||
// if !ok {
|
||||
// log.Fatalf("packet counter generator expected, got %T", generator)
|
||||
// }
|
||||
// PacketCounter = packetCounter.counter
|
||||
// }
|
||||
|
||||
rv.append(generator)
|
||||
}
|
||||
|
||||
return rv, nil
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
input string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid name",
|
||||
args: args{name: "apple", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{name: "i1", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra >",
|
||||
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra <",
|
||||
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "empty <>",
|
||||
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "invalid tag",
|
||||
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
|
||||
wantErr: fmt.Errorf("invalid tag"),
|
||||
},
|
||||
{
|
||||
name: "counter uniqueness violation",
|
||||
args: args{name: "i1", input: "<c><c>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "timestamp uniqueness violation",
|
||||
args: args{name: "i1", input: "<t><t>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Parse(tt.args.name, tt.args.input)
|
||||
|
||||
// TODO: ErrorAs doesn't work as you think
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
358
device/device.go
358
device/device.go
|
@ -1,60 +1,24 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"github.com/tevino/abool/v2"
|
||||
)
|
||||
|
||||
type Version uint8
|
||||
|
||||
const (
|
||||
VersionDefault Version = iota
|
||||
VersionAwg
|
||||
VersionAwgSpecialHandshake
|
||||
)
|
||||
|
||||
// TODO:
|
||||
type AtomicVersion struct {
|
||||
value atomic.Uint32
|
||||
}
|
||||
|
||||
func NewAtomicVersion(v Version) *AtomicVersion {
|
||||
av := &AtomicVersion{}
|
||||
av.Store(v)
|
||||
return av
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Load() Version {
|
||||
return Version(av.value.Load())
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Store(v Version) {
|
||||
av.value.Store(uint32(v))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
|
||||
return av.value.CompareAndSwap(uint32(old), uint32(new))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Swap(new Version) Version {
|
||||
return Version(av.value.Swap(uint32(new)))
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
state struct {
|
||||
// state holds the device's state. It is accessed atomically.
|
||||
|
@ -128,8 +92,22 @@ type Device struct {
|
|||
closed chan struct{}
|
||||
log *Logger
|
||||
|
||||
version Version
|
||||
awg awg.Protocol
|
||||
isASecOn abool.AtomicBool
|
||||
aSecMux sync.RWMutex
|
||||
aSecCfg aSecCfgType
|
||||
}
|
||||
|
||||
type aSecCfgType struct {
|
||||
isSet bool
|
||||
junkPacketCount int
|
||||
junkPacketMinSize int
|
||||
junkPacketMaxSize int
|
||||
initPacketJunkSize int
|
||||
responsePacketJunkSize int
|
||||
initPacketMagicHeader uint32
|
||||
responsePacketMagicHeader uint32
|
||||
underloadPacketMagicHeader uint32
|
||||
transportPacketMagicHeader uint32
|
||||
}
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
|
@ -569,6 +547,7 @@ func (device *Device) BindUpdate() error {
|
|||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
device.log.Verbosef(netc.bind.GetOffloadInfo())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -578,261 +557,250 @@ func (device *Device) BindClose() error {
|
|||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
func (device *Device) isAWG() bool {
|
||||
return device.version >= VersionAwg
|
||||
func (device *Device) isAdvancedSecurityOn() bool {
|
||||
return device.isASecOn.IsSet()
|
||||
}
|
||||
|
||||
func (device *Device) resetProtocol() {
|
||||
// restore default message type values
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
MessageInitiationType = 1
|
||||
MessageResponseType = 2
|
||||
MessageCookieReplyType = 3
|
||||
MessageTransportType = 4
|
||||
}
|
||||
|
||||
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
||||
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
|
||||
return nil
|
||||
}
|
||||
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||
|
||||
var errs []error
|
||||
if !tempASecCfg.isSet {
|
||||
return err
|
||||
}
|
||||
|
||||
isASecOn := false
|
||||
device.awg.ASecMux.Lock()
|
||||
if tempAwg.ASecCfg.JunkPacketCount < 0 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
device.aSecMux.Lock()
|
||||
if tempASecCfg.junkPacketCount < 0 {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketCount should be non negative",
|
||||
),
|
||||
)
|
||||
}
|
||||
device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
|
||||
if tempAwg.ASecCfg.JunkPacketCount != 0 {
|
||||
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
||||
if tempASecCfg.junkPacketCount != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
|
||||
if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
|
||||
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
||||
if tempASecCfg.junkPacketMinSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if device.awg.ASecCfg.JunkPacketCount > 0 &&
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
|
||||
if device.aSecCfg.junkPacketCount > 0 &&
|
||||
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
||||
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
|
||||
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
|
||||
device.awg.ASecCfg.JunkPacketMinSize = 0
|
||||
device.awg.ASecCfg.JunkPacketMaxSize = 1
|
||||
errs = append(errs, ipcErrorf(
|
||||
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
||||
device.aSecCfg.junkPacketMinSize = 0
|
||||
device.aSecCfg.junkPacketMaxSize = 1
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
))
|
||||
} else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
)
|
||||
}
|
||||
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d; %w",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMinSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d",
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize,
|
||||
tempAwg.ASecCfg.JunkPacketMinSize,
|
||||
))
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMinSize,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
|
||||
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
|
||||
if tempASecCfg.junkPacketMaxSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
|
||||
|
||||
if newInitSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||
tempASecCfg.initPacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.ASecCfg.InitHeaderJunkSize,
|
||||
tempASecCfg.initPacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
|
||||
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
|
||||
if tempASecCfg.initPacketJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
|
||||
|
||||
if newResponseSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||
tempASecCfg.responsePacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.ASecCfg.ResponseHeaderJunkSize,
|
||||
tempASecCfg.responsePacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
|
||||
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
|
||||
if tempASecCfg.responsePacketJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
|
||||
|
||||
if newCookieSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
|
||||
|
||||
if newTransportSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.ASecCfg.TransportHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
|
||||
if tempASecCfg.initPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
|
||||
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
|
||||
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
|
||||
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageInitiationType = 1
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 {
|
||||
if tempASecCfg.responsePacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||
device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
|
||||
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader
|
||||
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
|
||||
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageResponseType = 2
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 {
|
||||
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader
|
||||
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageCookieReplyType = 3
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 {
|
||||
if tempASecCfg.transportPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
|
||||
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader
|
||||
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
|
||||
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
MessageTransportType = 4
|
||||
}
|
||||
|
||||
isSameHeaderMap := map[uint32]struct{}{
|
||||
MessageInitiationType: {},
|
||||
MessageResponseType: {},
|
||||
MessageCookieReplyType: {},
|
||||
MessageTransportType: {},
|
||||
}
|
||||
isSameMap := map[uint32]bool{}
|
||||
isSameMap[MessageInitiationType] = true
|
||||
isSameMap[MessageResponseType] = true
|
||||
isSameMap[MessageCookieReplyType] = true
|
||||
isSameMap[MessageTransportType] = true
|
||||
|
||||
// size will be different if same values
|
||||
if len(isSameHeaderMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
if len(isSameMap) != 4 {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
|
||||
MessageInitiationType,
|
||||
MessageResponseType,
|
||||
MessageCookieReplyType,
|
||||
MessageTransportType,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
||||
MessageInitiationType,
|
||||
MessageResponseType,
|
||||
MessageCookieReplyType,
|
||||
MessageTransportType,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
isSameSizeMap := map[int]struct{}{
|
||||
newInitSize: {},
|
||||
newResponseSize: {},
|
||||
newCookieSize: {},
|
||||
newTransportSize: {},
|
||||
}
|
||||
|
||||
if len(isSameSizeMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
||||
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
||||
|
||||
if newInitSize == newResponseSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
|
||||
`new init size:%d; and new response size:%d; should differ; %w`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
newCookieSize,
|
||||
newTransportSize,
|
||||
),
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
|
||||
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
|
||||
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
|
||||
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new init size:%d; and new response size:%d; should differ`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
)
|
||||
}
|
||||
|
||||
} else {
|
||||
packetSizeToMsgType = map[int]uint32{
|
||||
newInitSize: MessageInitiationType,
|
||||
newResponseSize: MessageResponseType,
|
||||
newCookieSize: MessageCookieReplyType,
|
||||
newTransportSize: MessageTransportType,
|
||||
MessageCookieReplySize: MessageCookieReplyType,
|
||||
MessageTransportSize: MessageTransportType,
|
||||
}
|
||||
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
|
||||
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
|
||||
MessageCookieReplyType: 0,
|
||||
MessageTransportType: 0,
|
||||
}
|
||||
}
|
||||
|
||||
device.awg.IsASecOn.SetTo(isASecOn)
|
||||
var err error
|
||||
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
device.isASecOn.SetTo(isASecOn)
|
||||
device.aSecMux.Unlock()
|
||||
|
||||
if tempAwg.HandshakeHandler.IsSet {
|
||||
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||
} else {
|
||||
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
|
||||
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
|
||||
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
|
||||
device.version = VersionAwgSpecialHandshake
|
||||
}
|
||||
} else {
|
||||
device.version = VersionAwg
|
||||
}
|
||||
|
||||
device.awg.ASecMux.Unlock()
|
||||
|
||||
return errors.Join(errs...)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,28 +1,25 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
|
@ -53,7 +50,7 @@ func uapiCfg(cfg ...string) string {
|
|||
|
||||
// genConfigs generates a pair of configs that connect to each other.
|
||||
// The configs use distinct, probably-usable ports.
|
||||
func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
||||
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||
var key1, key2 NoisePrivateKey
|
||||
_, err := rand.Read(key1[:])
|
||||
if err != nil {
|
||||
|
@ -65,8 +62,7 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
|||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
args0 := append([]string(nil), cfg...)
|
||||
args0 = append(args0, []string{
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
|
@ -74,16 +70,12 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
|||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.2/32",
|
||||
}...)
|
||||
cfgs[0] = uapiCfg(args0...)
|
||||
|
||||
)
|
||||
endpointCfgs[0] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
|
||||
args1 := append([]string(nil), cfg...)
|
||||
args1 = append(args1, []string{
|
||||
cfgs[1] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
|
@ -91,9 +83,66 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
|||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.1/32",
|
||||
}...)
|
||||
)
|
||||
endpointCfgs[1] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
cfgs[1] = uapiCfg(args1...)
|
||||
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||
var key1, key2 NoisePrivateKey
|
||||
_, err := rand.Read(key1[:])
|
||||
if err != nil {
|
||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||
}
|
||||
_, err = rand.Read(key2[:])
|
||||
if err != nil {
|
||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "501",
|
||||
"s1", "30",
|
||||
"s2", "40",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h4", "32345",
|
||||
"h3", "123123",
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.2/32",
|
||||
)
|
||||
endpointCfgs[0] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
cfgs[1] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "501",
|
||||
"s1", "30",
|
||||
"s2", "40",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h4", "32345",
|
||||
"h3", "123123",
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.1/32",
|
||||
)
|
||||
endpointCfgs[1] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
|
@ -136,10 +185,9 @@ func (pair *testPair) Send(
|
|||
// pong is the new ping
|
||||
p0, p1 = p1, p0
|
||||
}
|
||||
|
||||
msg := tuntest.Ping(p0.ip, p1.ip)
|
||||
p1.tun.Outbound <- msg
|
||||
timer := time.NewTimer(6 * time.Second)
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
defer timer.Stop()
|
||||
var err error
|
||||
select {
|
||||
|
@ -166,12 +214,14 @@ func (pair *testPair) Send(
|
|||
// genTestPair creates a testPair.
|
||||
func genTestPair(
|
||||
tb testing.TB,
|
||||
realSocket bool,
|
||||
extraCfg ...string,
|
||||
realSocket, withASecurity bool,
|
||||
) (pair testPair) {
|
||||
var cfg, endpointCfg [2]string
|
||||
cfg, endpointCfg = genConfigs(tb, extraCfg...)
|
||||
|
||||
if withASecurity {
|
||||
cfg, endpointCfg = genASecurityConfigs(tb)
|
||||
} else {
|
||||
cfg, endpointCfg = genConfigs(tb)
|
||||
}
|
||||
var binds [2]conn.Bind
|
||||
if realSocket {
|
||||
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||
|
@ -215,7 +265,7 @@ func genTestPair(
|
|||
|
||||
func TestTwoDevicePing(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true)
|
||||
pair := genTestPair(t, true, false)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
pair.Send(t, Ping, nil)
|
||||
})
|
||||
|
@ -224,23 +274,9 @@ func TestTwoDevicePing(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
|
||||
func TestAWGDevicePing(t *testing.T) {
|
||||
func TestTwoDevicePingASecurity(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
|
||||
pair := genTestPair(t, true,
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
"s1", "30",
|
||||
"s2", "40",
|
||||
"s3", "50",
|
||||
"s4", "5",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h3", "123123",
|
||||
"h4", "32345",
|
||||
)
|
||||
pair := genTestPair(t, true, true)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
pair.Send(t, Ping, nil)
|
||||
})
|
||||
|
@ -249,58 +285,13 @@ func TestAWGDevicePing(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Needs to be stopped with Ctrl-C
|
||||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||
t.Skip("This test is intended to be run manually, not as part of the test suite.")
|
||||
|
||||
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
isRunning := atomic.NewBool(true)
|
||||
go func() {
|
||||
<-signalContext.Done()
|
||||
fmt.Println("Waiting to finish")
|
||||
isRunning.Store(false)
|
||||
}()
|
||||
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true,
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||
"i2", "<b 0xf6ab3267fa><r 100>",
|
||||
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||
"j3", "<t><b 0xf6ab><c><r 10>",
|
||||
"itime", "60",
|
||||
// "jc", "1",
|
||||
// "jmin", "500",
|
||||
// "jmax", "1000",
|
||||
// "s1", "30",
|
||||
// "s2", "40",
|
||||
// "h1", "123456",
|
||||
// "h2", "67543",
|
||||
// "h4", "32345",
|
||||
// "h3", "123123",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Ping, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Pong, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpDown(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
const itrials = 50
|
||||
const otrials = 10
|
||||
|
||||
for n := 0; n < otrials; n++ {
|
||||
pair := genTestPair(t, false)
|
||||
pair := genTestPair(t, false, false)
|
||||
for i := range pair {
|
||||
for k := range pair[i].dev.peers.keyMap {
|
||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||
|
@ -334,7 +325,7 @@ func TestUpDown(t *testing.T) {
|
|||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||
// It is intended to be used with the race detector to catch data races.
|
||||
func TestConcurrencySafety(t *testing.T) {
|
||||
pair := genTestPair(t, true)
|
||||
pair := genTestPair(t, true, false)
|
||||
done := make(chan struct{})
|
||||
|
||||
const warmupIters = 10
|
||||
|
@ -415,7 +406,7 @@ func TestConcurrencySafety(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkLatency(b *testing.B) {
|
||||
pair := genTestPair(b, true)
|
||||
pair := genTestPair(b, true, false)
|
||||
|
||||
// Establish a connection.
|
||||
pair.Send(b, Ping, nil)
|
||||
|
@ -429,7 +420,7 @@ func BenchmarkLatency(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkThroughput(b *testing.B) {
|
||||
pair := genTestPair(b, true)
|
||||
pair := genTestPair(b, true, false)
|
||||
|
||||
// Establish a connection.
|
||||
pair.Send(b, Ping, nil)
|
||||
|
@ -473,7 +464,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkUAPIGet(b *testing.B) {
|
||||
pair := genTestPair(b, true)
|
||||
pair := genTestPair(b, true, false)
|
||||
pair.Send(b, Ping, nil)
|
||||
pair.Send(b, Pong, nil)
|
||||
b.ReportAllocs()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -52,18 +52,11 @@ const (
|
|||
WGLabelCookie = "cookie--"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMessageInitiationType uint32 = 1
|
||||
DefaultMessageResponseType uint32 = 2
|
||||
DefaultMessageCookieReplyType uint32 = 3
|
||||
DefaultMessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
var (
|
||||
MessageInitiationType uint32 = DefaultMessageInitiationType
|
||||
MessageResponseType uint32 = DefaultMessageResponseType
|
||||
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
|
||||
MessageTransportType uint32 = DefaultMessageTransportType
|
||||
MessageInitiationType uint32 = 1
|
||||
MessageResponseType uint32 = 2
|
||||
MessageCookieReplyType uint32 = 3
|
||||
MessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -82,10 +75,9 @@ const (
|
|||
MessageTransportOffsetContent = 16
|
||||
)
|
||||
|
||||
var (
|
||||
packetSizeToMsgType map[int]uint32
|
||||
msgTypeToJunkSize map[uint32]int
|
||||
)
|
||||
var packetSizeToMsgType map[int]uint32
|
||||
|
||||
var msgTypeToJunkSize map[uint32]int
|
||||
|
||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||
* by marshalling the messages in little-endian byteorder
|
||||
|
@ -205,12 +197,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
device.awg.ASecMux.RLock()
|
||||
device.aSecMux.RLock()
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
@ -264,12 +256,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
device.awg.ASecMux.RLock()
|
||||
device.aSecMux.RLock()
|
||||
if msg.Type != MessageInitiationType {
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
@ -384,9 +376,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
var msg MessageResponse
|
||||
device.awg.ASecMux.RLock()
|
||||
device.aSecMux.RLock()
|
||||
msg.Type = MessageResponseType
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
|
@ -436,12 +428,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
device.awg.ASecMux.RLock()
|
||||
device.aSecMux.RLock()
|
||||
if msg.Type != MessageResponseType {
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -13,7 +13,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
|
@ -114,16 +113,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
|
||||
err := peer.SendBuffers(buffers)
|
||||
if err == nil {
|
||||
awg.PacketCounter.Add(uint64(len(buffers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type WaitPool struct {
|
||||
pool sync.Pool
|
||||
cond sync.Cond
|
||||
lock sync.Mutex
|
||||
count uint32 // Get calls not yet Put back
|
||||
count atomic.Uint32
|
||||
max uint32
|
||||
}
|
||||
|
||||
|
@ -26,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
|
|||
func (p *WaitPool) Get() any {
|
||||
if p.max != 0 {
|
||||
p.lock.Lock()
|
||||
for p.count >= p.max {
|
||||
for p.count.Load() >= p.max {
|
||||
p.cond.Wait()
|
||||
}
|
||||
p.count++
|
||||
p.count.Add(1)
|
||||
p.lock.Unlock()
|
||||
}
|
||||
return p.pool.Get()
|
||||
|
@ -40,9 +41,7 @@ func (p *WaitPool) Put(x any) {
|
|||
if p.max == 0 {
|
||||
return
|
||||
}
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.count--
|
||||
p.count.Add(^uint32(0))
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -32,9 +32,7 @@ func TestWaitPool(t *testing.T) {
|
|||
wg.Add(workers)
|
||||
var max atomic.Uint32
|
||||
updateMax := func() {
|
||||
p.lock.Lock()
|
||||
count := p.count
|
||||
p.lock.Unlock()
|
||||
count := p.count.Load()
|
||||
if count > p.max {
|
||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
deathSpiral = 0
|
||||
|
||||
device.awg.ASecMux.RLock()
|
||||
device.aSecMux.RLock()
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
|
@ -137,14 +137,10 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
var msgType uint32
|
||||
if device.isAWG() {
|
||||
// TODO:
|
||||
// if awg.WaitResponse.ShouldWait.IsSet() {
|
||||
// awg.WaitResponse.Channel <- struct{}{}
|
||||
// }
|
||||
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||
// transport size can align with other header types;
|
||||
|
@ -153,29 +149,19 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
if msgType == assumedMsgType {
|
||||
packet = packet[junkSize:]
|
||||
} else {
|
||||
device.log.Verbosef("transport packet lined up with another msg type")
|
||||
device.log.Verbosef("Transport packet lined up with another msg type")
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
}
|
||||
} else {
|
||||
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
|
||||
msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
if msgType != MessageTransportType {
|
||||
// probably a junk packet
|
||||
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
|
||||
device.log.Verbosef("ASec: Received message with unknown type")
|
||||
continue
|
||||
}
|
||||
|
||||
// remove junk from bufsArrs by shifting the packet
|
||||
// this buffer is also used for decryption, so it needs to be corrected
|
||||
copy(bufsArrs[i][:size], packet[transportJunkSize:])
|
||||
size -= transportJunkSize
|
||||
// need to reinitialize packet as well
|
||||
packet = packet[:size]
|
||||
}
|
||||
} else {
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
@ -259,7 +245,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
default:
|
||||
}
|
||||
}
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
|
@ -318,7 +304,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
device.awg.ASecMux.RLock()
|
||||
device.aSecMux.RLock()
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
|
@ -470,7 +456,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.aSecMux.RUnlock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
|
105
device/send.go
105
device/send.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -9,6 +9,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
@ -124,30 +125,12 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
return err
|
||||
}
|
||||
var sendBuffer [][]byte
|
||||
|
||||
// so only packet processed for cookie generation
|
||||
var junkedHeader []byte
|
||||
if peer.device.version >= VersionAwg {
|
||||
var junks [][]byte
|
||||
if peer.device.version == VersionAwgSpecialHandshake {
|
||||
peer.device.awg.ASecMux.RLock()
|
||||
// set junks depending on packet type
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||
if junks == nil {
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
|
||||
if junks != nil {
|
||||
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
|
||||
}
|
||||
} else {
|
||||
peer.device.log.Verbosef("%v - Special junks sent", peer)
|
||||
}
|
||||
peer.device.awg.ASecMux.RUnlock()
|
||||
} else {
|
||||
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
|
||||
}
|
||||
peer.device.awg.ASecMux.RLock()
|
||||
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
|
||||
peer.device.awg.ASecMux.RUnlock()
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.aSecMux.RLock()
|
||||
junks, err := peer.createJunkPackets()
|
||||
peer.device.aSecMux.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
|
@ -163,11 +146,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
|
||||
peer.device.aSecMux.RLock()
|
||||
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
peer.device.aSecMux.RUnlock()
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.aSecMux.RUnlock()
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
|
@ -182,7 +173,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
|
||||
sendBuffer = append(sendBuffer, junkedHeader)
|
||||
|
||||
err = peer.SendAndCountBuffers(sendBuffer)
|
||||
err = peer.SendBuffers(sendBuffer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
|
@ -203,13 +194,22 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.aSecMux.RLock()
|
||||
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
||||
if err != nil {
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.aSecMux.RUnlock()
|
||||
}
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
||||
|
@ -229,7 +229,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
|
||||
err = peer.SendBuffers([][]byte{junkedHeader})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
|
@ -252,19 +252,11 @@ func (device *Device) SendHandshakeCookie(
|
|||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
|
||||
junkedHeader = append(junkedHeader, writer.Bytes()...)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
|
||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -477,6 +469,31 @@ top:
|
|||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) createJunkPackets() ([][]byte, error) {
|
||||
if peer.device.aSecCfg.junkPacketCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
|
||||
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
|
||||
packetSize := rand.Intn(
|
||||
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
|
||||
) + peer.device.aSecCfg.junkPacketMinSize
|
||||
|
||||
junk, err := randomJunkWithSize(packetSize)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf(
|
||||
"%v - Failed to create junk packet: %v",
|
||||
peer,
|
||||
err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
junks = append(junks, junk)
|
||||
}
|
||||
return junks, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
|
@ -577,7 +594,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
|
@ -585,14 +601,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
for _, elem := range elemsContainer.elems {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
|
||||
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = append(junkedHeader, elem.packet...)
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
|
@ -600,11 +608,10 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendAndCountBuffers(bufs)
|
||||
err := peer.SendBuffers(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
|
|
|
@ -7,6 +7,6 @@ import (
|
|||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
|
@ -9,7 +9,7 @@
|
|||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code remains platform dependent.
|
||||
* So this code is remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -47,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
|
|||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
224
device/uapi.go
224
device/uapi.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -18,7 +18,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
)
|
||||
|
||||
|
@ -98,51 +97,33 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
if device.isAWG() {
|
||||
if device.awg.ASecCfg.JunkPacketCount != 0 {
|
||||
sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if device.aSecCfg.junkPacketCount != 0 {
|
||||
sendf("jc=%d", device.aSecCfg.junkPacketCount)
|
||||
}
|
||||
if device.awg.ASecCfg.JunkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
|
||||
if device.aSecCfg.junkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
|
||||
}
|
||||
if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
|
||||
if device.aSecCfg.junkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
|
||||
}
|
||||
if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
|
||||
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
|
||||
if device.aSecCfg.initPacketJunkSize != 0 {
|
||||
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
|
||||
}
|
||||
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
|
||||
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
|
||||
if device.aSecCfg.responsePacketJunkSize != 0 {
|
||||
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
|
||||
}
|
||||
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
|
||||
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
|
||||
if device.aSecCfg.initPacketMagicHeader != 0 {
|
||||
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
|
||||
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
|
||||
if device.aSecCfg.responsePacketMagicHeader != 0 {
|
||||
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
|
||||
sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
|
||||
if device.aSecCfg.underloadPacketMagicHeader != 0 {
|
||||
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
|
||||
sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
|
||||
sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
||||
}
|
||||
|
||||
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||
for _, field := range specialJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
|
||||
for _, field := range controlledJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
if device.awg.HandshakeHandler.ITimeout != 0 {
|
||||
sendf("itime=%d", device.awg.HandshakeHandler.ITimeout/time.Second)
|
||||
if device.aSecCfg.transportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -199,13 +180,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
tempAwg := awg.Protocol{}
|
||||
tempASecCfg := aSecCfgType{}
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
err := device.handlePostConfig(&tempAwg)
|
||||
err := device.handlePostConfig(&tempASecCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -236,7 +217,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value, &tempAwg)
|
||||
err = device.handleDeviceLine(key, value, &tempASecCfg)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
|
@ -244,7 +225,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
err = device.handlePostConfig(&tempAwg)
|
||||
err = device.handlePostConfig(&tempASecCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -256,7 +237,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
|
||||
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
|
@ -297,11 +278,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
|||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set replace_peers, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
@ -309,138 +286,80 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
|||
case "jc":
|
||||
junkPacketCount, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||
tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.junkPacketCount = junkPacketCount
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "jmin":
|
||||
junkPacketMinSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||
tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "jmax":
|
||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "s1":
|
||||
initPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "s2":
|
||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "s3":
|
||||
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
|
||||
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "s4":
|
||||
transportPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
|
||||
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "h1":
|
||||
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
|
||||
}
|
||||
tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "h2":
|
||||
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
|
||||
}
|
||||
tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "h3":
|
||||
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
|
||||
}
|
||||
tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
case "h4":
|
||||
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
|
||||
}
|
||||
tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
case "i1", "i2", "i3", "i4", "i5":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
return nil
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
generators, err := awg.Parse(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating %s", key)
|
||||
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
case "j1", "j2", "j3":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
generators, err := awg.Parse(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating %s", key)
|
||||
|
||||
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
case "itime":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty itime")
|
||||
return nil
|
||||
}
|
||||
|
||||
itime, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating itime")
|
||||
|
||||
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
|
@ -513,11 +432,7 @@ func (device *Device) handlePeerLine(
|
|||
case "update_only":
|
||||
// allow disabling of creation
|
||||
if value != "true" {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set update only, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
}
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
|
@ -563,11 +478,7 @@ func (device *Device) handlePeerLine(
|
|||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set persistent keepalive interval: %w",
|
||||
err,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||
}
|
||||
|
||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||
|
@ -578,11 +489,7 @@ func (device *Device) handlePeerLine(
|
|||
case "replace_allowed_ips":
|
||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||
if value != "true" {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to replace allowedips, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
|
@ -590,14 +497,7 @@ func (device *Device) handlePeerLine(
|
|||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
add := true
|
||||
verb := "Adding"
|
||||
if len(value) > 0 && value[0] == '-' {
|
||||
add = false
|
||||
verb = "Removing"
|
||||
value = value[1:]
|
||||
}
|
||||
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
|
@ -605,11 +505,7 @@ func (device *Device) handlePeerLine(
|
|||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
if add {
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
} else {
|
||||
device.allowedips.Remove(prefix, peer.Peer)
|
||||
}
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
|
@ -661,11 +557,7 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
|||
return
|
||||
}
|
||||
if nextByte != '\n' {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"trailing character in UAPI get: %q",
|
||||
nextByte,
|
||||
)
|
||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||
break
|
||||
}
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
|
|
25
device/util.go
Normal file
25
device/util.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func appendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk, err := randomJunkWithSize(size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create header junk: %v", err)
|
||||
}
|
||||
_, err = writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write header junk: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomJunkWithSize(size int) ([]byte, error) {
|
||||
junk := make([]byte, size)
|
||||
_, err := crand.Read(junk)
|
||||
return junk, err
|
||||
}
|
27
device/util_test.go
Normal file
27
device/util_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_randomJunktWithSize(t *testing.T) {
|
||||
junk, err := randomJunkWithSize(30)
|
||||
fmt.Println(string(junk), len(junk), err)
|
||||
}
|
||||
|
||||
func Test_appendJunk(t *testing.T) {
|
||||
t.Run("", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
err := appendJunk(buffer, 30)
|
||||
if err != nil &&
|
||||
buffer.Len() != len(s)+30 {
|
||||
t.Errorf("appendWithJunk() size don't match")
|
||||
}
|
||||
read := make([]byte, 50)
|
||||
buffer.Read(read)
|
||||
fmt.Println(string(read))
|
||||
})
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
package main
|
||||
|
||||
|
|
20
go.mod
20
go.mod
|
@ -1,23 +1,17 @@
|
|||
module github.com/amnezia-vpn/amneziawg-go
|
||||
|
||||
go 1.24.4
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tevino/abool v1.2.0
|
||||
github.com/tevino/abool/v2 v2.1.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/net v0.41.0
|
||||
golang.org/x/sys v0.33.0
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/net v0.21.0
|
||||
golang.org/x/sys v0.17.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/time v0.9.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||
)
|
||||
|
|
48
go.sum
48
go.sum
|
@ -1,40 +1,16 @@
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
|
||||
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
||||
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
|
||||
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
2
main.go
2
main.go
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package replay
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package rwcancel implements cancelable read/write operations on
|
||||
|
@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool {
|
|||
|
||||
func (rw *RWCancel) ReadyWrite() bool {
|
||||
closeFd := int32(rw.closingReader.Fd())
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}}
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(pollFds, -1)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
|
122
tun/checksum.go
122
tun/checksum.go
|
@ -1,86 +1,102 @@
|
|||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
import "encoding/binary"
|
||||
|
||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
|
||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||
tmp := make([]byte, 8)
|
||||
binary.NativeEndian.PutUint64(tmp, initial)
|
||||
ac := binary.BigEndian.Uint64(tmp)
|
||||
var carry uint64
|
||||
ac := initial
|
||||
|
||||
for len(b) >= 128 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
|
||||
b = b[128:]
|
||||
}
|
||||
if len(b) >= 64 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint32(b))
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
|
||||
ac += carry
|
||||
ac += uint64(binary.BigEndian.Uint16(b))
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) == 1 {
|
||||
tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
|
||||
ac, carry = bits.Add64(ac, uint64(tmp), 0)
|
||||
ac += carry
|
||||
ac += uint64(b[0]) << 8
|
||||
}
|
||||
|
||||
binary.NativeEndian.PutUint64(tmp, ac)
|
||||
return binary.BigEndian.Uint64(tmp)
|
||||
return ac
|
||||
}
|
||||
|
||||
func checksum(b []byte, initial uint64) uint16 {
|
||||
|
|
|
@ -1,74 +1,11 @@
|
|||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func checksumRef(b []byte, initial uint16) uint16 {
|
||||
ac := uint64(initial)
|
||||
|
||||
for len(b) >= 2 {
|
||||
ac += uint64(binary.BigEndian.Uint16(b))
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) == 1 {
|
||||
ac += uint64(b[0]) << 8
|
||||
}
|
||||
|
||||
for (ac >> 16) > 0 {
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
}
|
||||
return uint16(ac)
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
sum := checksumRef(srcAddr, 0)
|
||||
sum = checksumRef(dstAddr, sum)
|
||||
sum = checksumRef([]byte{0, protocol}, sum)
|
||||
tmp := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||
return checksumRef(tmp, sum)
|
||||
}
|
||||
|
||||
func TestChecksum(t *testing.T) {
|
||||
for length := 0; length <= 9001; length++ {
|
||||
buf := make([]byte, length)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rng.Read(buf)
|
||||
csum := checksum(buf, 0x1234)
|
||||
csumRef := checksumRef(buf, 0x1234)
|
||||
if csum != csumRef {
|
||||
t.Error("Expected checksum", csumRef, "got", csum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPseudoHeaderChecksum(t *testing.T) {
|
||||
for _, addrLen := range []int{4, 16} {
|
||||
for length := 0; length <= 9001; length++ {
|
||||
srcAddr := make([]byte, addrLen)
|
||||
dstAddr := make([]byte, addrLen)
|
||||
buf := make([]byte, length)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rng.Read(srcAddr)
|
||||
rng.Read(dstAddr)
|
||||
rng.Read(buf)
|
||||
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
|
||||
csum := checksum(buf, phSum)
|
||||
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
|
||||
csumRef := checksumRef(buf, phSumRef)
|
||||
if csum != csumRef {
|
||||
t.Error("Expected checksumRef", csumRef, "got", csum)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkChecksum(b *testing.B) {
|
||||
lengths := []int{
|
||||
64,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue