Compare commits

...

83 commits

Author SHA1 Message Date
pokamest
1abd24b5b9
Merge pull request #85 from amnezia-vpn/hotfix/docker-script
fix: restore Dockerfile changes
2025-07-07 16:20:58 +03:00
Yaroslav Gurov
3f19f1c657 fix: restore Dockerfile 2025-07-07 15:15:29 +02:00
Mykola Baibuz
c207898480
AmneziaWG v1.5 (#84) 2025-07-07 13:34:51 +01:00
pokamest
fe75b639fa
Merge pull request #78 from jmwample/jmwample/upstream
Sync with Major Upstream changes
2025-07-02 03:01:39 +01:00
jmwample
169ed49a46
fix formatting discrepancy 2025-06-23 14:56:43 -06:00
Jason A. Donenfeld
eeb8aae13e
version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:56:43 -06:00
Jason A. Donenfeld
99f2e6d66f
conn: don't enable GRO on Linux < 5.12
Kernels below 5.12 are missing this:

    commit 98184612aca0a9ee42b8eb0262a49900ee9eef0d
    Author: Norman Maurer <norman_maurer@apple.com>
    Date:   Thu Apr 1 08:59:17 2021

        net: udp: Add support for getsockopt(..., ..., UDP_GRO, ..., ...);

        Support for UDP_GRO was added in the past but the implementation for
        getsockopt was missed which did lead to an error when we tried to
        retrieve the setting for UDP_GRO. This patch adds the missing switch
        case for UDP_GRO

        Fixes: e20cf8d3f1f7 ("udp: implement GRO for plain UDP sockets.")
        Signed-off-by: Norman Maurer <norman_maurer@apple.com>
        Reviewed-by: David Ahern <dsahern@kernel.org>
        Signed-off-by: David S. Miller <davem@davemloft.net>

That means we can't set the option and then read it back later. Given
how buggy UDP_GRO is in general on odd kernels, just disable it on older
kernels all together.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:56:43 -06:00
Jason A. Donenfeld
d5359f52f0
device: add support for removing allowedips individually
This pairs with the recent change in wireguard-tools.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:56:43 -06:00
Jason A. Donenfeld
6768090667
version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:56:43 -06:00
Kurnia D Win
2cad62c40b
rwcancel: fix wrong poll event flag on ReadyWrite
It should be POLLIN because closeFd is read-only file.

Signed-off-by: Kurnia D Win <kurnia.d.win@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:27:20 -06:00
Tom Holford
8051f17147
device: use rand.NewSource instead of rand.Seed
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:27:06 -06:00
Tom Holford
ace3e11ef2
global: replaced unused function params with _
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:26:55 -06:00
ruokeqx
8a2b2bf4f4
tun: darwin: fetch flags and mtu from if_msghdr directly
Signed-off-by: ruokeqx <ruokeqx@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:26:37 -06:00
Tu Dinh Ngoc
75d6c67a67
tun: use add-with-carry in checksumNoFold()
Use parallel summation with native byte order per RFC 1071.
add-with-carry operation is used to add 4 words per operation.  Byteswap
is performed before and after checksumming for compatibility with old
`checksumNoFold()`.  With this we get a 30-80% speedup in `checksum()`
depending on packet sizes.

Add unit tests with comparison to a per-word implementation.

**Intel(R) Xeon(R) Silver 4210R CPU @ 2.40GHz**

| Size | OldTime | NewTime | Speedup  |
|------|---------|---------|----------|
| 64   | 12.64   | 9.183   | 1.376456 |
| 128  | 18.52   | 12.72   | 1.455975 |
| 256  | 31.01   | 18.13   | 1.710425 |
| 512  | 54.46   | 29.03   | 1.87599  |
| 1024 | 102     | 52.2    | 1.954023 |
| 1500 | 146.8   | 81.36   | 1.804326 |
| 2048 | 196.9   | 102.5   | 1.920976 |
| 4096 | 389.8   | 200.8   | 1.941235 |
| 8192 | 767.3   | 413.3   | 1.856521 |
| 9000 | 851.7   | 448.8   | 1.897727 |
| 9001 | 854.8   | 451.9   | 1.891569 |

**AMD EPYC 7352 24-Core Processor**

| Size | OldTime | NewTime | Speedup  |
|------|---------|---------|----------|
| 64   | 9.159   | 6.949   | 1.318031 |
| 128  | 13.59   | 10.59   | 1.283286 |
| 256  | 22.37   | 14.91   | 1.500335 |
| 512  | 41.42   | 24.22   | 1.710157 |
| 1024 | 81.59   | 45.05   | 1.811099 |
| 1500 | 120.4   | 68.35   | 1.761522 |
| 2048 | 162.8   | 90.14   | 1.806079 |
| 4096 | 321.4   | 180.3   | 1.782585 |
| 8192 | 650.4   | 360.8   | 1.802661 |
| 9000 | 706.3   | 398.1   | 1.774177 |
| 9001 | 712.4   | 398.2   | 1.789051 |

Signed-off-by: Tu Dinh Ngoc <dinhngoc.tu@irit.fr>
[Jason: simplified and cleaned up unit tests]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:26:25 -06:00
Jason A. Donenfeld
ac8a885a03
tun/netstack: cleanup network stack at closing time
Colin's commit went a step further and protected tun.incomingPacket with
a lock on shutdown, but let's see if the tun.stack.Close() call actually
solves that on its own.

Suggested-by: kshangx <hikeshang@hotmail.com>
Suggested-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:26:07 -06:00
Jason A. Donenfeld
6a7c878409
tun/netstack: remove usage of pkt.IsNil()
Since 3c75945fd ("netstack: remove PacketBuffer.IsNil()") this has been
invalid. Follow the replacement pattern of that commit.

The old definition inlined to the same code anyway:

 func (pk *PacketBuffer) IsNil() bool {
 	return pk == nil
 }

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:25:57 -06:00
Jason A. Donenfeld
704d57c27a
mod: bump deps
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:25:36 -06:00
Jason A. Donenfeld
c0b6e6a200
global: bump copyright notice
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:20:32 -06:00
Jordan Whited
c803ce1e5b
device: fix missed return of QueueOutboundElementsContainer to its WaitPool
Fixes: 3bb8fec ("conn, device, tun: implement vectorized I/O plumbing")
Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:18:41 -06:00
Jordan Whited
deedce495a
device: fix WaitPool sync.Cond usage
The sync.Locker used with a sync.Cond must be acquired when changing
the associated condition, otherwise there is a window within
sync.Cond.Wait() where a wake-up may be missed.

Fixes: 4846070 ("device: use a waiting sync.Pool instead of a channel")
Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2025-06-23 14:18:28 -06:00
pokamest
27e661d68e
Merge pull request #70 from marko1777/junk-improvements
Junk improvements
2025-04-07 15:31:41 +01:00
Mark Puha
71be0eb3a6 faster and more secure junk creation 2025-03-18 08:34:23 +01:00
pokamest
e3f1273f8a
Merge pull request #64 from drkivi/master
Patch for golang crypto and net submodules
2025-02-18 11:50:35 +00:00
drkivi
c97b5b7615
Update go.sum
Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:58 +03:30
drkivi
668ddfd455
Update go.mod
Submodules Version Up

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:17 +03:30
drkivi
b8da08c106
Update Dockerfile
golang -> 1.23.6
AWGTOOLS_RELEASE -> 1.0.20241018

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:43:02 +03:30
Iurii Egorov
2e3f7d122c Update Go version in Dockerfile 2024-07-01 13:47:44 +03:00
Iurii Egorov
2e7780471a
Remove GetOffloadInfo() (#32)
* Remove GetOffloadInfo()
* Remove GetOffloadInfo() from bind_windows as well
* Allow lightweight tags to be used in the version
2024-05-24 16:18:23 +01:00
albexk
87d8c00f86 Up go to 1.22.3, up crypto to 0.21.0 2024-05-21 08:09:58 -07:00
albexk
c00bda9200 Fix output of the version command 2024-05-14 03:51:01 -07:00
albexk
d2b0fc9789 Add resetting of message types when closing the device 2024-05-14 03:51:01 -07:00
albexk
77d39ff3b9 Minor naming changes 2024-05-14 03:51:01 -07:00
albexk
e433d13df6 Add disabling UDP GSO when an error occurs due to inconsistent peer mtu 2024-05-14 03:51:01 -07:00
RomikB
3ddf952973 unsafe rebranding: change pipe name 2024-05-13 11:10:42 -07:00
albexk
3f0a3bcfa0 Fix wg reconnection problem after awg connection 2024-03-16 14:16:13 +00:00
AlexanderGalkov
4dddf62e57 Update Dockerfile
add wg and wg-quick symlinks

Signed-off-by: AlexanderGalkov <143902290+AlexanderGalkov@users.noreply.github.com>
2024-02-20 20:32:38 +07:00
tiaga
827ec6e14b
Merge pull request #21 from amnezia-vpn/fix-dockerfile
Fix Dockerfile
2024-02-13 21:47:55 +07:00
tiaga
92e28a0d14 Fix Dockerfile
Fix AmneziaWG tools installation.
2024-02-13 21:44:41 +07:00
tiaga
52fed4d362
Merge pull request #20 from amnezia-vpn/update_dockerfile
Update Dockerfile
2024-02-13 21:28:17 +07:00
tiaga
9c6b3ff332 Update Dockerfile
- rename `wg` and `wg-quick` to `awg` and `awg-quick` accordingly
- add iptables
- update AmneziaWG tools version
2024-02-13 21:27:34 +07:00
pokamest
7de7a9a754
Merge pull request #19 from amnezia-vpn/fix/go_sum
Fix go.sum
2024-02-12 05:31:57 -08:00
albexk
0c347529b8 Fix go.sum 2024-02-12 16:27:56 +03:00
albexk
6705978fc8 Add debug udp offload info 2024-02-12 04:40:23 -08:00
albexk
032e33f577 Fix Android UDP GRO check 2024-02-12 04:40:23 -08:00
albexk
59101fd202 Bump crypto, net, sys modules to the latest versions 2024-02-12 04:40:23 -08:00
tiaga
8bcfbac230
Merge pull request #17 from amnezia-vpn/fix-pipeline
Fix pipeline
2024-02-07 19:19:11 +07:00
tiaga
f0dfb5eacc Fix pipeline
Fix path to GitHub Actions workflow.
2024-02-07 18:53:55 +07:00
tiaga
9195025d8f
Merge pull request #16 from amnezia-vpn/pipeline
Add pipeline
2024-02-07 18:48:12 +07:00
tiaga
cbd414dfec Add pipeline
Build and push Docker image on a tag push.
2024-02-07 18:44:59 +07:00
tiaga
7155d20913
Merge pull request #14 from amnezia-vpn/docker
Update Dockerfile
2024-02-02 23:40:39 +07:00
tiaga
bfeb3954f6 Update Dockerfile
- update Alpine version
- improve `Dockerfile` to use pre-built AmneziaWG tools
2024-02-02 22:56:00 +07:00
Iurii Egorov
e3c9ec8012 Naming unify 2024-01-20 23:18:10 +03:00
pokamest
ce9d3866a3
Merge pull request #10 from amnezia-vpn/upstream-merge
Merge upstream changes
2024-01-14 13:37:00 -05:00
Iurii Egorov
e5f355e843 Fix incorrect configuration handling for zero-valued Jc 2024-01-14 18:22:02 +03:00
Iurii Egorov
c05b2ee2a3 Merge remote-tracking branch 'upstream/master' into upstream-merge 2024-01-09 21:34:30 +03:00
tiaga
180c9284f3
Merge pull request #7 from amnezia-vpn/docker
Add Dockerfile
2023-12-19 18:50:48 +07:00
tiaga
015e11875d Add Dockerfile
Build Docker image with the corresponding wg-tools version.
2023-12-19 18:46:04 +07:00
Martin Basovnik
12269c2761 device: fix possible deadlock in close method
There is a possible deadlock in `device.Close()` when you try to close
the device very soon after its start. The problem is that two different
methods acquire the same locks in different order:

1. device.Close()
 - device.ipcMutex.Lock()
 - device.state.Lock()

2. device.changeState(deviceState)
 - device.state.Lock()
 - device.ipcMutex.Lock()

Reproducer:

    func TestDevice_deadlock(t *testing.T) {
    	d := randDevice(t)
    	d.Close()
    }

Problem:

    $ go clean -testcache && go test -race -timeout 3s -run TestDevice_deadlock ./device | grep -A 10 sync.runtime_SemacquireMutex
    sync.runtime_SemacquireMutex(0xc000117d20?, 0x94?, 0x0?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    golang.zx2c4.com/wireguard/device.(*Device).Close(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:373 +0xb6
    golang.zx2c4.com/wireguard/device.TestDevice_deadlock(0x0?)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device_test.go:480 +0x2c
    testing.tRunner(0xc00014c000, 0x131d7b0)
    --
    sync.runtime_SemacquireMutex(0xc000130564?, 0x60?, 0xc000130548?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    sync.(*RWMutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/rwmutex.go:147 +0x45
    golang.zx2c4.com/wireguard/device.(*Device).upLocked(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:179 +0x72
    golang.zx2c4.com/wireguard/device.(*Device).changeState(0xc000130500, 0x1)

Signed-off-by: Martin Basovnik <martin.basovnik@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:38:47 +01:00
Jason A. Donenfeld
542e565baa device: do atomic 64-bit add outside of vector loop
Only bother updating the rxBytes counter once we've processed a whole
vector, since additions are atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:35:57 +01:00
Jordan Whited
7c20311b3d device: reduce redundant per-packet overhead in RX path
Peer.RoutineSequentialReceiver() deals with packet vectors and does not
need to perform timer and endpoint operations for every packet in a
given vector. Changing these per-packet operations to per-vector
improves throughput by as much as 10% in some environments.

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
4ffa9c2032 device: change Peer.endpoint locking to reduce contention
Access to Peer.endpoint was previously synchronized by Peer.RWMutex.
This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the
sole caller of Endpoint.ClearSrc(), which is signaled via a new bool,
Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now
set this bool, primarily via peer.markEndpointSrcForClearing().
Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an
updated conn.Endpoint is stored. This maintains the same event order as
before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx
is set, but before the next Peer.SendBuffers() call results in the
latest conn.Endpoint source being used for the next packet transmission.

These changes result in throughput improvements for single flow,
parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP
tests as measured on both Linux and Windows. Latency under load improves
especially for high throughput Linux scenarios. These improvements are
likely realized on all platforms to some degree, as the changes are not
platform-specific.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
d0bc03c707 tun: implement UDP GSO/GRO for Linux
Implement UDP GSO and GRO for the Linux tun.Device, which is made
possible by virtio extensions in the kernel's TUN driver starting in
v6.2.

secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is
used to demonstrate the effect of this commit between two Linux
computers with i5-12400 CPUs. There is roughly ~13us of round trip
latency between them. secnetperf was invoked with the following command
line options:
-stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0

The first result is from commit 2e0774f without UDP GSO/GRO on the TUN.

[conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us
SendTotalPackets=55859 SendSuspectedLostPackets=61
SendSpuriousLostPackets=59 SendCongestionCount=27
SendEcnCongestionCount=0 RecvTotalPackets=2779122
RecvReorderedPackets=0 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms).

The second result is with UDP GSO/GRO on the TUN.

[conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us
SendTotalPackets=165033 SendSuspectedLostPackets=64
SendSpuriousLostPackets=61 SendCongestionCount=53
SendEcnCongestionCount=0 RecvTotalPackets=11845268
RecvReorderedPackets=25267 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms).

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:27:22 +01:00
Jordan Whited
1cf89f5339 tun: fix Device.Read() buf length assumption on Windows
The length of a packet read from the underlying TUN device may exceed
the length of a supplied buffer when MTU exceeds device.MaxMessageSize.

Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:20:49 +01:00
pokamest
b43118018e
Merge pull request #4 from amnezia-vpn/upstream-merge
Upstream merge
2023-11-30 10:49:31 -08:00
Iurii Egorov
7af55a3e6f Merge remote-tracking branch 'upstream/master' 2023-11-17 22:42:13 +03:00
pokamest
c493b95f66
Update README.md
Signed-off-by: pokamest <pokamest@gmail.com>
2023-10-25 22:41:33 +01:00
Jason A. Donenfeld
2e0774f246 device: ratchet up max segment size on android
GRO requires big allocations to be efficient. This isn't great, as there
might be Android memory usage issues. So we should revisit this commit.
But at least it gets things working again.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-22 02:12:13 +02:00
Jason A. Donenfeld
b3df23dcd4 conn: set unused OOB to zero length
Otherwise in the event that we're using GSO without sticky sockets, we
pass garbage OOB buffers to sendmmsg, making a EINVAL, when GSO doesn't
set its header.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 19:32:07 +02:00
Jason A. Donenfeld
f502ec3fad conn: fix cmsg data padding calculation for gso
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 19:06:38 +02:00
Jason A. Donenfeld
5d37bd24e1 conn: separate gso and sticky control
Android wants GSO but not sticky.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 18:44:01 +02:00
Jason A. Donenfeld
24ea13351e conn: harmonize GOOS checks between "linux" and "android"
Otherwise GRO gets enabled on Android, but the conn doesn't use it,
resulting in bundled packets being discarded.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-18 21:14:13 +02:00
Jason A. Donenfeld
177caa7e44 conn: simplify supportsUDPOffload
This allows a kernel to support UDP_GRO while not supporting
UDP_SEGMENT.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-18 21:02:52 +02:00
Mazay B
b81ca925db peer.device.aSecMux.RLock added 2023-10-14 11:42:30 +01:00
James Tucker
42ec952ead go.mod,tun/netstack: bump gvisor
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:37:17 +02:00
James Tucker
ec8f6f82c2 tun: fix crash when ForceMTU is called after close
Close closes the events channel, resulting in a panic from send on
closed channel.

Reported-By: Brad Fitzpatrick <brad@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:37:17 +02:00
Jordan Whited
1ec454f253 device: move Queue{In,Out}boundElement Mutex to container type
Queue{In,Out}boundElement locking can contribute to significant
overhead via sync.Mutex.lockSlow() in some environments. These types
are passed throughout the device package as elements in a slice, so
move the per-element Mutex to a container around the slice.

Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
8a015f7c76 tun: reduce redundant checksumming in tcpGRO()
IPv4 header and pseudo header checksums were being computed on every
merge operation. Additionally, virtioNetHdr was being written at the
same time. This delays those operations until after all coalescing has
occurred.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
895d6c23cd tun: unwind summing loop in checksumNoFold()
$ benchstat old.txt new.txt
goos: linux
goarch: amd64
pkg: golang.zx2c4.com/wireguard/tun
cpu: 12th Gen Intel(R) Core(TM) i5-12400
                 │   old.txt    │               new.txt               │
                 │    sec/op    │   sec/op     vs base                │
Checksum/64-12     10.670n ± 2%   4.769n ± 0%  -55.30% (p=0.000 n=10)
Checksum/128-12    19.665n ± 2%   8.032n ± 0%  -59.16% (p=0.000 n=10)
Checksum/256-12     37.68n ± 1%   16.06n ± 0%  -57.37% (p=0.000 n=10)
Checksum/512-12     76.61n ± 3%   32.13n ± 0%  -58.06% (p=0.000 n=10)
Checksum/1024-12   160.55n ± 4%   64.25n ± 0%  -59.98% (p=0.000 n=10)
Checksum/1500-12   231.05n ± 7%   94.12n ± 0%  -59.26% (p=0.000 n=10)
Checksum/2048-12    309.5n ± 3%   128.5n ± 0%  -58.48% (p=0.000 n=10)
Checksum/4096-12    603.8n ± 4%   257.2n ± 0%  -57.41% (p=0.000 n=10)
Checksum/8192-12   1185.0n ± 3%   515.5n ± 0%  -56.50% (p=0.000 n=10)
Checksum/9000-12   1328.5n ± 5%   564.8n ± 0%  -57.49% (p=0.000 n=10)
Checksum/9001-12   1340.5n ± 3%   564.8n ± 0%  -57.87% (p=0.000 n=10)
geomean             185.3n        77.99n       -57.92%

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
4201e08f1d device: distribute crypto work as slice of elements
After reducing UDP stack traversal overhead via GSO and GRO,
runtime.chanrecv() began to account for a high percentage (20% in one
environment) of perf samples during a throughput benchmark. The
individual packet channel ops with the crypto goroutines was the primary
contributor to this overhead.

Updating these channels to pass vectors, which the device package
already handles at its ends, reduced this overhead substantially, and
improved throughput.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is with UDP GSO and GRO, and with single element
channels.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

The second result is with channels updated to pass a slice of
elements.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  13.2 GBytes  11.3 Gbits/sec  182   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  13.2 GBytes  11.3 Gbits/sec  182   sender
[  5]   0.00-10.04  sec  13.2 GBytes  11.3 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
6a84778f2c conn, device: use UDP GSO and GRO on Linux
StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is
dependent on checksum offload support on the egress netdev. UDP GSO
will be disabled in the event sendmmsg() returns EIO, which is a strong
signal that the egress netdev does not support checksum offload.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is from commit 052af4a without UDP GSO or GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139   3.01 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139  sender
[  5]   0.00-10.04  sec  9.85 GBytes  8.42 Gbits/sec        receiver

The second result is with UDP GSO and GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
pokamest
b34974c476
Merge pull request #3 from amnezia-vpn/bugfix/uapi_adv_sec_onoff
Manage advanced security via uapi
2023-10-09 06:07:20 -07:00
Mazay B
f30419e0d1 Manage advanced sec via uapi 2023-10-09 13:22:49 +01:00
Mark Puha
8f1a6a10b2
Advanced security (#2)
* Advanced security header layer & config
2023-10-05 21:41:27 +01:00
119 changed files with 5762 additions and 1845 deletions

41
.github/workflows/build-if-tag.yml vendored Normal file
View file

@ -0,0 +1,41 @@
name: build-if-tag
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
env:
APP: amneziawg-go
jobs:
build:
runs-on: ubuntu-latest
name: build
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup metadata
uses: docker/metadata-action@v5
id: metadata
with:
images: amneziavpn/${{ env.APP }}
tags: type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.metadata.outputs.tags }}

2
.gitignore vendored
View file

@ -1 +1 @@
wireguard-go
amneziawg-go

18
Dockerfile Normal file
View file

@ -0,0 +1,18 @@
FROM golang:1.24.4 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
go mod verify && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
unzip -j alpine-3.19-amneziawg-tools.zip && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go

View file

@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true
@$(MAKE) wireguard-go
@$(MAKE) amneziawg-go
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@"
install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
test:
go test ./...
clean:
rm -f wireguard-go
rm -f amneziawg-go
.PHONY: all clean test install generate-version-and-build

View file

@ -1,24 +1,27 @@
# Go Implementation of [WireGuard](https://www.wireguard.com/)
# Go Implementation of AmneziaWG
This is an implementation of WireGuard in Go.
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
## Usage
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
Simply run:
```
$ wireguard-go wg0
$ amneziawg-go wg0
```
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
```
$ wireguard-go -f wg0
$ amneziawg-go -f wg0
```
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@ -26,52 +29,25 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
### Windows
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
### FreeBSD
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
### OpenBSD
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
## Building
This requires an installation of the latest version of [Go](https://go.dev/).
```
$ git clone https://git.zx2c4.com/wireguard-go
$ cd wireguard-go
$ git clone https://github.com/amnezia-vpn/amneziawg-go
$ cd amneziawg-go
$ make
```
## License
Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -8,6 +8,7 @@ package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
@ -29,16 +30,19 @@ var (
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these three fields are not guarded by mu
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
@ -54,23 +58,14 @@ func NewStdNetBind() Bind {
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
@ -179,19 +174,21 @@ again:
}
var fns []ReceiveFunc
if v4conn != nil {
if runtime.GOOS == "linux" {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
if runtime.GOOS == "linux" {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
@ -201,76 +198,101 @@ again:
return fns, uint16(port), nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
return numMsgs, nil
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
@ -293,28 +315,42 @@ func (s *StdNetBind) Close() error {
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
pc6 = s.ipv6PC
br = s.ipv6PC
is6 = true
} else {
pc4 = s.ipv4PC
offload = s.ipv6TxOffload
}
s.mu.Unlock()
@ -324,85 +360,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
return s.send6(conn, pc6, endpoint, bufs)
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
return s.send4(conn, pc4, endpoint, bufs)
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
}
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil || n == len((*msgs)[start:len(bufs)]) {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
n int
err error
start int
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
if runtime.GOOS == "linux" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil || n == len((*msgs)[start:len(bufs)]) {
break
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
if err != nil {
break
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
return n, nil
}

View file

@ -1,6 +1,12 @@
package conn
import "testing"
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
fn(bufs, sizes, eps)
}
}
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func Test_coalesceMessages(t *testing.T) {
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1").To4(),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func mockGetGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_splitCoalescedMessages(t *testing.T) {
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1<<16-1)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -17,7 +17,7 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
)
const (

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package bindtest
@ -12,7 +12,7 @@ import (
"net/netip"
"os"
"golang.zx2c4.com/wireguard/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type ChannelBind struct {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -13,6 +13,35 @@ import (
"golang.org/x/sys/unix"
)
// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
func kernelVersion() (major, minor int) {
var uname unix.Utsname
if err := unix.Uname(&uname); err != nil {
return
}
var (
values [2]int
value, vi int
)
for _, c := range uname.Release {
if '0' <= c && c <= '9' {
value = (value * 10) + int(c-'0')
} else {
// Note that we're assuming N.N.N here.
// If we see anything else, we are likely to mis-parse it.
values[vi] = value
vi++
if vi >= len(values) {
break
}
value = 0
}
}
return values[0], values[1]
}
func init() {
controlFns = append(controlFns,
@ -57,5 +86,24 @@ func init() {
}
return err
},
// Attempt to enable UDP_GRO
func(network, address string, c syscall.RawConn) error {
// Kernels below 5.12 are missing 98184612aca0 ("net:
// udp: Add support for getsockopt(..., ..., UDP_GRO,
// ..., ...);"), which means we can't read this back
// later. We could pipe the return value through to
// the rest of the code, but UDP_GRO is kind of buggy
// anyway, so just gate this here.
major, minor := kernelVersion()
if major < 5 || (major == 5 && minor < 12) {
return nil
}
c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
})
return nil
},
)
}

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

12
conn/errors_default.go Normal file
View file

@ -0,0 +1,12 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(_ error) bool {
return false
}

28
conn/errors_linux.go Normal file
View file

@ -0,0 +1,28 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
// It occurs when the peer mtu + wg headers is greater than path mtu.
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
}
return false
}

15
conn/features_default.go Normal file
View file

@ -0,0 +1,15 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
return
}

31
conn/features_linux.go Normal file
View file

@ -0,0 +1,31 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
// use setsockopt workaround
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
rxOffload = errSyscall == nil
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

21
conn/gso_default.go Normal file
View file

@ -0,0 +1,21 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

65
conn/gso_linux.go Normal file
View file

@ -0,0 +1,65 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
// use alternatively named flags and need ports and require testing.
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
@ -34,8 +35,8 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// srcControlSize returns the recommended buffer size for pooling sticky control
// data.
const srcControlSize = 0
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -105,6 +105,8 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = append(*control, ep.src...)
}
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) {
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, srcControlSize)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) {
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, srcControlSize)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, unix.CmsgLen(0))
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) {
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package winrio

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
}
}
func (node *trieEntry) remove() {
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
return
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
return
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
return
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var node *trieEntry
var exact bool
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
} else {
panic(errors.New("removing unknown address type"))
}
if !exact || node == nil || peer != node.peer {
return
}
node.remove()
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
elem.Value.(*trieEntry).remove()
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) {
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
rng := rand.New(rand.NewSource(1))
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) {
for n := 0; n < NumberOfAddresses; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
rng.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
rng.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) {
for p = 0; ; p++ {
for n := 0; n < NumberOfTests; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
rng.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.Lookup(addr4[:])
if peer1 != peer2 {
@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) {
}
var addr6 [16]byte
rand.Read(addr6[:])
rng.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) {
}
}
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1)
rng := rand.New(rand.NewSource(1))
const AddressLength = 4
@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
rng.Read(addr[:])
cidr := uint8(rng.Uint32() % (AddressLength * 8))
index := rng.Int() % peerNumber
root.insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
rng.Read(addr[:])
trie.lookup(addr[:])
}
}
@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
insert(a, 1, 0, 0, 0, 32)
insert(a, 192, 0, 0, 0, 24)
assertEQ(a, 1, 0, 0, 0)
assertEQ(a, 192, 0, 0, 1)
remove(a, 192, 0, 0, 0, 32)
assertEQ(a, 192, 0, 0, 1)
remove(nil, 192, 0, 0, 0, 24)
assertEQ(a, 192, 0, 0, 1)
remove(b, 192, 0, 0, 0, 24)
assertEQ(a, 192, 0, 0, 1)
remove(a, 192, 0, 0, 0, 24)
assertNEQ(a, 192, 0, 0, 1)
remove(a, 1, 0, 0, 0, 32)
assertNEQ(a, 1, 0, 0, 0)
}
/* Test ported from kernel implementation:
@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
}
}
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
p := allowedIPs.Lookup(addr)
if p == peer {
t.Error("Assert NEQ failed")
}
}
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
}

144
device/awg/awg.go Normal file
View file

@ -0,0 +1,144 @@
package awg
import (
"bytes"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"github.com/tevino/abool"
)
type aSecCfgType struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
JunkPacketMaxSize int
InitHeaderJunkSize int
ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
InitPacketMagicHeader uint32
ResponsePacketMagicHeader uint32
UnderloadPacketMagicHeader uint32
TransportPacketMagicHeader uint32
// InitPacketMagicHeader Limit
// ResponsePacketMagicHeader Limit
// UnderloadPacketMagicHeader Limit
// TransportPacketMagicHeader Limit
}
type Limit struct {
Min uint32
Max uint32
HeaderType uint32
}
func NewLimit(min, max, headerType uint32) (Limit, error) {
if min > max {
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return Limit{
Min: min,
Max: max,
HeaderType: headerType,
}, nil
}
func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
// tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType)
// var min, max, headerType uint32
// _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType)
// if err != nil {
// return Limit{}, fmt.Errorf("invalid magic header format: %s", value)
// }
limits := strings.Split(value, "-")
if len(limits) != 2 {
return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value)
}
min, err := strconv.ParseUint(limits[0], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err)
}
max, err := strconv.ParseUint(limits[1], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err)
}
limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
if err != nil {
return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err)
}
return limit, nil
}
type Limits []Limit
func NewLimits(limits []Limit) Limits {
slices.SortFunc(limits, func(a, b Limit) int {
if a.Min < b.Min {
return -1
} else if a.Min > b.Min {
return 1
}
return 0
})
return Limits(limits)
}
type Protocol struct {
IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex
ASecMux sync.RWMutex
ASecCfg aSecCfgType
JunkCreator junkCreator
HandshakeHandler SpecialHandshakeHandler
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) {
extraSize := 0
if len(optExtraSize) == 1 {
extraSize = optExtraSize[0]
}
var junk []byte
protocol.ASecMux.RLock()
if junkSize != 0 {
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
protocol.ASecMux.RUnlock()
return nil, err
}
junk = writer.Bytes()
}
protocol.ASecMux.RUnlock()
return junk, nil
}

View file

@ -0,0 +1,37 @@
package internal
type mockGenerator struct {
size int
}
func NewMockGenerator(size int) mockGenerator {
return mockGenerator{size: size}
}
func (m mockGenerator) Generate() []byte {
return make([]byte, m.size)
}
func (m mockGenerator) Size() int {
return m.size
}
func (m mockGenerator) Name() string {
return "mock"
}
type mockByteGenerator struct {
data []byte
}
func NewMockByteGenerator(data []byte) mockByteGenerator {
return mockByteGenerator{data: data}
}
func (bg mockByteGenerator) Generate() []byte {
return bg.data
}
func (bg mockByteGenerator) Size() int {
return len(bg.data)
}

View file

@ -0,0 +1,70 @@
package awg
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
aSecCfg aSecCfgType
cha8Rand *v2.ChaCha8
}
// TODO: refactor param to only pass the junk related params
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
if jc.aSecCfg.JunkPacketCount == 0 {
return nil
}
for range jc.aSecCfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return fmt.Errorf("create junk packet: %v", err)
}
*junks = append(*junks, junk)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
),
) + jc.aSecCfg.JunkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
// TODO: use a memory pool to allocate
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

View file

@ -0,0 +1,115 @@
package awg
import (
"bytes"
"fmt"
"testing"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
jc, err := NewJunkCreator(aSecCfgType{
IsSet: true,
JunkPacketCount: 5,
JunkPacketMinSize: 500,
JunkPacketMaxSize: 1000,
InitHeaderJunkSize: 30,
ResponseHeaderJunkSize: 40,
InitPacketMagicHeader: 123456,
ResponsePacketMagicHeader: 67543,
UnderloadPacketMagicHeader: 32345,
TransportPacketMagicHeader: 123123,
})
if err != nil {
t.Errorf("failed to create junk creator %v", err)
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("valid", func(t *testing.T) {
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
err := jc.CreateJunkPackets(&got)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("valid", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
got > jc.aSecCfg.JunkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.aSecCfg.JunkPacketMinSize,
jc.aSecCfg.JunkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("valid", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.AppendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Error("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -0,0 +1,73 @@
package awg
import (
"errors"
"time"
"github.com/tevino/abool"
"go.uber.org/atomic"
)
// TODO: atomic?/ and better way to use this
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
// TODO
var WaitResponse = struct {
Channel chan struct{}
ShouldWait *abool.AtomicBool
}{
make(chan struct{}, 1),
abool.New(),
}
type SpecialHandshakeHandler struct {
isFirstDone bool
SpecialJunk TagJunkPacketGenerators
ControlledJunk TagJunkPacketGenerators
nextItime time.Time
ITimeout time.Duration // seconds
IsSet bool
}
func (handler *SpecialHandshakeHandler) Validate() error {
var errs []error
if err := handler.SpecialJunk.Validate(); err != nil {
errs = append(errs, err)
}
if err := handler.ControlledJunk.Validate(); err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
if !handler.SpecialJunk.IsDefined() {
return nil
}
// TODO: create tests
if !handler.isFirstDone {
handler.isFirstDone = true
} else if !handler.isTimeToSendSpecial() {
return nil
}
rv := handler.SpecialJunk.GeneratePackets()
handler.nextItime = time.Now().Add(handler.ITimeout)
return rv
}
func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
return time.Now().After(handler.nextItime)
}
func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
if !handler.ControlledJunk.IsDefined() {
return nil
}
return handler.ControlledJunk.GeneratePackets()
}

190
device/awg/tag_generator.go Normal file
View file

@ -0,0 +1,190 @@
package awg
import (
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
v2 "math/rand/v2"
// "go.uber.org/atomic"
)
type Generator interface {
Generate() []byte
Size() int
}
type newGenerator func(string) (Generator, error)
type BytesGenerator struct {
value []byte
size int
}
func (bg *BytesGenerator) Generate() []byte {
return bg.value
}
func (bg *BytesGenerator) Size() int {
return bg.size
}
func newBytesGenerator(param string) (Generator, error) {
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
if !hasPrefix {
return nil, fmt.Errorf("not correct hex: %s", param)
}
hex, err := hexToBytes(param)
if err != nil {
return nil, fmt.Errorf("hexToBytes: %w", err)
}
return &BytesGenerator{value: hex, size: len(hex)}, nil
}
func hexToBytes(hexStr string) ([]byte, error) {
hexStr = strings.TrimPrefix(hexStr, "0x")
hexStr = strings.TrimPrefix(hexStr, "0X")
// Ensure even length (pad with leading zero if needed)
if len(hexStr)%2 != 0 {
hexStr = "0" + hexStr
}
return hex.DecodeString(hexStr)
}
type RandomPacketGenerator struct {
cha8Rand *v2.ChaCha8
size int
}
func (rpg *RandomPacketGenerator) Generate() []byte {
junk := make([]byte, rpg.size)
rpg.cha8Rand.Read(junk)
return junk
}
func (rpg *RandomPacketGenerator) Size() int {
return rpg.size
}
func newRandomPacketGenerator(param string) (Generator, error) {
size, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("random packet parse int: %w", err)
}
if size > 1000 {
return nil, fmt.Errorf("random packet size must be less than 1000")
}
buf := make([]byte, 32)
_, err = crand.Read(buf)
if err != nil {
return nil, fmt.Errorf("random packet crand read: %w", err)
}
return &RandomPacketGenerator{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
size: size,
}, nil
}
type TimestampGenerator struct {
}
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
return buf
}
func (tg *TimestampGenerator) Size() int {
return 8
}
func newTimestampGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
}
return &TimestampGenerator{}, nil
}
type WaitTimeoutGenerator struct {
waitTimeout time.Duration
}
func (wtg *WaitTimeoutGenerator) Generate() []byte {
time.Sleep(wtg.waitTimeout)
return []byte{}
}
func (wtg *WaitTimeoutGenerator) Size() int {
return 0
}
func newWaitTimeoutGenerator(param string) (Generator, error) {
timeout, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("timeout parse int: %w", err)
}
if timeout > 5000 {
return nil, fmt.Errorf("timeout must be less than 5000ms")
}
return &WaitTimeoutGenerator{
waitTimeout: time.Duration(timeout) * time.Millisecond,
}, nil
}
type PacketCounterGenerator struct {
}
func (c *PacketCounterGenerator) Generate() []byte {
buf := make([]byte, 8)
// TODO: better way to handle counter tag
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
return buf
}
func (c *PacketCounterGenerator) Size() int {
return 8
}
func newPacketCounterGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
}
return &PacketCounterGenerator{}, nil
}
type WaitResponseGenerator struct {
}
func (c *WaitResponseGenerator) Generate() []byte {
WaitResponse.ShouldWait.Set()
<-WaitResponse.Channel
WaitResponse.ShouldWait.UnSet()
return []byte{}
}
func (c *WaitResponseGenerator) Size() int {
return 0
}
func newWaitResponseGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
}
return &WaitResponseGenerator{}, nil
}

View file

@ -0,0 +1,189 @@
package awg
import (
"encoding/binary"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func Test_newBytesGenerator(t *testing.T) {
type args struct {
param string
}
tests := []struct {
name string
args args
want []byte
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "wrong start",
args: args{
param: "123456",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with X",
args: args{
param: "0X12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with x",
args: args{
param: "0x12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "valid hex",
args: args{
param: "0xf6ab3267fa",
},
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
},
{
name: "valid hex with odd length",
args: args{
param: "0xfab3267fa",
},
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newBytesGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
gotValues := got.Generate()
require.Equal(t, tt.want, gotValues)
})
}
}
func Test_newRandomPacketGenerator(t *testing.T) {
type args struct {
param string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "not an int",
args: args{
param: "x",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "too large",
args: args{
param: "1001",
},
wantErr: fmt.Errorf("random packet size must be less than 1000"),
},
{
name: "valid",
args: args{
param: "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newRandomPacketGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
first := got.Generate()
second := got.Generate()
require.NotEqual(t, first, second)
})
}
}
func TestPacketCounterGenerator(t *testing.T) {
tests := []struct {
name string
param string
wantErr bool
}{
{
name: "Valid empty param",
param: "",
wantErr: false,
},
{
name: "Invalid non-empty param",
param: "anything",
wantErr: true,
},
}
for _, tc := range tests {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
gen, err := newPacketCounterGenerator(tc.param)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, 8, gen.Size())
// Reset counter to known value for test
initialCount := uint64(42)
PacketCounter.Store(initialCount)
output := gen.Generate()
require.Equal(t, 8, len(output))
// Verify counter value in output
counterValue := binary.BigEndian.Uint64(output)
require.Equal(t, initialCount, counterValue)
// Increment counter and verify change
PacketCounter.Add(1)
output = gen.Generate()
counterValue = binary.BigEndian.Uint64(output)
require.Equal(t, initialCount+1, counterValue)
})
}
}

View file

@ -0,0 +1,59 @@
package awg
import (
"fmt"
"strconv"
)
type TagJunkPacketGenerator struct {
name string
tagValue string
packetSize int
generators []Generator
}
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
return TagJunkPacketGenerator{
name: name,
tagValue: tagValue,
generators: make([]Generator, 0, size),
}
}
func (tg *TagJunkPacketGenerator) append(generator Generator) {
tg.generators = append(tg.generators, generator)
tg.packetSize += generator.Size()
}
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
}
return packet
}
func (tg *TagJunkPacketGenerator) Name() string {
return tg.name
}
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
if len(tg.name) != 2 {
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
}
index, err := strconv.Atoi(tg.name[1:2])
if err != nil {
return 0, fmt.Errorf("name 2 char should be an int %w", err)
}
return index, nil
}
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
return IpcFields{
Key: tg.name,
Value: tg.tagValue,
}
}

View file

@ -0,0 +1,210 @@
package awg
import (
"testing"
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
"github.com/stretchr/testify/require"
)
func TestNewTagJunkGenerator(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
genName string
size int
expected TagJunkPacketGenerator
}{
{
name: "Create new generator with empty name",
genName: "",
size: 0,
expected: TagJunkPacketGenerator{
name: "",
packetSize: 0,
generators: make([]Generator, 0),
},
},
{
name: "Create new generator with valid name",
genName: "T1",
size: 0,
expected: TagJunkPacketGenerator{
name: "T1",
packetSize: 0,
generators: make([]Generator, 0),
},
},
{
name: "Create new generator with non-zero size",
genName: "T2",
size: 5,
expected: TagJunkPacketGenerator{
name: "T2",
packetSize: 0,
generators: make([]Generator, 5),
},
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
require.Equal(t, tc.expected.name, result.name)
require.Equal(t, tc.expected.packetSize, result.packetSize)
require.Equal(t, cap(result.generators), len(tc.expected.generators))
})
}
}
func TestTagJunkGeneratorAppend(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
initialState TagJunkPacketGenerator
mockSize int
expectedLength int
expectedSize int
}{
{
name: "Append to empty generator",
initialState: newTagJunkPacketGenerator("T1", "", 0),
mockSize: 5,
expectedLength: 1,
expectedSize: 5,
},
{
name: "Append to non-empty generator",
initialState: TagJunkPacketGenerator{
name: "T2",
packetSize: 10,
generators: make([]Generator, 2),
},
mockSize: 7,
expectedLength: 3, // 2 existing + 1 new
expectedSize: 17, // 10 + 7
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := tc.initialState
mockGen := internal.NewMockGenerator(tc.mockSize)
tg.append(mockGen)
require.Equal(t, tc.expectedLength, len(tg.generators))
require.Equal(t, tc.expectedSize, tg.packetSize)
})
}
}
func TestTagJunkGeneratorGenerate(t *testing.T) {
t.Parallel()
// Create mock generators for testing
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
testCases := []struct {
name string
setupGenerator func() TagJunkPacketGenerator
expected []byte
}{
{
name: "Generate with empty generators",
setupGenerator: func() TagJunkPacketGenerator {
return newTagJunkPacketGenerator("T1", "", 0)
},
expected: []byte{},
},
{
name: "Generate with single generator",
setupGenerator: func() TagJunkPacketGenerator {
tg := newTagJunkPacketGenerator("T2", "", 0)
tg.append(mockGen1)
return tg
},
expected: []byte{0x01, 0x02},
},
{
name: "Generate with multiple generators",
setupGenerator: func() TagJunkPacketGenerator {
tg := newTagJunkPacketGenerator("T3", "", 0)
tg.append(mockGen1)
tg.append(mockGen2)
return tg
},
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := tc.setupGenerator()
result := tg.generatePacket()
require.Equal(t, tc.expected, result)
})
}
}
func TestTagJunkGeneratorNameIndex(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
generatorName string
expectedIndex int
expectError bool
}{
{
name: "Valid name with digit",
generatorName: "T5",
expectedIndex: 5,
expectError: false,
},
{
name: "Invalid name - too short",
generatorName: "T",
expectError: true,
},
{
name: "Invalid name - too long",
generatorName: "T55",
expectError: true,
},
{
name: "Invalid name - non-digit second character",
generatorName: "TX",
expectError: true,
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := TagJunkPacketGenerator{name: tc.generatorName}
index, err := tg.nameIndex()
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.expectedIndex, index)
}
})
}
}

View file

@ -0,0 +1,66 @@
package awg
import "fmt"
type TagJunkPacketGenerators struct {
tagGenerators []TagJunkPacketGenerator
length int
DefaultJunkCount int // Jc
}
func (generators *TagJunkPacketGenerators) AppendGenerator(
generator TagJunkPacketGenerator,
) {
generators.tagGenerators = append(generators.tagGenerators, generator)
generators.length++
}
func (generators *TagJunkPacketGenerators) IsDefined() bool {
return len(generators.tagGenerators) > 0
}
// validate that packets were defined consecutively
func (generators *TagJunkPacketGenerators) Validate() error {
seen := make([]bool, len(generators.tagGenerators))
for _, generator := range generators.tagGenerators {
index, err := generator.nameIndex()
if index > len(generators.tagGenerators) {
return fmt.Errorf("junk packet index should be consecutive")
}
if err != nil {
return fmt.Errorf("name index: %w", err)
} else {
seen[index-1] = true
}
}
for _, found := range seen {
if !found {
return fmt.Errorf("junk packet index should be consecutive")
}
}
return nil
}
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
for i, tagGenerator := range generators.tagGenerators {
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
PacketCounter.Inc()
}
PacketCounter.Add(uint64(generators.DefaultJunkCount))
return rv
}
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
rv := make([]IpcFields, 0, len(tg.tagGenerators))
for _, generator := range tg.tagGenerators {
rv = append(rv, generator.IpcGetFields())
}
return rv
}

View file

@ -0,0 +1,149 @@
package awg
import (
"testing"
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
"github.com/stretchr/testify/require"
)
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
tests := []struct {
name string
generator TagJunkPacketGenerator
}{
{
name: "append single generator",
generator: newTagJunkPacketGenerator("t1", "", 10),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
// Initial length should be 0
require.Equal(t, 0, generators.length)
require.Empty(t, generators.tagGenerators)
// After append, length should be 1 and generator should be added
generators.AppendGenerator(tt.generator)
require.Equal(t, 1, generators.length)
require.Len(t, generators.tagGenerators, 1)
require.Equal(t, tt.generator, generators.tagGenerators[0])
})
}
}
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
tests := []struct {
name string
generators []TagJunkPacketGenerator
wantErr bool
errMsg string
}{
{
name: "bad start",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t3", "", 10),
newTagJunkPacketGenerator("t4", "", 10),
},
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "non-consecutive indices",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t1", "", 10),
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
},
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "consecutive indices",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t1", "", 10),
newTagJunkPacketGenerator("t2", "", 10),
newTagJunkPacketGenerator("t3", "", 10),
newTagJunkPacketGenerator("t4", "", 10),
newTagJunkPacketGenerator("t5", "", 10),
},
},
{
name: "nameIndex error",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("error", "", 10),
},
wantErr: true,
errMsg: "name must be 2 character long",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
for _, gen := range tt.generators {
generators.AppendGenerator(gen)
}
err := generators.Validate()
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
return
}
require.NoError(t, err)
})
}
}
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
mockByte1 := []byte{0x01, 0x02}
mockByte2 := []byte{0x03, 0x04, 0x05}
mockGen1 := internal.NewMockByteGenerator(mockByte1)
mockGen2 := internal.NewMockByteGenerator(mockByte2)
tests := []struct {
name string
setupGenerator func() []TagJunkPacketGenerator
expected [][]byte
}{
{
name: "generate with no default junk",
setupGenerator: func() []TagJunkPacketGenerator {
tg1 := newTagJunkPacketGenerator("t1", "", 0)
tg1.append(mockGen1)
tg1.append(mockGen2)
tg2 := newTagJunkPacketGenerator("t2", "", 0)
tg2.append(mockGen2)
tg2.append(mockGen1)
return []TagJunkPacketGenerator{tg1, tg2}
},
expected: [][]byte{
append(mockByte1, mockByte2...),
append(mockByte2, mockByte1...),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
tagGenerators := tt.setupGenerator()
for _, gen := range tagGenerators {
generators.AppendGenerator(gen)
}
result := generators.GeneratePackets()
require.Equal(t, result, tt.expected)
})
}
}

112
device/awg/tag_parser.go Normal file
View file

@ -0,0 +1,112 @@
package awg
import (
"fmt"
"maps"
"regexp"
"strings"
)
type IpcFields struct{ Key, Value string }
type EnumTag string
const (
BytesEnumTag EnumTag = "b"
CounterEnumTag EnumTag = "c"
TimestampEnumTag EnumTag = "t"
RandomBytesEnumTag EnumTag = "r"
WaitTimeoutEnumTag EnumTag = "wt"
WaitResponseEnumTag EnumTag = "wr"
)
var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomPacketGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
// WaitResponseEnumTag: newWaitResponseGenerator,
}
// helper map to determine enumTags are unique
var uniqueTags = map[EnumTag]bool{
CounterEnumTag: false,
TimestampEnumTag: false,
}
type Tag struct {
Name EnumTag
Param string
}
func parseTag(input string) (Tag, error) {
// Regular expression to match <tagname optional_param>
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
match := re.FindStringSubmatch(input)
tag := Tag{
Name: EnumTag(match[1]),
}
if len(match) > 2 && match[2] != "" {
tag.Param = strings.TrimSpace(match[2])
}
return tag, nil
}
func Parse(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
}
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
maps.Copy(uniqueTagCheck, uniqueTags)
// skip byproduct of split
inputSlice = inputSlice[1:]
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
for _, inputParam := range inputSlice {
if len(inputParam) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf(
"empty tag in input: %s",
inputSlice,
)
} else if strings.Count(inputParam, ">") != 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
}
tag, _ := parseTag(inputParam)
creator, ok := generatorCreator[tag.Name]
if !ok {
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
}
if present, ok := uniqueTagCheck[tag.Name]; ok {
if present {
return TagJunkPacketGenerator{}, fmt.Errorf(
"tag %s needs to be unique",
tag.Name,
)
}
uniqueTagCheck[tag.Name] = true
}
generator, err := creator(tag.Param)
if err != nil {
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
}
// TODO: handle counter tag
// if tag.Name == CounterEnumTag {
// packetCounter, ok := generator.(*PacketCounterGenerator)
// if !ok {
// log.Fatalf("packet counter generator expected, got %T", generator)
// }
// PacketCounter = packetCounter.counter
// }
rv.append(generator)
}
return rv, nil
}

View file

@ -0,0 +1,77 @@
package awg
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestParse(t *testing.T) {
type args struct {
name string
input string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "invalid name",
args: args{name: "apple", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "empty",
args: args{name: "i1", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "extra >",
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "extra <",
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"),
},
{
name: "empty <>",
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"),
},
{
name: "invalid tag",
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
wantErr: fmt.Errorf("invalid tag"),
},
{
name: "counter uniqueness violation",
args: args{name: "i1", input: "<c><c>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "timestamp uniqueness violation",
args: args{name: "i1", input: "<t><t>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "valid",
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Parse(tt.args.name, tt.args.input)
// TODO: ErrorAs doesn't work as you think
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
return
}
require.Nil(t, err)
})
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -8,7 +8,7 @@ package device
import (
"errors"
"golang.zx2c4.com/wireguard/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type DummyDatagram struct {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
c chan *QueueOutboundElement
c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
c: make(chan *QueueOutboundElement, QueueOutboundSize),
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
c chan *QueueInboundElement
c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
c: make(chan *QueueInboundElement, QueueInboundSize),
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
c chan *[]*QueueInboundElement
c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
c: make(chan *[]*QueueInboundElement, QueueInboundSize),
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
case elems := <-q.c:
for _, elem := range *elems {
elem.Lock()
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
device.PutInboundElementsContainer(elemsContainer)
default:
return
}
@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
c chan *[]*QueueOutboundElement
c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
case elems := <-q.c:
for _, elem := range *elems {
elem.Lock()
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elems)
device.PutOutboundElementsContainer(elemsContainer)
default:
return
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,22 +1,60 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"runtime"
"sync"
"sync/atomic"
"time"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
type Version uint8
const (
VersionDefault Version = iota
VersionAwg
VersionAwgSpecialHandshake
)
// TODO:
type AtomicVersion struct {
value atomic.Uint32
}
func NewAtomicVersion(v Version) *AtomicVersion {
av := &AtomicVersion{}
av.Store(v)
return av
}
func (av *AtomicVersion) Load() Version {
return Version(av.value.Load())
}
func (av *AtomicVersion) Store(v Version) {
av.value.Store(uint32(v))
}
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
return av.value.CompareAndSwap(uint32(old), uint32(new))
}
func (av *AtomicVersion) Swap(new Version) Version {
return Version(av.value.Swap(uint32(new)))
}
type Device struct {
state struct {
// state holds the device's state. It is accessed atomically.
@ -68,11 +106,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
outboundElementsSlice *WaitPool
inboundElementsSlice *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
inboundElementsContainer *WaitPool
outboundElementsContainer *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
}
queue struct {
@ -89,6 +127,9 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
version Version
awg awg.Protocol
}
// deviceState represents the state of a Device.
@ -162,7 +203,8 @@ func (device *Device) changeState(want deviceState) (err error) {
err = errDown
}
}
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
return
}
@ -368,10 +410,10 @@ func (device *Device) RemoveAllPeers() {
}
func (device *Device) Close() {
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
device.state.Lock()
defer device.state.Unlock()
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
@ -395,6 +437,8 @@ func (device *Device) Close() {
device.rate.limiter.Close()
device.resetProtocol()
device.log.Verbosef("Device closed")
close(device.closed)
}
@ -461,11 +505,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@ -515,11 +555,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@ -542,3 +578,261 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
func (device *Device) isAWG() bool {
return device.version >= VersionAwg
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = DefaultMessageInitiationType
MessageResponseType = DefaultMessageResponseType
MessageCookieReplyType = DefaultMessageCookieReplyType
MessageTransportType = DefaultMessageTransportType
}
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil
}
var errs []error
isASecOn := false
device.awg.ASecMux.Lock()
if tempAwg.ASecCfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
),
)
}
device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
if tempAwg.ASecCfg.JunkPacketCount != 0 {
isASecOn = true
}
device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
isASecOn = true
}
if device.awg.ASecCfg.JunkPacketCount > 0 &&
tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
}
if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.ASecCfg.JunkPacketMinSize = 0
device.awg.ASecCfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize,
MaxSegmentSize,
))
} else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize,
tempAwg.ASecCfg.JunkPacketMinSize,
))
} else {
device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
}
if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
isASecOn = true
}
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.InitHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
}
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
isASecOn = true
}
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
}
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
isASecOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
}
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
}
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
}
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
isASecOn = true
}
if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
}
if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
}
if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
}
if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
}
isSameHeaderMap := map[uint32]struct{}{
MessageInitiationType: {},
MessageResponseType: {},
MessageCookieReplyType: {},
MessageTransportType: {},
}
// size will be different if same values
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
),
)
}
isSameSizeMap := map[int]struct{}{
newInitSize: {},
newResponseSize: {},
newCookieSize: {},
newTransportSize: {},
}
if len(isSameSizeMap) != 4 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
newInitSize,
newResponseSize,
newCookieSize,
newTransportSize,
),
)
} else {
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
}
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
newCookieSize: MessageCookieReplyType,
newTransportSize: MessageTransportType,
}
}
device.awg.IsASecOn.SetTo(isASecOn)
var err error
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
if err != nil {
errs = append(errs, err)
}
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
device.awg.ASecMux.Unlock()
return errors.Join(errs...)
}

View file

@ -1,29 +1,32 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/netip"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"sync"
"sync/atomic"
"testing"
"time"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
"go.uber.org/atomic"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
@ -50,7 +53,7 @@ func uapiCfg(cfg ...string) string {
// genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
@ -62,7 +65,8 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
args0 := append([]string(nil), cfg...)
args0 = append(args0, []string{
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
@ -70,12 +74,16 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
}...)
cfgs[0] = uapiCfg(args0...)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
args1 := append([]string(nil), cfg...)
args1 = append(args1, []string{
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
@ -83,7 +91,9 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
}...)
cfgs[1] = uapiCfg(args1...)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
@ -115,16 +125,21 @@ func (d SendDirection) String() string {
return "pong"
}
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
) {
tb.Helper()
p0, p1 := pair[0], pair[1]
if !ping {
// pong is the new ping
p0, p1 = p1, p0
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(5 * time.Second)
timer := time.NewTimer(6 * time.Second)
defer timer.Stop()
var err error
select {
@ -149,8 +164,14 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
}
// genTestPair creates a testPair.
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
cfg, endpointCfg := genConfigs(tb)
func genTestPair(
tb testing.TB,
realSocket bool,
extraCfg ...string,
) (pair testPair) {
var cfg, endpointCfg [2]string
cfg, endpointCfg = genConfigs(tb, extraCfg...)
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
@ -203,6 +224,76 @@ func TestTwoDevicePing(t *testing.T) {
})
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestAWGDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true,
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"s3", "50",
"s4", "5",
"h1", "123456",
"h2", "67543",
"h3", "123123",
"h4", "32345",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
// Needs to be stopped with Ctrl-C
func TestAWGHandshakeDevicePing(t *testing.T) {
t.Skip("This test is intended to be run manually, not as part of the test suite.")
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
isRunning := atomic.NewBool(true)
go func() {
<-signalContext.Done()
fmt.Println("Waiting to finish")
isRunning.Store(false)
}()
goroutineLeakCheck(t)
pair := genTestPair(t, true,
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
"i2", "<b 0xf6ab3267fa><r 100>",
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
"j2", "<c><b 0xf6ab><t><wt 1000>",
"j3", "<t><b 0xf6ab><c><r 10>",
"itime", "60",
// "jc", "1",
// "jmin", "500",
// "jmax", "1000",
// "s1", "30",
// "s2", "40",
// "h1", "123456",
// "h2", "67543",
// "h4", "32345",
// "h3", "123123",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
for isRunning.Load() {
pair.Send(t, Ping, nil)
time.Sleep(2 * time.Second)
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
for isRunning.Load() {
pair.Send(t, Pong, nil)
time.Sleep(2 * time.Second)
}
})
}
func TestUpDown(t *testing.T) {
goroutineLeakCheck(t)
const itrials = 50
@ -423,29 +514,43 @@ type fakeBindSized struct {
size int
}
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
func (b *fakeBindSized) Open(
port uint16,
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -11,7 +11,7 @@ import (
"sync/atomic"
"time"
"golang.zx2c4.com/wireguard/replay"
"github.com/amnezia-vpn/amneziawg-go/replay"
)
/* Due to limitations in Go and /x/crypto there is currently

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
peer.disableRoaming = peer.endpoint != nil
peer.Unlock()
peer.endpoint.Lock()
peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.endpoint.Unlock()
}
device.peers.RUnlock()
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -15,7 +15,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n"
"github.com/amnezia-vpn/amneziawg-go/tai64n"
)
type handshakeState int
@ -53,10 +53,17 @@ const (
)
const (
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
DefaultMessageInitiationType uint32 = 1
DefaultMessageResponseType uint32 = 2
DefaultMessageCookieReplyType uint32 = 3
DefaultMessageTransportType uint32 = 4
)
var (
MessageInitiationType uint32 = DefaultMessageInitiationType
MessageResponseType uint32 = DefaultMessageResponseType
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
MessageTransportType uint32 = DefaultMessageTransportType
)
const (
@ -75,6 +82,11 @@ const (
MessageTransportOffsetContent = 16
)
var (
packetSizeToMsgType map[int]uint32
msgTypeToJunkSize map[uint32]int
)
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now)
@ -193,10 +205,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.awg.ASecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -250,9 +264,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.awg.ASecMux.RLock()
if msg.Type != MessageInitiationType {
device.awg.ASecMux.RUnlock()
return nil
}
device.awg.ASecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -367,7 +384,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.awg.ASecMux.RLock()
msg.Type = MessageResponseType
device.awg.ASecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -417,9 +436,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.ASecMux.RLock()
if msg.Type != MessageResponseType {
device.awg.ASecMux.RUnlock()
return nil
}
device.awg.ASecMux.RUnlock()
// lookup handshake by receiver

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -10,8 +10,8 @@ import (
"encoding/binary"
"testing"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/tun/tuntest"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func TestCurveWrappers(t *testing.T) {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -12,22 +12,26 @@ import (
"sync/atomic"
"time"
"golang.zx2c4.com/wireguard/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
)
type Peer struct {
isRunning atomic.Bool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
disableRoaming bool
endpoint struct {
sync.Mutex
val conn.Endpoint
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
disableRoaming bool
}
timers struct {
retransmitHandshake *Timer
@ -45,9 +49,9 @@ type Peer struct {
}
queue struct {
staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
@ -74,14 +78,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@ -97,7 +99,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
// init timers
peer.timersInit()
@ -108,6 +114,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
err := peer.SendBuffers(buffers)
if err == nil {
awg.PacketCounter.Add(uint64(len(buffers)))
return nil
}
return err
}
func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@ -116,14 +132,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return nil
}
peer.RLock()
defer peer.RUnlock()
if peer.endpoint == nil {
peer.endpoint.Lock()
endpoint := peer.endpoint.val
if endpoint == nil {
peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffers, peer.endpoint)
err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
var totalLen uint64
for _, b := range buffers {
@ -267,10 +288,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if peer.disableRoaming {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
return
}
peer.Lock()
peer.endpoint = endpoint
peer.Unlock()
peer.endpoint.clearSrcOnTx = false
peer.endpoint.val = endpoint
}
func (peer *Peer) markEndpointSrcForClearing() {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return
}
peer.endpoint.clearSrcOnTx = true
}

View file

@ -1,20 +1,19 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync"
"sync/atomic"
)
type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count atomic.Uint32
count uint32 // Get calls not yet Put back
max uint32
}
@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count.Load() >= p.max {
for p.count >= p.max {
p.cond.Wait()
}
p.count.Add(1)
p.count++
p.lock.Unlock()
}
return p.pool.Get()
@ -41,18 +40,20 @@ func (p *WaitPool) Put(x any) {
if p.max == 0 {
return
}
p.count.Add(^uint32(0))
p.lock.Lock()
defer p.lock.Unlock()
p.count--
p.cond.Signal()
}
func (device *Device) PopulatePools() {
device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &s
})
device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &s
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
@ -65,28 +66,32 @@ func (device *Device) PopulatePools() {
})
}
func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
for i := range *s {
(*s)[i] = nil
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
*s = (*s)[:0]
device.pool.outboundElementsSlice.Put(s)
c.elems = c.elems[:0]
device.pool.inboundElementsContainer.Put(c)
}
func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
for i := range *s {
(*s)[i] = nil
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
*s = (*s)[:0]
device.pool.inboundElementsSlice.Put(s)
c.elems = c.elems[:0]
device.pool.outboundElementsContainer.Put(c)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) {
wg.Add(workers)
var max atomic.Uint32
updateMax := func() {
count := p.count.Load()
p.lock.Lock()
count := p.count
p.lock.Unlock()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}

View file

@ -1,11 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
import "golang.zx2c4.com/wireguard/conn"
import "github.com/amnezia-vpn/amneziawg-go/conn"
/* Reduce memory consumption for Android */
@ -14,6 +14,6 @@ const (
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2200
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)

View file

@ -2,12 +2,12 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
import "golang.zx2c4.com/wireguard/conn"
import "github.com/amnezia-vpn/amneziawg-go/conn"
const (
QueueStagedSize = conn.IdealBatchSize

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -13,10 +13,10 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
@ -27,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@ -35,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
type QueueInboundElementsContainer struct {
sync.Mutex
elems []*QueueInboundElement
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@ -66,7 +70,10 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
func (device *Device) RoutineReceiveIncoming(
maxBatchSize int,
recv conn.ReceiveFunc,
) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@ -87,7 +94,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
@ -122,6 +129,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
}
deathSpiral = 0
device.awg.ASecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@ -129,9 +137,44 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
}
// check size of packet
packet := bufsArrs[i][:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var msgType uint32
if device.isAWG() {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
device.log.Verbosef("transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
if msgType != MessageTransportType {
// probably a junk packet
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
continue
}
// remove junk from bufsArrs by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy(bufsArrs[i][:size], packet[transportJunkSize:])
size -= transportJunkSize
// need to reinitialize packet as well
packet = packet[:size]
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
}
switch msgType {
@ -170,15 +213,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsSlice()
elemsForPeer = device.GetInboundElementsContainer()
elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
@ -217,18 +259,17 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
default:
}
}
for peer, elems := range elemsByPeer {
device.awg.ASecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
for _, elem := range *elems {
device.queue.decryption.c <- elem
}
peer.queue.inbound.c <- elemsContainer
device.queue.decryption.c <- elemsContainer
} else {
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
@ -241,26 +282,28 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
for elem := range device.queue.decryption.c {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
for elemsContainer := range device.queue.decryption.c {
for _, elem := range elemsContainer.elems {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
}
}
elem.Unlock()
elemsContainer.Unlock()
}
}
@ -275,6 +318,8 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.awg.ASecMux.RLock()
// handle cookie fields and ratelimiting
switch elem.msgType {
@ -302,9 +347,14 @@ func (device *Device) RoutineHandshake(id int) {
// consume reply
if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
device.log.Verbosef(
"Receiving cookie response from %s",
elem.endpoint.DstToString(),
)
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response")
device.log.Verbosef(
"Could not decrypt invalid cookie response",
)
}
}
@ -346,9 +396,7 @@ func (device *Device) RoutineHandshake(id int) {
switch elem.msgType {
case MessageInitiationType:
// unmarshal
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
@ -358,7 +406,6 @@ func (device *Device) RoutineHandshake(id int) {
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
@ -423,6 +470,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.awg.ASecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
@ -437,12 +485,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
for elems := range peer.queue.inbound.c {
if elems == nil {
for elemsContainer := range peer.queue.inbound.c {
if elemsContainer == nil {
return
}
for _, elem := range *elems {
elem.Lock()
elemsContainer.Lock()
validTailPacket := -1
dataPacketReceived := false
rxBytesLen := uint64(0)
for i, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
@ -452,21 +503,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
continue
}
peer.SetEndpointFromPacket(elem.endpoint)
validTailPacket = i
if peer.ReceivedWithKeypair(elem.keypair) {
peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
peer.timersDataReceived()
dataPacketReceived = true
switch elem.packet[0] >> 4 {
case 4:
@ -503,11 +552,28 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
}
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
device.log.Verbosef(
"Packet with invalid IP version from %v",
peer,
)
continue
}
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
bufs = append(
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
}
peer.rxBytes.Add(rxBytesLen)
if validTailPacket >= 0 {
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
}
if dataPacketReceived {
peer.timersDataReceived()
}
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
@ -515,11 +581,11 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
bufs = bufs[:0]
device.PutInboundElementsSlice(elems)
device.PutInboundElementsContainer(elemsContainer)
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -14,10 +14,11 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@ -45,7 +46,6 @@ import (
*/
type QueueOutboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@ -53,10 +53,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
type QueueOutboundElementsContainer struct {
sync.Mutex
elems []*QueueOutboundElement
}
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@ -78,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
elems := peer.device.GetOutboundElementsSlice()
*elems = append(*elems, elem)
elemsContainer := peer.device.GetOutboundElementsContainer()
elemsContainer.elems = append(elemsContainer.elems, elem)
select {
case peer.queue.staged <- elems:
case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
peer.device.PutOutboundElementsSlice(elems)
peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@ -119,17 +123,66 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.ASecMux.RLock()
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
if junks != nil {
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
}
} else {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
peer.device.awg.ASecMux.RUnlock()
} else {
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
}
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffers([][]byte{packet})
sendBuffer = append(sendBuffer, junkedHeader)
err = peer.SendAndCountBuffers(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@ -151,11 +204,19 @@ func (peer *Peer) SendHandshakeResponse() error {
return err
}
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession()
if err != nil {
@ -168,28 +229,42 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{packet})
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
)
if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err)
return err
}
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
if err != nil {
device.log.Errorf("%v - %v", device, err)
return err
}
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
junkedHeader = append(junkedHeader, writer.Bytes()...)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
return nil
}
@ -218,7 +293,7 @@ func (device *Device) RoutineReadFromTUN() {
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
@ -275,10 +350,10 @@ func (device *Device) RoutineReadFromTUN() {
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsSlice()
elemsForPeer = device.GetOutboundElementsContainer()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
@ -288,11 +363,11 @@ func (device *Device) RoutineReadFromTUN() {
peer.StagePackets(elemsForPeer)
peer.SendStagedPackets()
} else {
for _, elem := range *elemsForPeer {
for _, elem := range elemsForPeer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elemsForPeer)
device.PutOutboundElementsContainer(elemsForPeer)
}
delete(elemsByPeer, peer)
}
@ -316,7 +391,7 @@ func (device *Device) RoutineReadFromTUN() {
}
}
func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
case peer.queue.staged <- elems:
@ -325,11 +400,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
}
select {
case tooOld := <-peer.queue.staged:
for _, elem := range *tooOld {
for _, elem := range tooOld.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(tooOld)
peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
@ -348,54 +423,52 @@ top:
}
for {
var elemsOOO *[]*QueueOutboundElement
var elemsContainerOOO *QueueOutboundElementsContainer
select {
case elems := <-peer.queue.staged:
case elemsContainer := <-peer.queue.staged:
i := 0
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
if elemsOOO == nil {
elemsOOO = peer.device.GetOutboundElementsSlice()
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
*elemsOOO = append(*elemsOOO, elem)
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
(*elems)[i] = elem
elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
elem.Lock()
}
*elems = (*elems)[:i]
elemsContainer.Lock()
elemsContainer.elems = elemsContainer.elems[:i]
if elemsOOO != nil {
peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
}
if len(*elems) == 0 {
peer.device.PutOutboundElementsSlice(elems)
if len(elemsContainer.elems) == 0 {
peer.device.PutOutboundElementsContainer(elemsContainer)
goto top
}
// add to parallel and sequential queue
if peer.isRunning.Load() {
peer.queue.outbound.c <- elems
for _, elem := range *elems {
peer.device.queue.encryption.c <- elem
}
peer.queue.outbound.c <- elemsContainer
peer.device.queue.encryption.c <- elemsContainer
} else {
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(elems)
peer.device.PutOutboundElementsContainer(elemsContainer)
}
if elemsOOO != nil {
if elemsContainerOOO != nil {
goto top
}
default:
@ -407,12 +480,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
case elems := <-peer.queue.staged:
for _, elem := range *elems {
case elemsContainer := <-peer.queue.staged:
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(elems)
peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
@ -446,32 +519,34 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
for elem := range device.queue.encryption.c {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
for elemsContainer := range device.queue.encryption.c {
for _, elem := range elemsContainer.elems {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
elem.Unlock()
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
}
elemsContainer.Unlock()
}
}
@ -485,9 +560,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
for elems := range peer.queue.outbound.c {
for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
if elems == nil {
if elemsContainer == nil {
return
}
if !peer.isRunning.Load() {
@ -497,18 +572,27 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
// The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
for _, elem := range *elems {
elem.Lock()
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
continue
}
dataSent := false
for _, elem := range *elems {
elem.Lock()
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
if err != nil {
device.log.Errorf("%v - %v", device, err)
continue
}
elem.packet = append(junkedHeader, elem.packet...)
}
bufs = append(bufs, elem.packet)
}
@ -516,15 +600,23 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendBuffers(bufs)
err := peer.SendAndCountBuffers(bufs)
if dataSent {
peer.timersDataSent()
}
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elems)
device.PutOutboundElementsContainer(elemsContainer)
if err != nil {
var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
device.log.Verbosef(err.Error())
err = errGSO.RetryErr
}
}
if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue

View file

@ -3,10 +3,10 @@
package device
import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) {
return nil, nil
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@ -9,7 +9,7 @@
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
* So this code remains platform dependent.
*/
package device
@ -20,8 +20,8 @@ import (
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
@ -47,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
return netlinkCancel, nil
}
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint
@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.endpoint.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock()
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.endpoint.Unlock()
break
}
pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
pePtr.peer.Unlock()
pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil {
peer.RUnlock()
peer.endpoint.Lock()
if peer.endpoint.val == nil {
peer.endpoint.Unlock()
continue
}
nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
peer.RUnlock()
peer.endpoint.Unlock()
continue
}
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock()
peer.endpoint.Unlock()
break
}
nlmsg := struct {
@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
peer.RUnlock()
peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -8,7 +8,7 @@ package device
import (
"fmt"
"golang.zx2c4.com/wireguard/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
const DefaultMTU = 1420

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@ -18,7 +18,8 @@ import (
"sync"
"time"
"golang.zx2c4.com/wireguard/ipc"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
type IPCError struct {
@ -97,35 +98,81 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.isAWG() {
if device.awg.ASecCfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
}
if device.awg.ASecCfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
}
if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
}
if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
}
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
}
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
}
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
}
if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
}
if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
}
if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
}
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
}
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
for _, field := range controlledJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
}
if device.awg.HandshakeHandler.ITimeout != 0 {
sendf("itime=%d", device.awg.HandshakeHandler.ITimeout/time.Second)
}
}
for _, peer := range device.peers.keyMap {
// Serialize peer state.
// Do the work in an anonymous function so that we can use defer.
func() {
peer.RLock()
defer peer.RUnlock()
peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
peer.endpoint.Lock()
if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
peer.endpoint.Unlock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
sendf("protocol_version=1")
if peer.endpoint != nil {
sendf("endpoint=%s", peer.endpoint.DstToString())
}
nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}()
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}
}()
@ -152,17 +199,26 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
tempAwg := awg.Protocol{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempAwg)
if err != nil {
return err
}
peer.handlePostConfig()
return nil
}
key, value, ok := strings.Cut(line, "=")
if !ok {
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
return ipcErrorf(
ipc.IpcErrorProtocol,
"failed to parse line %q",
line,
)
}
if key == "public_key" {
@ -180,7 +236,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value)
err = device.handleDeviceLine(key, value, &tempAwg)
} else {
err = device.handlePeerLine(peer, key, value)
}
@ -188,6 +244,10 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
err = device.handlePostConfig(&tempAwg)
if err != nil {
return err
}
peer.handlePostConfig()
if err := scanner.Err(); err != nil {
@ -196,7 +256,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string) error {
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@ -237,11 +297,150 @@ func (device *Device) handleDeviceLine(key, value string) error {
case "replace_peers":
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set replace_peers, invalid value: %v",
value,
)
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
tempAwg.ASecCfg.IsSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.ASecCfg.IsSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.ASecCfg.IsSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
}
tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
}
tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
}
tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
}
tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
}
generators, err := awg.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.IsSet = true
case "j1", "j2", "j3":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
}
generators, err := awg.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.IsSet = true
case "itime":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty itime")
return nil
}
itime, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
}
device.log.Verbosef("UAPI: Updating itime")
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
tempAwg.HandshakeHandler.IsSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
}
@ -262,7 +461,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
return
}
if peer.created {
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
@ -273,7 +472,10 @@ func (peer *ipcSetPeer) handlePostConfig() {
}
}
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
func (device *Device) handlePublicKeyLine(
peer *ipcSetPeer,
value string,
) error {
// Load/create the peer we are configuring.
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
@ -303,12 +505,19 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error
return nil
}
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
func (device *Device) handlePeerLine(
peer *ipcSetPeer,
key, value string,
) error {
switch key {
case "update_only":
// allow disabling of creation
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set update only, invalid value: %v",
value,
)
}
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
@ -345,16 +554,20 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
peer.Lock()
defer peer.Unlock()
peer.endpoint = endpoint
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set persistent keepalive interval: %w",
err,
)
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
@ -365,7 +578,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to replace allowedips, invalid value: %v",
value,
)
}
if peer.dummy {
return nil
@ -373,7 +590,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
add := true
verb := "Adding"
if len(value) > 0 && value[0] == '-' {
add = false
verb = "Removing"
value = value[1:]
}
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
@ -381,7 +605,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if peer.dummy {
return nil
}
device.allowedips.Insert(prefix, peer.Peer)
if add {
device.allowedips.Insert(prefix, peer.Peer)
} else {
device.allowedips.Remove(prefix, peer.Peer)
}
case "protocol_version":
if value != "1" {
@ -433,7 +661,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
return
}
if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
err = ipcErrorf(
ipc.IpcErrorInvalid,
"trailing character in UAPI get: %q",
nextByte,
)
break
}
err = device.IpcGetOperation(buffered.Writer)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package main

23
go.mod
View file

@ -1,16 +1,23 @@
module golang.zx2c4.com/wireguard
module github.com/amnezia-vpn/amneziawg-go
go 1.20
go 1.24.4
require (
golang.org/x/crypto v0.6.0
golang.org/x/net v0.7.0
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
github.com/stretchr/testify v1.10.0
github.com/tevino/abool v1.2.0
github.com/tevino/abool/v2 v2.1.0
go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.39.0
golang.org/x/net v0.41.0
golang.org/x/sys v0.33.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489
)
require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/time v0.9.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

50
go.sum
View file

@ -1,14 +1,40 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=

View file

@ -20,8 +20,8 @@ import (
"testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
func randomPipePath() string {

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ipc
@ -9,8 +9,8 @@ import (
"net"
"os"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
type UAPIListener struct {

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ipc
@ -26,7 +26,7 @@ const (
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
var socketDirectory = "/var/run/wireguard"
var socketDirectory = "/var/run/amneziawg"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ipc
@ -8,8 +8,8 @@ package ipc
import (
"net"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// TODO: replace these with actual standard windows error numbers from the win package
@ -62,7 +62,7 @@ func init() {
func UAPIListen(name string) (net.Listener, error) {
listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
if err != nil {
return nil, err
}

32
main.go
View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package main
@ -14,11 +14,11 @@ import (
"runtime"
"strconv"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
const (
@ -46,20 +46,20 @@ func warning() {
return
}
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
}
func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
return
}
@ -145,7 +145,7 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting wireguard-go version %s", Version)
logger.Verbosef("Starting amneziawg-go version %s", Version)
if err != nil {
logger.Errorf("Failed to create TUN device: %v", err)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package main
@ -12,11 +12,11 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"golang.zx2c4.com/wireguard/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
const (
@ -30,13 +30,13 @@ func main() {
}
interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
logger := device.NewLogger(
device.LogLevelVerbose,
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting wireguard-go version %s", Version)
logger.Verbosef("Starting amneziawg-go version %s", Version)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package replay

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
// Package rwcancel implements cancelable read/write operations on
@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool {
func (rw *RWCancel) ReadyWrite() bool {
closeFd := int32(rw.closingReader.Fd())
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}}
var err error
for {
_, err = unix.Poll(pollFds, -1)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package tai64n

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package tai64n

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package tun

View file

@ -1,26 +1,86 @@
package tun
import "encoding/binary"
import (
"encoding/binary"
"math/bits"
)
// TODO: Explore SIMD and/or other assembly optimizations.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
i := 0
n := len(b)
for n >= 4 {
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
n -= 4
i += 4
tmp := make([]byte, 8)
binary.NativeEndian.PutUint64(tmp, initial)
ac := binary.BigEndian.Uint64(tmp)
var carry uint64
for len(b) >= 128 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
ac += carry
b = b[128:]
}
for n >= 2 {
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
if len(b) >= 64 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac += carry
b = b[64:]
}
if n == 1 {
ac += uint64(b[i]) << 8
if len(b) >= 32 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac += carry
b = b[32:]
}
return ac
if len(b) >= 16 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac += carry
b = b[16:]
}
if len(b) >= 8 {
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac += carry
b = b[8:]
}
if len(b) >= 4 {
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
ac += carry
b = b[4:]
}
if len(b) >= 2 {
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
ac += carry
b = b[2:]
}
if len(b) == 1 {
tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
ac, carry = bits.Add64(ac, uint64(tmp), 0)
ac += carry
}
binary.NativeEndian.PutUint64(tmp, ac)
return binary.BigEndian.Uint64(tmp)
}
func checksum(b []byte, initial uint64) uint16 {

98
tun/checksum_test.go Normal file
View file

@ -0,0 +1,98 @@
package tun
import (
"encoding/binary"
"fmt"
"math/rand"
"testing"
"golang.org/x/sys/unix"
)
func checksumRef(b []byte, initial uint16) uint16 {
ac := uint64(initial)
for len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}
for (ac >> 16) > 0 {
ac = (ac >> 16) + (ac & 0xffff)
}
return uint16(ac)
}
func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
sum := checksumRef(srcAddr, 0)
sum = checksumRef(dstAddr, sum)
sum = checksumRef([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumRef(tmp, sum)
}
func TestChecksum(t *testing.T) {
for length := 0; length <= 9001; length++ {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
csum := checksum(buf, 0x1234)
csumRef := checksumRef(buf, 0x1234)
if csum != csumRef {
t.Error("Expected checksum", csumRef, "got", csum)
}
}
}
func TestPseudoHeaderChecksum(t *testing.T) {
for _, addrLen := range []int{4, 16} {
for length := 0; length <= 9001; length++ {
srcAddr := make([]byte, addrLen)
dstAddr := make([]byte, addrLen)
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(srcAddr)
rng.Read(dstAddr)
rng.Read(buf)
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csum := checksum(buf, phSum)
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csumRef := checksumRef(buf, phSumRef)
if csum != csumRef {
t.Error("Expected checksumRef", csumRef, "got", csum)
}
}
}
}
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,
128,
256,
512,
1024,
1500,
2048,
4096,
8192,
9000,
9001,
}
for _, length := range lengths {
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum(buf, 0)
}
})
}
}

Some files were not shown because too many files have changed in this diff Show more