Compare commits

..

183 commits

Author SHA1 Message Date
pokamest
27e661d68e
Merge pull request #70 from marko1777/junk-improvements
Junk improvements
2025-04-07 15:31:41 +01:00
Mark Puha
71be0eb3a6 faster and more secure junk creation 2025-03-18 08:34:23 +01:00
pokamest
e3f1273f8a
Merge pull request #64 from drkivi/master
Patch for golang crypto and net submodules
2025-02-18 11:50:35 +00:00
drkivi
c97b5b7615
Update go.sum
Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:58 +03:30
drkivi
668ddfd455
Update go.mod
Submodules Version Up

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:17 +03:30
drkivi
b8da08c106
Update Dockerfile
golang -> 1.23.6
AWGTOOLS_RELEASE -> 1.0.20241018

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:43:02 +03:30
Iurii Egorov
2e3f7d122c Update Go version in Dockerfile 2024-07-01 13:47:44 +03:00
Iurii Egorov
2e7780471a
Remove GetOffloadInfo() (#32)
* Remove GetOffloadInfo()
* Remove GetOffloadInfo() from bind_windows as well
* Allow lightweight tags to be used in the version
2024-05-24 16:18:23 +01:00
albexk
87d8c00f86 Up go to 1.22.3, up crypto to 0.21.0 2024-05-21 08:09:58 -07:00
albexk
c00bda9200 Fix output of the version command 2024-05-14 03:51:01 -07:00
albexk
d2b0fc9789 Add resetting of message types when closing the device 2024-05-14 03:51:01 -07:00
albexk
77d39ff3b9 Minor naming changes 2024-05-14 03:51:01 -07:00
albexk
e433d13df6 Add disabling UDP GSO when an error occurs due to inconsistent peer mtu 2024-05-14 03:51:01 -07:00
RomikB
3ddf952973 unsafe rebranding: change pipe name 2024-05-13 11:10:42 -07:00
albexk
3f0a3bcfa0 Fix wg reconnection problem after awg connection 2024-03-16 14:16:13 +00:00
AlexanderGalkov
4dddf62e57 Update Dockerfile
add wg and wg-quick symlinks

Signed-off-by: AlexanderGalkov <143902290+AlexanderGalkov@users.noreply.github.com>
2024-02-20 20:32:38 +07:00
tiaga
827ec6e14b
Merge pull request #21 from amnezia-vpn/fix-dockerfile
Fix Dockerfile
2024-02-13 21:47:55 +07:00
tiaga
92e28a0d14 Fix Dockerfile
Fix AmneziaWG tools installation.
2024-02-13 21:44:41 +07:00
tiaga
52fed4d362
Merge pull request #20 from amnezia-vpn/update_dockerfile
Update Dockerfile
2024-02-13 21:28:17 +07:00
tiaga
9c6b3ff332 Update Dockerfile
- rename `wg` and `wg-quick` to `awg` and `awg-quick` accordingly
- add iptables
- update AmneziaWG tools version
2024-02-13 21:27:34 +07:00
pokamest
7de7a9a754
Merge pull request #19 from amnezia-vpn/fix/go_sum
Fix go.sum
2024-02-12 05:31:57 -08:00
albexk
0c347529b8 Fix go.sum 2024-02-12 16:27:56 +03:00
albexk
6705978fc8 Add debug udp offload info 2024-02-12 04:40:23 -08:00
albexk
032e33f577 Fix Android UDP GRO check 2024-02-12 04:40:23 -08:00
albexk
59101fd202 Bump crypto, net, sys modules to the latest versions 2024-02-12 04:40:23 -08:00
tiaga
8bcfbac230
Merge pull request #17 from amnezia-vpn/fix-pipeline
Fix pipeline
2024-02-07 19:19:11 +07:00
tiaga
f0dfb5eacc Fix pipeline
Fix path to GitHub Actions workflow.
2024-02-07 18:53:55 +07:00
tiaga
9195025d8f
Merge pull request #16 from amnezia-vpn/pipeline
Add pipeline
2024-02-07 18:48:12 +07:00
tiaga
cbd414dfec Add pipeline
Build and push Docker image on a tag push.
2024-02-07 18:44:59 +07:00
tiaga
7155d20913
Merge pull request #14 from amnezia-vpn/docker
Update Dockerfile
2024-02-02 23:40:39 +07:00
tiaga
bfeb3954f6 Update Dockerfile
- update Alpine version
- improve `Dockerfile` to use pre-built AmneziaWG tools
2024-02-02 22:56:00 +07:00
Iurii Egorov
e3c9ec8012 Naming unify 2024-01-20 23:18:10 +03:00
pokamest
ce9d3866a3
Merge pull request #10 from amnezia-vpn/upstream-merge
Merge upstream changes
2024-01-14 13:37:00 -05:00
Iurii Egorov
e5f355e843 Fix incorrect configuration handling for zero-valued Jc 2024-01-14 18:22:02 +03:00
Iurii Egorov
c05b2ee2a3 Merge remote-tracking branch 'upstream/master' into upstream-merge 2024-01-09 21:34:30 +03:00
tiaga
180c9284f3
Merge pull request #7 from amnezia-vpn/docker
Add Dockerfile
2023-12-19 18:50:48 +07:00
tiaga
015e11875d Add Dockerfile
Build Docker image with the corresponding wg-tools version.
2023-12-19 18:46:04 +07:00
Martin Basovnik
12269c2761 device: fix possible deadlock in close method
There is a possible deadlock in `device.Close()` when you try to close
the device very soon after its start. The problem is that two different
methods acquire the same locks in different order:

1. device.Close()
 - device.ipcMutex.Lock()
 - device.state.Lock()

2. device.changeState(deviceState)
 - device.state.Lock()
 - device.ipcMutex.Lock()

Reproducer:

    func TestDevice_deadlock(t *testing.T) {
    	d := randDevice(t)
    	d.Close()
    }

Problem:

    $ go clean -testcache && go test -race -timeout 3s -run TestDevice_deadlock ./device | grep -A 10 sync.runtime_SemacquireMutex
    sync.runtime_SemacquireMutex(0xc000117d20?, 0x94?, 0x0?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    golang.zx2c4.com/wireguard/device.(*Device).Close(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:373 +0xb6
    golang.zx2c4.com/wireguard/device.TestDevice_deadlock(0x0?)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device_test.go:480 +0x2c
    testing.tRunner(0xc00014c000, 0x131d7b0)
    --
    sync.runtime_SemacquireMutex(0xc000130564?, 0x60?, 0xc000130548?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    sync.(*RWMutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/rwmutex.go:147 +0x45
    golang.zx2c4.com/wireguard/device.(*Device).upLocked(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:179 +0x72
    golang.zx2c4.com/wireguard/device.(*Device).changeState(0xc000130500, 0x1)

Signed-off-by: Martin Basovnik <martin.basovnik@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:38:47 +01:00
Jason A. Donenfeld
542e565baa device: do atomic 64-bit add outside of vector loop
Only bother updating the rxBytes counter once we've processed a whole
vector, since additions are atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:35:57 +01:00
Jordan Whited
7c20311b3d device: reduce redundant per-packet overhead in RX path
Peer.RoutineSequentialReceiver() deals with packet vectors and does not
need to perform timer and endpoint operations for every packet in a
given vector. Changing these per-packet operations to per-vector
improves throughput by as much as 10% in some environments.

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
4ffa9c2032 device: change Peer.endpoint locking to reduce contention
Access to Peer.endpoint was previously synchronized by Peer.RWMutex.
This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the
sole caller of Endpoint.ClearSrc(), which is signaled via a new bool,
Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now
set this bool, primarily via peer.markEndpointSrcForClearing().
Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an
updated conn.Endpoint is stored. This maintains the same event order as
before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx
is set, but before the next Peer.SendBuffers() call results in the
latest conn.Endpoint source being used for the next packet transmission.

These changes result in throughput improvements for single flow,
parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP
tests as measured on both Linux and Windows. Latency under load improves
especially for high throughput Linux scenarios. These improvements are
likely realized on all platforms to some degree, as the changes are not
platform-specific.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
d0bc03c707 tun: implement UDP GSO/GRO for Linux
Implement UDP GSO and GRO for the Linux tun.Device, which is made
possible by virtio extensions in the kernel's TUN driver starting in
v6.2.

secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is
used to demonstrate the effect of this commit between two Linux
computers with i5-12400 CPUs. There is roughly ~13us of round trip
latency between them. secnetperf was invoked with the following command
line options:
-stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0

The first result is from commit 2e0774f without UDP GSO/GRO on the TUN.

[conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us
SendTotalPackets=55859 SendSuspectedLostPackets=61
SendSpuriousLostPackets=59 SendCongestionCount=27
SendEcnCongestionCount=0 RecvTotalPackets=2779122
RecvReorderedPackets=0 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms).

The second result is with UDP GSO/GRO on the TUN.

[conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us
SendTotalPackets=165033 SendSuspectedLostPackets=64
SendSpuriousLostPackets=61 SendCongestionCount=53
SendEcnCongestionCount=0 RecvTotalPackets=11845268
RecvReorderedPackets=25267 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms).

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:27:22 +01:00
Jordan Whited
1cf89f5339 tun: fix Device.Read() buf length assumption on Windows
The length of a packet read from the underlying TUN device may exceed
the length of a supplied buffer when MTU exceeds device.MaxMessageSize.

Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:20:49 +01:00
pokamest
b43118018e
Merge pull request #4 from amnezia-vpn/upstream-merge
Upstream merge
2023-11-30 10:49:31 -08:00
Iurii Egorov
7af55a3e6f Merge remote-tracking branch 'upstream/master' 2023-11-17 22:42:13 +03:00
pokamest
c493b95f66
Update README.md
Signed-off-by: pokamest <pokamest@gmail.com>
2023-10-25 22:41:33 +01:00
Jason A. Donenfeld
2e0774f246 device: ratchet up max segment size on android
GRO requires big allocations to be efficient. This isn't great, as there
might be Android memory usage issues. So we should revisit this commit.
But at least it gets things working again.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-22 02:12:13 +02:00
Jason A. Donenfeld
b3df23dcd4 conn: set unused OOB to zero length
Otherwise in the event that we're using GSO without sticky sockets, we
pass garbage OOB buffers to sendmmsg, making a EINVAL, when GSO doesn't
set its header.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 19:32:07 +02:00
Jason A. Donenfeld
f502ec3fad conn: fix cmsg data padding calculation for gso
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 19:06:38 +02:00
Jason A. Donenfeld
5d37bd24e1 conn: separate gso and sticky control
Android wants GSO but not sticky.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 18:44:01 +02:00
Jason A. Donenfeld
24ea13351e conn: harmonize GOOS checks between "linux" and "android"
Otherwise GRO gets enabled on Android, but the conn doesn't use it,
resulting in bundled packets being discarded.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-18 21:14:13 +02:00
Jason A. Donenfeld
177caa7e44 conn: simplify supportsUDPOffload
This allows a kernel to support UDP_GRO while not supporting
UDP_SEGMENT.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-18 21:02:52 +02:00
Mazay B
b81ca925db peer.device.aSecMux.RLock added 2023-10-14 11:42:30 +01:00
James Tucker
42ec952ead go.mod,tun/netstack: bump gvisor
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:37:17 +02:00
James Tucker
ec8f6f82c2 tun: fix crash when ForceMTU is called after close
Close closes the events channel, resulting in a panic from send on
closed channel.

Reported-By: Brad Fitzpatrick <brad@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:37:17 +02:00
Jordan Whited
1ec454f253 device: move Queue{In,Out}boundElement Mutex to container type
Queue{In,Out}boundElement locking can contribute to significant
overhead via sync.Mutex.lockSlow() in some environments. These types
are passed throughout the device package as elements in a slice, so
move the per-element Mutex to a container around the slice.

Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
8a015f7c76 tun: reduce redundant checksumming in tcpGRO()
IPv4 header and pseudo header checksums were being computed on every
merge operation. Additionally, virtioNetHdr was being written at the
same time. This delays those operations until after all coalescing has
occurred.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
895d6c23cd tun: unwind summing loop in checksumNoFold()
$ benchstat old.txt new.txt
goos: linux
goarch: amd64
pkg: golang.zx2c4.com/wireguard/tun
cpu: 12th Gen Intel(R) Core(TM) i5-12400
                 │   old.txt    │               new.txt               │
                 │    sec/op    │   sec/op     vs base                │
Checksum/64-12     10.670n ± 2%   4.769n ± 0%  -55.30% (p=0.000 n=10)
Checksum/128-12    19.665n ± 2%   8.032n ± 0%  -59.16% (p=0.000 n=10)
Checksum/256-12     37.68n ± 1%   16.06n ± 0%  -57.37% (p=0.000 n=10)
Checksum/512-12     76.61n ± 3%   32.13n ± 0%  -58.06% (p=0.000 n=10)
Checksum/1024-12   160.55n ± 4%   64.25n ± 0%  -59.98% (p=0.000 n=10)
Checksum/1500-12   231.05n ± 7%   94.12n ± 0%  -59.26% (p=0.000 n=10)
Checksum/2048-12    309.5n ± 3%   128.5n ± 0%  -58.48% (p=0.000 n=10)
Checksum/4096-12    603.8n ± 4%   257.2n ± 0%  -57.41% (p=0.000 n=10)
Checksum/8192-12   1185.0n ± 3%   515.5n ± 0%  -56.50% (p=0.000 n=10)
Checksum/9000-12   1328.5n ± 5%   564.8n ± 0%  -57.49% (p=0.000 n=10)
Checksum/9001-12   1340.5n ± 3%   564.8n ± 0%  -57.87% (p=0.000 n=10)
geomean             185.3n        77.99n       -57.92%

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
4201e08f1d device: distribute crypto work as slice of elements
After reducing UDP stack traversal overhead via GSO and GRO,
runtime.chanrecv() began to account for a high percentage (20% in one
environment) of perf samples during a throughput benchmark. The
individual packet channel ops with the crypto goroutines was the primary
contributor to this overhead.

Updating these channels to pass vectors, which the device package
already handles at its ends, reduced this overhead substantially, and
improved throughput.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is with UDP GSO and GRO, and with single element
channels.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

The second result is with channels updated to pass a slice of
elements.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  13.2 GBytes  11.3 Gbits/sec  182   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  13.2 GBytes  11.3 Gbits/sec  182   sender
[  5]   0.00-10.04  sec  13.2 GBytes  11.3 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
6a84778f2c conn, device: use UDP GSO and GRO on Linux
StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is
dependent on checksum offload support on the egress netdev. UDP GSO
will be disabled in the event sendmmsg() returns EIO, which is a strong
signal that the egress netdev does not support checksum offload.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is from commit 052af4a without UDP GSO or GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139   3.01 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139  sender
[  5]   0.00-10.04  sec  9.85 GBytes  8.42 Gbits/sec        receiver

The second result is with UDP GSO and GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
pokamest
b34974c476
Merge pull request #3 from amnezia-vpn/bugfix/uapi_adv_sec_onoff
Manage advanced security via uapi
2023-10-09 06:07:20 -07:00
Mazay B
f30419e0d1 Manage advanced sec via uapi 2023-10-09 13:22:49 +01:00
Mark Puha
8f1a6a10b2
Advanced security (#2)
* Advanced security header layer & config
2023-10-05 21:41:27 +01:00
Dimitri Papadopoulos Orfanos
469159ecf7 netstack: fix typo
Signed-off-by: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-07-04 15:56:30 +02:00
Brad Fitzpatrick
6e755e132a all: adjust build tags for wasip1/wasm
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-07-04 15:54:42 +02:00
springhack
1f25eac395 conn: windows: add missing return statement in DstToString AF_INET
Signed-off-by: SpringHack <springhack@live.cn>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-06-27 18:02:50 +02:00
James Tucker
25eb973e00 conn: store IP_PKTINFO cmsg in StdNetendpoint src
Replace the src storage inside StdNetEndpoint with a copy of the raw
control message buffer, to reduce allocation and perform less work on a
per-packet basis.

Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-06-27 17:48:32 +02:00
James Tucker
b7cd547315 device: wait for and lock ipc operations during close
If an IPC operation is in flight while close starts, it is possible for
both processes to deadlock. Prevent this by taking the IPC lock at the
start of close and for the duration.

Signed-off-by: James Tucker <jftucker@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-06-27 17:43:35 +02:00
Jordan Whited
052af4a807 tun: use correct IP header comparisons in tcpGRO() and tcpPacketsCanCoalesce()
tcpGRO() was using an incorrect IPv4 more fragments bit mask.

tcpPacketsCanCoalesce() was not distinguishing tcp6 from tcp4, and TTL
values were not compared. TTL values should be equal at the IP layer,
otherwise the packets should not coalesce. This tracks with the kernel.

Reviewed-by: Denton Gentry <dgentry@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-25 23:13:38 +01:00
Jordan Whited
aad7fca9c5 tun: disqualify tcp4 packets w/IP options from coalescing
IP options were not being compared prior to coalescing. They are not
commonly used. Disqualification due to nonzero options is in line with
the kernel.

Reviewed-by: Denton Gentry <dgentry@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-25 23:13:26 +01:00
Jason A. Donenfeld
6f895be10d conn: move booleans to bottom of StdNetBind struct
This results in a more compact structure.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-24 17:05:07 +01:00
Jason A. Donenfeld
6a07b2a355 conn: use ipv6 message pool for ipv6 receiving
Looks like a simple copy&paste error.

Fixes: 9e2f386 ("conn, device, tun: implement vectorized I/O on Linux")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-24 16:20:16 +01:00
Jordan Whited
334b605e72 conn: fix StdNetEndpoint data race by dynamically allocating endpoints
In 9e2f386 ("conn, device, tun: implement vectorized I/O on Linux"), the
Linux-specific Bind implementation was collapsed into StdNetBind. This
introduced a race on StdNetEndpoint from getSrcFromControl() and
setSrcControl().

Remove the sync.Pool involved in the race, and simplify StdNetBind's
receive path to allocate StdNetEndpoint on the heap instead, with the
intent for it to be cleaned up by the GC, later. This essentially
reverts ef5c587 ("conn: remove the final alloc per packet receive"),
adding back that allocation, unfortunately.

This does slightly increase resident memory usage in higher throughput
scenarios. StdNetBind is the only Bind implementation that was using
this Endpoint recycling technique prior to this commit.

This is considered a stop-gap solution, and there are plans to replace
the allocation with a better mechanism.

Reported-by: lsc <lsc@lv6.tw>
Link: https://lore.kernel.org/wireguard/ac87f86f-6837-4e0e-ec34-1df35f52540e@lv6.tw/
Fixes: 9e2f386 ("conn, device, tun: implement vectorized I/O on Linux")
Cc: Josh Bleecher Snyder <josharian@gmail.com>
Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-24 14:37:13 +01:00
Jason A. Donenfeld
3a9e75374f conn: disable sticky sockets on Android
We can't have the netlink listener socket, so it's not possible to
support it. Plus, android networking stack complexity makes it a bit
tricky anyway, so best to leave it disabled.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-23 18:39:00 +01:00
Jason A. Donenfeld
cc20c08c96 global: remove old style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-23 18:34:09 +01:00
Jordan Whited
1417a47c8f tun: replace ErrorBatch() with errors.Join()
Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-17 15:18:04 +01:00
Jordan Whited
7f511c3bb1 go.mod: bump to Go 1.20
Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-17 15:18:04 +01:00
Jordan Whited
07a1e55270 conn: fix getSrcFromControl() iteration
We only expect a single control message in the normal case, but this
would loop infinitely if there were more.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-16 17:45:41 +01:00
Jordan Whited
fff53afca7 conn: use CmsgSpace() for ancillary data buf sizing
CmsgLen() does not account for data alignment.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-16 17:45:41 +01:00
Jason A. Donenfeld
0ad14a89f5 global: buff -> buf
This always struck me as kind of weird and non-standard.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-13 17:55:53 +01:00
Jason A. Donenfeld
7d327ed35a conn: use right cmsghdr len types on 32-bit in sticky test
Cmsghdr uses uint32 and uint64 on 32-bit and 64-bit respectively for the
Len member, which makes assignments and comparisons slightly more
irksome than usual.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 16:19:18 +01:00
Jordan Whited
f41f474466 conn: make StdNetBind.BatchSize() return 1 for non-Linux
This commit updates StdNetBind.BatchSize() to return 1 instead of
IdealBatchSize for non-Linux platforms. Non-Linux platforms do not
yet benefit from values > 1, which only serves to increase memory
consumption.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:53:09 +01:00
Jordan Whited
5819c6af28 tun/netstack: enable TCP Selective Acknowledgements
Enable TCP SACK for the gVisor Stack used in tun/netstack. This can
improve throughput by an order of magnitude in the presence of packet
loss.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:39 +01:00
Jordan Whited
6901984f6a conn: ensure control message size is respected in StdNetBind
This commit re-slices received control messages in StdNetBind to the
value the OS reports on a successful read. Previously, the len of this
slice would always be srcControlSize, which could result in control
message values leaking through a sync.Pool round trip. This is
unlikely with the IP_PKTINFO socket option set successfully, but
should be guarded against.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:32 +01:00
Jordan Whited
2fcdaf9799 conn: fix StdNetBind fallback on Windows
If RIO is unavailable, NewWinRingBind() falls back to StdNetBind.
StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving
datagrams, specifically via the {Read,Write}Batch methods.
These methods are unimplemented on Windows and will return runtime
errors as a result. Additionally, only Linux benefits from these
x/net types for reading and writing, so we update StdNetBind to fall
back to the standard library net package for all platforms other than
Linux.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:24 +01:00
Jason A. Donenfeld
dbd949307e conn: inch BatchSize toward being non-dynamic
There's not really a use at the moment for making this configurable, and
once bind_windows.go behaves like bind_std.go, we'll be able to use
constants everywhere. So begin that simplification now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:22 +01:00
Jordan Whited
f26efb65f2 conn: set SO_{SND,RCV}BUF to 7MB on the Bind UDP socket
The conn.Bind UDP sockets' send and receive buffers are now being sized
to 7MB, whereas they were previously inheriting the system defaults.
The system defaults are considerably small and can result in dropped
packets on high speed links. By increasing the size of these buffers we
are able to achieve higher throughput in the aforementioned case.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with 32-core Xeon Platinum CPUs @ 2.9Ghz. There is
roughly ~125us of round trip latency between them.

The first result is from commit 792b49c which uses the system defaults,
e.g. net.core.{r,w}mem_max = 212992. The TCP retransmits are correlated
with buffer full drops on both sides.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  4.74 GBytes  4.08 Gbits/sec  2742   285 KBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  4.74 GBytes  4.08 Gbits/sec  2742   sender
[  5]   0.00-10.04  sec  4.74 GBytes  4.06 Gbits/sec         receiver

The second result is after increasing SO_{SND,RCV}BUF to 7MB, i.e.
applying this commit.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  6.14 GBytes  5.27 Gbits/sec    0   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  6.14 GBytes  5.27 Gbits/sec    0    sender
[  5]   0.00-10.04  sec  6.14 GBytes  5.25 Gbits/sec         receiver

The specific value of 7MB is chosen as it is the max supported by a
default configuration of macOS. A value greater than 7MB may further
benefit throughput for environments with higher network latency and
lower CPU clocks, but will also increase latency under load
(bufferbloat). Some platforms will silently clamp the value to other
maximums. On Linux, we use SO_{SND,RCV}BUFFORCE in case 7MB is beyond
net.core.{r,w}mem_max.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:20 +01:00
Jason A. Donenfeld
f67c862a2a go.mod: bump deps
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:18 +01:00
Jordan Whited
9e2f386022 conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.

Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.

Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:17 +01:00
Jordan Whited
3bb8fec7e4 conn, device, tun: implement vectorized I/O plumbing
Accept packet vectors for reading and writing in the tun.Device and
conn.Bind interfaces, so that the internal plumbing between these
interfaces now passes a vector of packets. Vectors move untouched
between these interfaces, i.e. if 128 packets are received from
conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is
no internal buffering.

Currently, existing implementations are only adjusted to have vectors
of length one. Subsequent patches will improve that.

Also, as a related fixup, use the unix and windows packages rather than
the syscall package when possible.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:13 +01:00
Jason A. Donenfeld
21636207a6 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-23 19:12:33 +01:00
Jason A. Donenfeld
c7b76d3d9e device: uniformly check ECDH output for zeros
For some reason, this was omitted for response messages.

Reported-by: z <dzm@unexpl0.red>
Fixes: 8c34c4c ("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-16 16:33:14 +01:00
Jordan Whited
1e2c3e5a3c tun: guard Device.Events() against chan writes
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-09 12:35:58 -03:00
Jason A. Donenfeld
ebbd4a4330 global: bump copyright year
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:39:29 -03:00
Soren L. Hansen
0ae4b3177c tun/netstack: make http examples communicate with each other
This seems like a much better demonstration as it removes the need for
external components.

Signed-off-by: Søren L. Hansen <sorenisanerd@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:38:19 -03:00
Colin Adler
077ce8ecab tun/netstack: bump gvisor
Bump gVisor to a recent known-good version.

Signed-off-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:10:52 -03:00
Jason A. Donenfeld
bb719d3a6e global: bump copyright year
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-20 17:21:32 +02:00
Colin Adler
fde0a9525a tun/netstack: ensure (*netTun).incomingPacket chan is closed
Without this, `device.Close()` will deadlock.

Signed-off-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-20 17:17:29 +02:00
Brad Fitzpatrick
b51010ba13 all: use Go 1.19 and its atomic types
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-04 12:57:30 +02:00
Jason A. Donenfeld
d1d08426b2 tun/netstack: remove separate module
Now that the gvisor deps aren't insane, we can just do this in the main
module.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-08-29 12:14:05 -04:00
Shengjing Zhu
3381e21b18 tun/netstack: bump to latest gvisor
To build with go1.19, gvisor needs
99325baf ("Bump gVisor build tags to go1.19").

However gvisor.dev/gvisor/pkg/tcpip/buffer is no longer available,
so refactor to use gvisor.dev/gvisor/pkg/tcpip/link/channel directly.

Signed-off-by: Shengjing Zhu <i@zhsj.me>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-08-29 12:01:05 -04:00
Brad Fitzpatrick
c31a7b1ab4 conn, device, tun: set CLOEXEC on fds
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-07-04 01:42:12 +02:00
Tobias Klauser
6a08d81f6b tun: use ByteSliceToString from golang.org/x/sys/unix
Use unix.ByteSliceToString in (*NativeTun).nameSlice to convert the
TUNGETIFF ioctl result []byte to a string.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-06-01 15:00:07 +02:00
Josh Bleecher Snyder
ef5c587f78 conn: remove the final alloc per packet receive
This does bind_std only; other platforms remain.

The remaining alloc per iteration in the Throughput benchmark
comes from the tuntest package, and should not appear in regular use.

name           old time/op      new time/op      delta
Latency-10         25.2µs ± 1%      25.0µs ± 0%   -0.58%  (p=0.006 n=10+10)
Throughput-10      2.44µs ± 3%      2.41µs ± 2%     ~     (p=0.140 n=10+8)

name           old alloc/op     new alloc/op     delta
Latency-10           854B ± 5%        741B ± 3%  -13.22%  (p=0.000 n=10+10)
Throughput-10        265B ±34%        267B ±39%     ~     (p=0.670 n=10+10)

name           old allocs/op    new allocs/op    delta
Latency-10           16.0 ± 0%        14.0 ± 0%  -12.50%  (p=0.000 n=10+10)
Throughput-10        2.00 ± 0%        1.00 ± 0%  -50.00%  (p=0.000 n=10+10)

name           old packet-loss  new packet-loss  delta
Throughput-10        0.01 ±82%       0.01 ±282%     ~     (p=0.321 n=9+8)

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-04-07 03:31:10 +02:00
Jason A. Donenfeld
193cf8d6a5 conn: use netip for std bind
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-17 22:23:02 -06:00
Jason A. Donenfeld
ee1c8e0e87 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 21:32:14 -06:00
Jason A. Donenfeld
95b48cdb39 tun/netstack: bump mod
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 18:01:34 -06:00
Jason A. Donenfeld
5aff28b14c mod: bump packages and remove compat netip
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 17:51:47 -06:00
Josh Bleecher Snyder
46826fc4e5 all: use any in place of interface{}
Enabled by using Go 1.18. A bit less verbose.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2022-03-16 16:40:24 -07:00
Josh Bleecher Snyder
42c9af45e1 all: update to Go 1.18
Bump go.mod and README.

Switch to upstream net/netip.

Use strings.Cut.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2022-03-16 16:09:48 -07:00
Alexander Neumann
ae6bc4dd64 tun/netstack: check error returned by SetDeadline()
Signed-off-by: Alexander Neumann <alexander.neumann@redteam-pentesting.de>
[Jason: don't wrap deadline error.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-09 18:27:36 -07:00
Alexander Neumann
2cec4d1a62 tun/netstack: update to latest wireguard-go
This commit fixes all callsites of netip.AddrFromSlice(), which has
changed its signature and now returns two values.

Signed-off-by: Alexander Neumann <alexander.neumann@redteam-pentesting.de>
[Jason: remove error handling from AddrFromSlice.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-09 18:27:36 -07:00
Jason A. Donenfeld
3b95c81cc1 tun/netstack: simplify read timeout on ping socket
I'm not 100% sure this is correct, but it certainly is a lot simpler.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-02-02 23:30:31 +01:00
Thomas H. Ptacek
b9669b734e tun/netstack: implement ICMP ping
Provide a PacketConn interface for netstack's ICMP endpoint; netstack
currently only provides EchoRequest/EchoResponse ICMP support, so this
code exposes only an interface for doing ping.

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
[Jason: rework structure, match std go interfaces, add example code]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-02-02 23:09:37 +01:00
Jason A. Donenfeld
e0b8f11489 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-01-17 17:37:42 +01:00
Jason A. Donenfeld
114a3db918 ipc: bsd: try again if kqueue returns EINTR
Reported-by: J. Michael McAtee <mmcatee@jumptrading.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-01-14 16:10:43 +01:00
Jason A. Donenfeld
9c9e7e2724 global: apply gofumpt
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-12-09 23:15:55 +01:00
Jason A. Donenfeld
2dd424e2d8 device: handle peer post config on blank line
We missed a function exit point. This was exacerbated by e3134bf
("device: defer state machine transitions until configuration is
complete"), but the bug existed prior. Minus provided the following
useful reproducer script:

    #!/usr/bin/env bash

    set -eux

    make wireguard-go || exit 125

    ip netns del test-ns || true
    ip netns add test-ns
    ip link add test-kernel type wireguard
    wg set test-kernel listen-port 0 private-key <(echo "QMCfZcp1KU27kEkpcMCgASEjDnDZDYsfMLHPed7+538=") peer "eDPZJMdfnb8ZcA/VSUnLZvLB2k8HVH12ufCGa7Z7rHI=" allowed-ips 10.51.234.10/32
    ip link set test-kernel netns test-ns up
    ip -n test-ns addr add 10.51.234.1/24 dev test-kernel
    port=$(ip netns exec test-ns wg show test-kernel listen-port)

    ip link del test-go || true
    ./wireguard-go test-go
    wg set test-go private-key <(echo "WBM7qimR3vFk1QtWNfH+F4ggy/hmO+5hfIHKxxI4nF4=") peer "+nj9Dkqpl4phsHo2dQliGm5aEiWJJgBtYKbh7XjeNjg=" allowed-ips 0.0.0.0/0 endpoint 127.0.0.1:$port
    ip addr add 10.51.234.10/24 dev test-go
    ip link set test-go up

    ping -c2 -W1 10.51.234.1

Reported-by: minus <minus@mnus.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-29 12:31:54 -05:00
Josh Bleecher Snyder
387f7c461a device: reduce peer lock critical section in UAPI
The deferred RUnlock calls weren't executing until all peers
had been processed. Add an anonymous function so that each
peer may be unlocked as soon as it is completed.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Josh Bleecher Snyder
4d87c9e824 device: remove code using unsafe
There is no performance impact.

name                             old time/op  new time/op  delta
TrieIPv4Peers100Addresses1000-8  78.6ns ± 1%  79.4ns ± 3%    ~     (p=0.604 n=10+9)
TrieIPv4Peers10Addresses10-8     29.1ns ± 2%  28.8ns ± 1%  -1.12%  (p=0.014 n=10+9)
TrieIPv6Peers100Addresses1000-8  78.9ns ± 1%  78.6ns ± 1%    ~     (p=0.492 n=10+10)
TrieIPv6Peers10Addresses10-8     29.3ns ± 2%  28.6ns ± 2%  -2.16%  (p=0.000 n=10+10)

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Jason A. Donenfeld
ef8d6804d7 global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Jason A. Donenfeld
de7c702ace device: only propagate roaming value before peer is referenced elsewhere
A peer.endpoint never becomes nil after being not-nil, so creation is
the only time we actually need to set this. This prevents a race from
when the variable is actually used elsewhere, and allows us to avoid an
expensive atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:16:04 +01:00
Jason A. Donenfeld
fc4f975a4d device: align 64-bit atomic member in Device
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
Jason A. Donenfeld
9d699ba730 device: start peers before running handshake test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
Jason A. Donenfeld
425f7c726b Makefile: don't use test -v because it hides failures in scrollback
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
David Anderson
3cae233d69 device: fix nil pointer dereference in uapi read
Signed-off-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 20:43:26 +01:00
Jason A. Donenfeld
111e0566dc device: make new peers inherit broken mobile semantics
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
e3134bf665 device: defer state machine transitions until configuration is complete
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
63abb5537b device: do not consume handshake messages if not running
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
851efb1bb6 tun: move wintun to its own repo
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-04 12:53:55 +01:00
Jason A. Donenfeld
c07dd60cdb namedpipe: rename from winpipe to keep in sync with CL299009
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-04 12:53:52 +01:00
Jason A. Donenfeld
eb6302c7eb device: timers: use pre-seeded per-thread unlocked fastrandn for jitter
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-28 13:47:50 +02:00
Jason A. Donenfeld
60683d7361 device: timers: seed unsafe rng before use for jitter
Forgetting to seed the unsafe rng, the jitter before followed a fixed
pattern, which didn't help when a fleet of computers all boot at once.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-28 13:34:21 +02:00
Jason A. Donenfeld
e42c6c4bc2 wintun: align 64-bit argument on ARM32
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-26 14:53:40 +02:00
Jason A. Donenfeld
828a885a71 README: raise minimum Go to 1.17
Suggested-by: Adam Bliss <abliss@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-25 17:53:11 +02:00
Mikael Magnusson
f1f626090e tun/netstack: update gvisor
Update gvisor to v0.0.0-20211020211948-f76a604701b6, which requires some
changes to tun.go:

WriteRawPacket: Add function with not implemented error.

CreateNetTUN: Replace stack.AddAddress with stack.AddProtocolAddress, and
fix IPv6 address in error message.

Signed-off-by: Mikael Magnusson <mikma@users.sourceforge.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-22 13:22:29 -06:00
Brad Fitzpatrick
82e0b734e5 ipc, rwcancel: compile on js/wasm
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-10-20 14:50:05 -06:00
Jason A. Donenfeld
fdf57a1fa4 wintun: allow retrieving DLL version
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-20 12:13:44 -06:00
Jason A. Donenfeld
f87e87af0d version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 23:27:13 -06:00
Jason A. Donenfeld
ba9e364dab wintun: remove memmod option for dll loading
Only wireguard-windows used this, and it's moving to wgnt exclusively.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 22:49:38 -06:00
Jason A. Donenfeld
dfd688b6aa global: remove old-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 12:02:10 -06:00
Jason A. Donenfeld
c01d52b66a global: add newer-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 11:46:53 -06:00
Jason A. Donenfeld
82d2aa87aa wintun: use new swdevice-based API for upcoming Wintun 0.14
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 00:26:46 -06:00
Jason A. Donenfeld
982d5d2e84 conn,wintun: use unsafe.Slice instead of unsafeSlice
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:57:53 -06:00
Jason A. Donenfeld
642a56e165 memmod: import from wireguard-windows
We'll eventually be getting rid of it here, but keep it sync'd up for
now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:53:36 -06:00
Jason A. Donenfeld
bb745b2ea3 rwcancel: use unix.Poll again but bump x/sys so it uses ppoll under the hood
This reverts commit fcc601dbf0 but then
bumps go.mod.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-27 14:19:15 -06:00
Jason A. Donenfeld
fcc601dbf0 rwcancel: use ppoll on Linux for Android
This is a temporary measure while we wait for
https://go-review.googlesource.com/c/sys/+/352310 to land.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-26 17:16:38 -06:00
Tobias Klauser
217ac1016b tun: make operateonfd.go build tags more specific
(*NativeTun).operateOnFd is only used on darwin and freebsd. Adjust the
build tags accordingly.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:54:01 -06:00
Tobias Klauser
eae5e0f3a3 tun: avoid leaking sock fd in CreateTUN error cases
At these points, the socket file descriptor is not yet wrapped in an
*os.File, so it needs to be closed explicitly on error.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:53:49 -06:00
Jason A. Donenfeld
2ef39d4754 global: add new go 1.17 build comments
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-05 16:00:43 +02:00
Jason A. Donenfeld
3957e9b9dd memmod: register exception handler tables
Otherwise recent WDK binaries fail on ARM64, where an exception handler
is used for trapping an illegal instruction when ARMv8.1 atomics are
being tested for functionality.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-08-05 14:56:48 +02:00
Jason A. Donenfeld
bad6caeb82 memmod: fix protected delayed load the right way
The reason this was failing before is that dloadsup.h's
DloadObtainSection was doing a linear search of sections to find which
header corresponds with the IMAGE_DELAYLOAD_DESCRIPTOR section, and we
were stupidly overwriting the VirtualSize field, so the linear search
wound up matching the .text section, which then it found to not be
marked writable and failed with FAST_FAIL_DLOAD_PROTECTION_FAILURE.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:27:40 +02:00
Jason A. Donenfeld
c89f5ca665 memmod: disable protected delayed load for now
Probably a bad idea, but we don't currently support it, and those huge
windows.NewCallback trampolines make juicer targets anyway.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:13:03 +02:00
Jason A. Donenfeld
15b24b6179 ipc: allow admins but require high integrity label
Might be more reasonable.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-24 17:01:02 +02:00
Jason A. Donenfeld
f9b48a961c device: zero out allowedip node pointers when removing
This should make it a bit easier for the garbage collector.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-04 16:33:28 +02:00
Jason A. Donenfeld
d0cf96114f device: limit allowedip fuzzer a to 4 times through
Trying this for every peer winds up being very slow and precludes it
from acceptable runtime in the CI, so reduce this to 4.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 18:22:50 +02:00
Jason A. Donenfeld
841756e328 device: simplify allowedips lookup signature
The inliner should handle this for us.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
c382222eab device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
b41f4cc768 device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 15:08:42 +02:00
Jason A. Donenfeld
4a57024b94 device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 13:51:03 +02:00
Josh Bleecher Snyder
64cb82f2b3 go.mod: bump golang.org/x/sys again
To pick up https://go-review.googlesource.com/c/sys/+/307129.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-25 16:34:54 +02:00
Jason A. Donenfeld
c27ff9b9f6 device: allow reducing queue constants on iOS
Heavier network extensions might require the wireguard-go component to
use less ram, so let users of this reduce these as needed.

At some point we'll put this behind a configuration method of sorts, but
for now, just expose the consts as vars.

Requested-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-22 01:00:51 +02:00
Jason A. Donenfeld
99e8b4ba60 tun: linux: account for interface removal from outside
On Linux we can run `ip link del wg0`, in which case the fd becomes
stale, and we should exit. Since this is an intentional action, don't
treat it as an error.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:26:01 +02:00
Jason A. Donenfeld
bd83f0ac99 conn: linux: protect read fds
The -1 protection was removed and the wrong error was returned, causing
us to read from a bogus fd. As well, remove the useless closures that
aren't doing anything, since this is all synchronized anyway.

Fixes: 10533c3 ("all: make conn.Bind.Open return a slice of receive functions")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:09:55 +02:00
Jason A. Donenfeld
50d779833e rwcancel: use ordinary os.ErrClosed instead of custom error
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:56:36 +02:00
Jason A. Donenfeld
a9b377e9e1 rwcancel: use poll instead of select
Suggested-by: Lennart Poettering <lennart@poettering.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:42:34 +02:00
Jason A. Donenfeld
9087e444e6 device: optimize Peer.String even more
This reduces the allocation, branches, and amount of base64 encoding.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-18 17:43:53 +02:00
Josh Bleecher Snyder
25ad08a591 device: optimize Peer.String
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-05-14 00:37:30 +02:00
Jason A. Donenfeld
5846b62283 conn: windows: set count=0 on retry
When retrying, if count is not 0, we forget to dequeue another request,
and so the ring fills up and errors out.

Reported-by: Sascha Dierberg <dierberg@dresearch-fe.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-11 16:47:17 +02:00
Jason A. Donenfeld
9844c74f67 main: replace crlf on windows in fmt test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 22:23:32 +02:00
Jason A. Donenfeld
4e9e5dad09 main: check that code is formatted in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 17:48:26 +02:00
Jason A. Donenfeld
39e0b6dade tun: format
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:27 +02:00
Jason A. Donenfeld
7121927b87 device: add ID to repeated routines
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:21 +02:00
Jason A. Donenfeld
326aec10af device: remove unusual ... in messages
We dont use ... in any other present progressive messages except these.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:17:41 +02:00
Jason A. Donenfeld
efb8818550 device: avoid verbose log line during ordinary shutdown sequence
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:39:06 +02:00
Jason A. Donenfeld
69b39db0b4 tun: windows: set event before waiting
In 097af6e ("tun: windows: protect reads from closing") we made sure no
functions are running when End() is called, to avoid a UaF. But we still
need to kick that event somehow, so that Read() is allowed to exit, in
order to release the lock. So this commit calls SetEvent, while moving
the closing boolean to be atomic so it can be modified without locks,
and then moves to a WaitGroup for the RCU-like pattern.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:26:24 +02:00
Jason A. Donenfeld
db733ccd65 tun: windows: rearrange struct to avoid alignment trap on 32bit
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:19:00 +02:00
Jason A. Donenfeld
a7aec4449f tun: windows: check alignment in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:15:50 +02:00
Josh Bleecher Snyder
60a26371f4 device: log all errors received by RoutineReceiveIncoming
When debugging, it's useful to know why a receive func exited.

We were already logging that, but only in the "death spiral" case.
Move the logging up, to capture it always.
Reduce the verbosity, since it is not an error case any more.
Put the receive func name in the log line.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-06 11:22:13 +02:00
Jason A. Donenfeld
a544776d70 tun/netstack: update go mod and remove GSO argument
Reported-by: John Xiong <xiaoyang1258@yeah.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-06 11:07:26 +02:00
Jason A. Donenfeld
69a42a4eef tun: windows: send MTU update when forced MTU changes
Otherwise the padding doesn't get updated.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-05 11:42:45 +02:00
Jason A. Donenfeld
097af6e135 tun: windows: protect reads from closing
The code previously used the old errors channel for checking, rather
than the simpler boolean, which caused issues on shutdown, since the
errors channel was meaningless. However, looking at this exposed a more
basic problem: Close() and all the other functions that check the closed
boolean can race. So protect with a basic RW lock, to ensure that
Close() waits for all pending operations to complete.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:22:45 -04:00
Jason A. Donenfeld
8246d251ea conn: windows: do not error out when receiving UDP jumbogram
If we receive a large UDP packet, don't return an error to receive.go,
which then terminates the receive loop. Instead, simply retry.

Considering Winsock's general finickiness, we might consider other
places where an attacker on the wire can generate error conditions like
this.

Reported-by: Sascha Dierberg <sascha.dierberg@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:07:03 -04:00
130 changed files with 7184 additions and 4997 deletions

41
.github/workflows/build-if-tag.yml vendored Normal file
View file

@ -0,0 +1,41 @@
name: build-if-tag
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
env:
APP: amneziawg-go
jobs:
build:
runs-on: ubuntu-latest
name: build
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup metadata
uses: docker/metadata-action@v5
id: metadata
with:
images: amneziavpn/${{ env.APP }}
tags: type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.metadata.outputs.tags }}

2
.gitignore vendored
View file

@ -1 +1 @@
wireguard-go amneziawg-go

17
Dockerfile Normal file
View file

@ -0,0 +1,17 @@
FROM golang:1.24 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
go mod verify && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
unzip -j alpine-3.19-amneziawg-tools.zip && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go

View file

@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \ tag="$$(git describe --tags --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\nconst Version = "%s"\n' "$$tag")" && \ ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \ echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true git update-index --assume-unchanged version.go || true
@$(MAKE) wireguard-go @$(MAKE) amneziawg-go
wireguard-go: $(wildcard *.go) $(wildcard */*.go) amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@" go build -v -o "$@"
install: wireguard-go install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
test: test:
go test -v ./... go test ./...
clean: clean:
rm -f wireguard-go rm -f amneziawg-go
.PHONY: all clean test install generate-version-and-build .PHONY: all clean test install generate-version-and-build

View file

@ -1,24 +1,27 @@
# Go Implementation of [WireGuard](https://www.wireguard.com/) # Go Implementation of AmneziaWG
This is an implementation of WireGuard in Go. AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
## Usage ## Usage
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run: Simply run:
``` ```
$ wireguard-go wg0 $ amneziawg-go wg0
``` ```
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down. This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
To run wireguard-go without forking to the background, pass `-f` or `--foreground`: To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
``` ```
$ wireguard-go -f wg0 $ amneziawg-go -f wg0
``` ```
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`. To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@ -26,52 +29,24 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux ### Linux
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions. This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
### macOS ### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
### Windows ### Windows
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module. This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
### FreeBSD
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
### OpenBSD
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
## Building ## Building
This requires an installation of [go](https://golang.org) ≥ 1.16. This requires an installation of the latest version of [Go](https://go.dev/).
``` ```
$ git clone https://git.zx2c4.com/wireguard-go $ git clone https://github.com/amnezia-vpn/amneziawg-go
$ cd wireguard-go $ cd amneziawg-go
$ make $ make
``` ```
## License
Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,579 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"net"
"strconv"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
type ipv4Source struct {
Src [4]byte
Ifindex int32
}
type ipv6Source struct {
src [16]byte
// ifindex belongs in dst.ZoneId
}
type LinuxSocketEndpoint struct {
mu sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(ipv6Source{})]byte
isV6 bool
}
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct {
// mu guards sock4 and sock6 and the associated fds.
// As long as someone holds mu (read or write), the associated fds are valid.
mu sync.RWMutex
sock4 int
sock6 int
}
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
var _ Endpoint = (*LinuxSocketEndpoint)(nil)
var _ Bind = (*LinuxSocketBind)(nil)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
return nil, errors.New("invalid IP address")
}
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error
var newPort uint16
var tries int
if bind.sock4 != -1 || bind.sock6 != -1 {
return nil, 0, ErrBindAlreadyOpen
}
originalPort := port
again:
port = originalPort
var sock4, sock6 int
// Attempt ipv6 bind, update port if successful.
sock6, newPort, err = create6(port)
if err != nil {
if !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
} else {
port = newPort
}
// Attempt ipv4 bind, update port if successful.
sock4, newPort, err = create4(port)
if err != nil {
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
unix.Close(sock6)
tries++
goto again
}
if !errors.Is(err, syscall.EAFNOSUPPORT) {
unix.Close(sock6)
return nil, 0, err
}
} else {
port = newPort
}
var fns []ReceiveFunc
if sock4 != -1 {
fns = append(fns, bind.makeReceiveIPv4(sock4))
bind.sock4 = sock4
}
if sock6 != -1 {
fns = append(fns, bind.makeReceiveIPv6(sock6))
bind.sock6 = sock6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, port, nil
}
func (bind *LinuxSocketBind) SetMark(value uint32) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
if bind.sock4 != -1 {
err := unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
return nil
}
func (bind *LinuxSocketBind) Close() error {
// Take a readlock to shut down the sockets...
bind.mu.RLock()
if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
}
if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
}
bind.mu.RUnlock()
// ...and a write lock to close the fd.
// This ensures that no one else is using the fd.
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6)
bind.sock6 = -1
}
if bind.sock4 != -1 {
err2 = unix.Close(bind.sock4)
bind.sock4 = -1
}
if err1 != nil {
return err1
}
return err2
}
func (*LinuxSocketBind) makeReceiveIPv6(sock int) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
var end LinuxSocketEndpoint
n, err := receive6(sock, buff, &end)
return n, &end, err
}
}
func (*LinuxSocketBind) makeReceiveIPv4(sock int) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
var end LinuxSocketEndpoint
n, err := receive4(sock, buff, &end)
return n, &end, err
}
}
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
if !nend.isV6 {
if bind.sock4 == -1 {
return net.ErrClosed
}
return send4(bind.sock4, nend, buff)
} else {
if bind.sock6 == -1 {
return net.ErrClosed
}
return send6(bind.sock6, nend, buff)
}
}
func (end *LinuxSocketEndpoint) SrcIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.src4().Src[0],
end.src4().Src[1],
end.src4().Src[2],
end.src4().Src[3],
)
} else {
return end.src6().src[:]
}
}
func (end *LinuxSocketEndpoint) DstIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else {
return end.dst6().Addr[:]
}
}
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
}
func (end *LinuxSocketEndpoint) SrcToString() string {
return end.SrcIP().String()
}
func (end *LinuxSocketEndpoint) DstToString() string {
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
if !end.isV6 {
udpAddr.Port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
}
return udpAddr.String()
}
func (end *LinuxSocketEndpoint) ClearDst() {
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *LinuxSocketEndpoint) ClearSrc() {
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return -1, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet4).Port
}
return fd, uint16(addr.Port), err
}
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return -1, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet6).Port
}
return fd, uint16(addr.Port), err
}
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().Src,
Ifindex: end.src4().Ifindex,
},
}
end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.mu.Unlock()
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.mu.Unlock()
}
return err
}
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.mu.Unlock()
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.mu.Unlock()
}
return err
}
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = false
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().Src = cmsg.pktinfo.Spec_dst
end.src4().Ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
}
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = true
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
return size, nil
}

View file

@ -1,72 +1,126 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
import ( import (
"context"
"errors" "errors"
"fmt"
"net" "net"
"net/netip"
"runtime"
"strconv"
"sync" "sync"
"syscall" "syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
// StdNetBind is meant to be a temporary solution on platforms for which var (
// the sticky socket / source caching behavior has not yet been implemented. _ Bind = (*StdNetBind)(nil)
// It uses the Go's net package to implement networking. )
// See LinuxSocketBind for a proper implementation on the Linux platform.
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct { type StdNetBind struct {
mu sync.Mutex // protects following fields mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool blackhole4 bool
blackhole6 bool blackhole6 bool
} }
func NewStdNetBind() Bind { return &StdNetBind{} } func NewStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
type StdNetEndpoint net.UDPAddr msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
},
}
}
var _ Bind = (*StdNetBind)(nil) type StdNetEndpoint struct {
var _ Endpoint = (*StdNetEndpoint)(nil) // AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if
// supported. Typically this is a PKTINFO structure from/for control
// messages, see unix.PKTINFO for an example.
src []byte
}
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s) e, err := netip.ParseAddrPort(s)
return (*StdNetEndpoint)(addr), err if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
} }
func (*StdNetEndpoint) ClearSrc() {} func (e *StdNetEndpoint) ClearSrc() {
if e.src != nil {
func (e *StdNetEndpoint) DstIP() net.IP { // Truncate src, no need to reallocate.
return (*net.UDPAddr)(e).IP e.src = e.src[:0]
}
} }
func (e *StdNetEndpoint) SrcIP() net.IP { func (e *StdNetEndpoint) DstIP() netip.Addr {
return nil // not supported return e.AddrPort.Addr()
} }
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte { func (e *StdNetEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e) b, _ := e.AddrPort.MarshalBinary()
out := addr.IP.To4() return b
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
} }
func (e *StdNetEndpoint) DstToString() string { func (e *StdNetEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String() return e.AddrPort.String()
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
} }
func listenNet(network string, port int) (*net.UDPConn, int, error) { func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -80,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return conn, uaddr.Port, nil return conn.(*net.UDPConn), uaddr.Port, nil
} }
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock() s.mu.Lock()
defer bind.mu.Unlock() defer s.mu.Unlock()
var err error var err error
var tries int var tries int
if bind.ipv4 != nil || bind.ipv6 != nil { if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
@ -98,92 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure. // If uport is 0, we can retry on failure.
again: again:
port := int(uport) port := int(uport)
var ipv4, ipv6 *net.UDPConn var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
ipv4, port, err = listenNet("udp4", port) v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err return nil, 0, err
} }
// Listen on the same port as we're using for ipv4. // Listen on the same port as we're using for ipv4.
ipv6, port, err = listenNet("udp6", port) v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
ipv4.Close() v4conn.Close()
tries++ tries++
goto again goto again
} }
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close() v4conn.Close()
return nil, 0, err return nil, 0, err
} }
var fns []ReceiveFunc var fns []ReceiveFunc
if ipv4 != nil { if v4conn != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4)) s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
bind.ipv4 = ipv4 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
} }
if ipv6 != nil { if v6conn != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6)) s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
bind.ipv6 = ipv6 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
} }
if len(fns) == 0 { if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT return nil, 0, syscall.EAFNOSUPPORT
} }
return fns, uint16(port), nil return fns, uint16(port), nil
} }
func (bind *StdNetBind) Close() error { func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
bind.mu.Lock() for i := range *msgs {
defer bind.mu.Unlock() (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error var err1, err2 error
if bind.ipv4 != nil { if s.ipv4 != nil {
err1 = bind.ipv4.Close() err1 = s.ipv4.Close()
bind.ipv4 = nil s.ipv4 = nil
s.ipv4PC = nil
} }
if bind.ipv6 != nil { if s.ipv6 != nil {
err2 = bind.ipv6.Close() err2 = s.ipv6.Close()
bind.ipv6 = nil s.ipv6 = nil
s.ipv6PC = nil
} }
bind.blackhole4 = false s.blackhole4 = false
bind.blackhole6 = false s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil { if err1 != nil {
return err1 return err1
} }
return err2 return err2
} }
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { type ErrUDPGSODisabled struct {
return func(buff []byte) (int, Endpoint, error) { onLaddr string
n, endpoint, err := conn.ReadFromUDP(buff) RetryErr error
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*StdNetEndpoint)(endpoint), err
}
} }
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { func (e ErrUDPGSODisabled) Error() string {
return func(buff []byte) (int, Endpoint, error) { return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
n, endpoint, err := conn.ReadFromUDP(buff)
return n, (*StdNetEndpoint)(endpoint), err
}
} }
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { func (e ErrUDPGSODisabled) Unwrap() error {
var err error return e.RetryErr
nend, ok := endpoint.(*StdNetEndpoint) }
if !ok {
return ErrWrongEndpointType
}
bind.mu.Lock() func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
blackhole := bind.blackhole4 s.mu.Lock()
conn := bind.ipv4 blackhole := s.blackhole4
if nend.IP.To4() == nil { conn := s.ipv4
blackhole = bind.blackhole6 offload := s.ipv4TxOffload
conn = bind.ipv6 br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
br = s.ipv6PC
is6 = true
offload = s.ipv6TxOffload
} }
bind.mu.Unlock() s.mu.Unlock()
if blackhole { if blackhole {
return nil return nil
@ -191,6 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil { if conn == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err return err
} }
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
return err
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}

250
conn/bind_std_test.go Normal file
View file

@ -0,0 +1,250 @@
package conn
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
fns, _, err := bind.Open(0)
if err != nil {
t.Fatal(err)
}
bind.Close()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, 1)
sizes := make([]int, 1)
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(bufs, sizes, eps)
}
}
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func Test_coalesceMessages(t *testing.T) {
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1").To4(),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func mockGetGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_splitCoalescedMessages(t *testing.T) {
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1<<16-1)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
@ -9,6 +9,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -16,7 +17,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio" "github.com/amnezia-vpn/amneziawg-go/conn/winrio"
) )
const ( const (
@ -73,7 +74,7 @@ type afWinRingBind struct {
type WinRingBind struct { type WinRingBind struct {
v4, v6 afWinRingBind v4, v6 afWinRingBind
mu sync.RWMutex mu sync.RWMutex
isOpen uint32 isOpen atomic.Uint32 // 0, 1, or 2
} }
func NewDefaultBind() Bind { return NewWinRingBind() } func NewDefaultBind() Bind { return NewWinRingBind() }
@ -90,8 +91,10 @@ type WinRingEndpoint struct {
data [30]byte data [30]byte
} }
var _ Bind = (*WinRingBind)(nil) var (
var _ Endpoint = (*WinRingEndpoint)(nil) _ Bind = (*WinRingBind)(nil)
_ Endpoint = (*WinRingEndpoint)(nil)
)
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
host, port, err := net.SplitHostPort(s) host, port, err := net.SplitHostPort(s)
@ -121,27 +124,25 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS return nil, windows.ERROR_INVALID_ADDRESS
} }
var src []byte
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen)) copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
copy(dst[:], src)
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
} }
func (*WinRingEndpoint) ClearSrc() {} func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() net.IP { func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family { switch e.family {
case windows.AF_INET: case windows.AF_INET:
return append([]byte{}, e.data[2:6]...) return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6: case windows.AF_INET6:
return append([]byte{}, e.data[6:22]...) return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
} }
return nil return netip.Addr{}
} }
func (e *WinRingEndpoint) SrcIP() net.IP { func (e *WinRingEndpoint) SrcIP() netip.Addr {
return nil // not supported return netip.Addr{} // not supported
} }
func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToBytes() []byte {
@ -163,15 +164,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string { func (e *WinRingEndpoint) DstToString() string {
switch e.family { switch e.family {
case windows.AF_INET: case windows.AF_INET:
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
return addr.String()
case windows.AF_INET6: case windows.AF_INET6:
var zone string var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10) zone = strconv.FormatUint(uint64(scope), 10)
} }
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
return addr.String()
} }
return "" return ""
} }
@ -213,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
} }
func (bind *WinRingBind) closeAndZero() { func (bind *WinRingBind) closeAndZero() {
atomic.StoreUint32(&bind.isOpen, 0) bind.isOpen.Store(0)
bind.v4.CloseAndZero() bind.v4.CloseAndZero()
bind.v6.CloseAndZero() bind.v6.CloseAndZero()
} }
@ -277,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
bind.closeAndZero() bind.closeAndZero()
} }
}() }()
if atomic.LoadUint32(&bind.isOpen) != 0 { if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
var sa windows.Sockaddr var sa windows.Sockaddr
@ -300,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
return nil, 0, err return nil, 0, err
} }
} }
atomic.StoreUint32(&bind.isOpen, 1) bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
} }
func (bind *WinRingBind) Close() error { func (bind *WinRingBind) Close() error {
bind.mu.RLock() bind.mu.RLock()
if atomic.LoadUint32(&bind.isOpen) != 1 { if bind.isOpen.Load() != 1 {
bind.mu.RUnlock() bind.mu.RUnlock()
return nil return nil
} }
atomic.StoreUint32(&bind.isOpen, 2) bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@ -322,6 +321,13 @@ func (bind *WinRingBind) Close() error {
return nil return nil
} }
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (bind *WinRingBind) BatchSize() int {
// TODO: implement batching in and out of the ring
return 1
}
func (bind *WinRingBind) SetMark(mark uint32) error { func (bind *WinRingBind) SetMark(mark uint32) error {
return nil return nil
} }
@ -346,17 +352,21 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
//go:linkname procyield runtime.procyield //go:linkname procyield runtime.procyield
func procyield(cycles uint32) func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
bind.rx.mu.Lock() bind.rx.mu.Lock()
defer bind.rx.mu.Unlock() defer bind.rx.mu.Unlock()
var err error
var count uint32 var count uint32
var results [1]winrio.Result var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ { for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 { if tries > 0 {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
procyield(1) procyield(1)
@ -364,7 +374,7 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
count = winrio.DequeueCompletion(bind.rx.cq, results[:]) count = winrio.DequeueCompletion(bind.rx.cq, results[:])
} }
if count == 0 { if count == 0 {
err := winrio.Notify(bind.rx.cq) err = winrio.Notify(bind.rx.cq)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
@ -375,20 +385,28 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
count = winrio.DequeueCompletion(bind.rx.cq, results[:]) count = winrio.DequeueCompletion(bind.rx.cq, results[:])
if count == 0 { if count == 0 {
return 0, nil, io.ErrNoProgress return 0, nil, io.ErrNoProgress
} }
} }
bind.rx.Return(1) bind.rx.Return(1)
err := bind.InsertReceiveRequest() err = bind.InsertReceiveRequest()
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
}
if results[0].Status != 0 { if results[0].Status != 0 {
return 0, nil, windows.Errno(results[0].Status) return 0, nil, windows.Errno(results[0].Status)
} }
@ -398,20 +416,26 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
return n, &ep, nil return n, &ep, nil
} }
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v4.Receive(buf, &bind.isOpen) n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
} }
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v6.Receive(buf, &bind.isOpen) n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
} }
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
if len(buf) > bytesPerPacket { if len(buf) > bytesPerPacket {
@ -433,7 +457,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
if err != nil { if err != nil {
return err return err
} }
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
count = winrio.DequeueCompletion(bind.tx.cq, results[:]) count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@ -462,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint) nend, ok := endpoint.(*WinRingEndpoint)
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
switch nend.family { for _, buf := range bufs {
case windows.AF_INET: switch nend.family {
if bind.v4.blackhole { case windows.AF_INET:
return nil if bind.v4.blackhole {
continue
}
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
case windows.AF_INET6:
if bind.v6.blackhole {
continue
}
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
} }
return bind.v4.Send(buf, nend, &bind.isOpen)
case windows.AF_INET6:
if bind.v6.blackhole {
return nil
}
return bind.v6.Send(buf, nend, &bind.isOpen)
} }
return nil return nil
} }
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock() s.mu.Lock()
defer bind.mu.Unlock() defer s.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := s.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -500,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
if err != nil { if err != nil {
return err return err
} }
bind.blackhole4 = blackhole s.blackhole4 = blackhole
return nil return nil
} }
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock() s.mu.Lock()
defer bind.mu.Unlock() defer s.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := s.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -520,13 +550,14 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
if err != nil { if err != nil {
return err return err
} }
bind.blackhole6 = blackhole s.blackhole6 = blackhole
return nil return nil
} }
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
if atomic.LoadUint32(&bind.isOpen) != 1 { if bind.isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@ -540,7 +571,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
if atomic.LoadUint32(&bind.isOpen) != 1 { if bind.isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
@ -568,21 +599,3 @@ func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error
const IPV6_UNICAST_IF = 31 const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
} }
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package bindtest package bindtest
@ -9,10 +9,10 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"os" "os"
"strconv"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
) )
type ChannelBind struct { type ChannelBind struct {
@ -25,8 +25,10 @@ type ChannelBind struct {
type ChannelEndpoint uint16 type ChannelEndpoint uint16
var _ conn.Bind = (*ChannelBind)(nil) var (
var _ conn.Endpoint = (*ChannelEndpoint)(nil) _ conn.Bind = (*ChannelBind)(nil)
_ conn.Endpoint = (*ChannelEndpoint)(nil)
)
func NewChannelBinds() [2]conn.Bind { func NewChannelBinds() [2]conn.Bind {
arx4 := make(chan []byte, 8192) arx4 := make(chan []byte, 8192)
@ -61,9 +63,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
func (c ChannelEndpoint) SrcIP() net.IP { return nil } func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool) c.closeSignal = make(chan bool)
@ -87,45 +89,48 @@ func (c *ChannelBind) Close() error {
return nil return nil
} }
func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(b []byte) (n int, ep conn.Endpoint, err error) { return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select { select {
case <-c.closeSignal: case <-c.closeSignal:
return 0, nil, net.ErrClosed return 0, net.ErrClosed
case rx := <-ch: case rx := <-ch:
return copy(b, rx), c.target6, nil copied := copy(bufs[0], rx)
sizes[0] = copied
eps[0] = c.target6
return 1, nil
} }
} }
} }
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
select { for _, b := range bufs {
case <-c.closeSignal: select {
return net.ErrClosed case <-c.closeSignal:
default: return net.ErrClosed
bc := make([]byte, len(b)) default:
copy(bc, b) bc := make([]byte, len(b))
if ep.(ChannelEndpoint) == c.target4 { copy(bc, b)
*c.tx4 <- bc if ep.(ChannelEndpoint) == c.target4 {
} else if ep.(ChannelEndpoint) == c.target6 { *c.tx4 <- bc
*c.tx6 <- bc } else if ep.(ChannelEndpoint) == c.target6 {
} else { *c.tx6 <- bc
return os.ErrInvalid } else {
return os.ErrInvalid
}
} }
} }
return nil return nil
} }
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
_, port, err := net.SplitHostPort(s) addr, err := netip.ParseAddrPort(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
i, err := strconv.ParseUint(port, 10, 16) return ChannelEndpoint(addr.Port()), nil
if err != nil {
return nil, err
}
return ChannelEndpoint(i), nil
} }

View file

@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := s.ipv4.SyscallConn()
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return return
} }
func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := s.ipv6.SyscallConn()
if err != nil { if err != nil {
return -1, err return -1, err
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
// Package conn implements WireGuard's network connections. // Package conn implements WireGuard's network connections.
@ -9,16 +9,23 @@ package conn
import ( import (
"errors" "errors"
"fmt" "fmt"
"net" "net/netip"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
) )
// A ReceiveFunc receives a single inbound packet from the network. const (
// It writes the data into b. n is the length of the packet. IdealBatchSize = 128 // maximum number of packets handled per read and write
// ep is the remote endpoint. )
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
// //
@ -38,11 +45,16 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK. // This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error SetMark(mark uint32) error
// Send writes a packet b to address ep. // Send writes one or more packets in bufs to address ep. The length of
Send(b []byte, ep Endpoint) error // bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string. // ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error) ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
} }
// BindSocketToInterface is implemented by Bind objects that support being // BindSocketToInterface is implemented by Bind objects that support being
@ -68,8 +80,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port) SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port) DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP DstIP() netip.Addr
SrcIP() net.IP SrcIP() netip.Addr
} }
var ( var (
@ -119,33 +131,3 @@ func (fn ReceiveFunc) PrettyName() string {
} }
return name return name
} }
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}

24
conn/conn_test.go Normal file
View file

@ -0,0 +1,24 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"testing"
)
func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
)
const want = "TestPrettyName"
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
if got := recvFunc.PrettyName(); got != want {
t.Errorf("PrettyName() = %v, want %v", got, want)
}
})
}

43
conn/controlfns.go Normal file
View file

@ -0,0 +1,43 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"syscall"
)
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
// the max supported by a default configuration of macOS. Some platforms will
// silently clamp the value to other maximums, such as linux clamping to
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
// around this limitation)
const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

61
conn/controlfns_linux.go Normal file
View file

@ -0,0 +1,61 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
// fail silently - the result of failure is lower performance on very fast
// links or high latency links.
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// Set up to *mem_max
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
})
},
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
if runtime.GOOS != "android" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
}
case "udp6":
c.Control(func(fd uintptr) {
if runtime.GOOS != "android" {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

35
conn/controlfns_unix.go Normal file
View file

@ -0,0 +1,35 @@
//go:build !windows && !linux && !wasm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
})
},
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View file

@ -0,0 +1,23 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/windows"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
})
},
)
}

View file

@ -1,8 +1,8 @@
// +build !linux,!windows //go:build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn

12
conn/errors_default.go Normal file
View file

@ -0,0 +1,12 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

28
conn/errors_linux.go Normal file
View file

@ -0,0 +1,28 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
// It occurs when the peer mtu + wg headers is greater than path mtu.
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
}
return false
}

15
conn/features_default.go Normal file
View file

@ -0,0 +1,15 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

31
conn/features_linux.go Normal file
View file

@ -0,0 +1,31 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
// use setsockopt workaround
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
rxOffload = errSyscall == nil
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

21
conn/gso_default.go Normal file
View file

@ -0,0 +1,21 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

65
conn/gso_linux.go Normal file
View file

@ -0,0 +1,65 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

View file

@ -1,12 +1,12 @@
// +build !linux,!openbsd,!freebsd //go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
func (bind *StdNetBind) SetMark(mark uint32) error { func (s *StdNetBind) SetMark(mark uint32) error {
return nil return nil
} }

View file

@ -1,8 +1,8 @@
// +build linux openbsd freebsd //go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
@ -26,13 +26,13 @@ func init() {
} }
} }
func (bind *StdNetBind) SetMark(mark uint32) error { func (s *StdNetBind) SetMark(mark uint32) error {
var operr error var operr error
if fwmarkIoctl == 0 { if fwmarkIoctl == 0 {
return nil return nil
} }
if bind.ipv4 != nil { if s.ipv4 != nil {
fd, err := bind.ipv4.SyscallConn() fd, err := s.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err return err
} }
} }
if bind.ipv6 != nil { if s.ipv6 != nil {
fd, err := bind.ipv6.SyscallConn() fd, err := s.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
} }

42
conn/sticky_default.go Normal file
View file

@ -0,0 +1,42 @@
//go:build !linux || android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net/netip"
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

112
conn/sticky_linux.go Normal file
View file

@ -0,0 +1,112 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
if cap(*control) < len(ep.src) {
return
}
*control = (*control)[:0]
*control = append(*control, ep.src...)
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

266
conn/sticky_linux_test.go Normal file
View file

@ -0,0 +1,266 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
var buf []byte
if addr.Is4() {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet4Pktinfo{
Ifindex: ifidx,
Spec_dst: addr.As4(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
} else {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet6Pktinfo{
Ifindex: uint32(ifidx),
Addr: addr.As16(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
}
ep.src = buf
}
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
var control []byte
ep := &StdNetEndpoint{}
setSrc(ep, netip.MustParseAddr("::1"), 5)
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 0 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("Multiple", func(t *testing.T) {
zeroControl := make([]byte, unix.CmsgSpace(0))
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
zeroHdr.SetLen(unix.CmsgLen(0))
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
combined := make([]byte, 0)
combined = append(combined, zeroControl...)
combined = append(combined, control...)
ep := &StdNetEndpoint{}
getSrcFromControl(combined, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package winrio package winrio
@ -84,8 +84,10 @@ type iocpNotificationCompletion struct {
overlapped *windows.Overlapped overlapped *windows.Overlapped
} }
var initialized sync.Once var (
var available bool initialized sync.Once
available bool
)
func Initialize() bool { func Initialize() bool {
initialized.Do(func() { initialized.Do(func() {
@ -108,7 +110,7 @@ func Initialize() bool {
return return
} }
defer windows.CloseHandle(socket) defer windows.CloseHandle(socket)
var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
ob := uint32(0) ob := uint32(0)
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,

View file

@ -1,68 +1,55 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"container/list" "container/list"
"encoding/binary"
"errors" "errors"
"math/bits" "math/bits"
"net" "net"
"net/netip"
"sync" "sync"
"unsafe" "unsafe"
) )
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct { type trieEntry struct {
child [2]*trieEntry peer *Peer
peer *Peer child [2]*trieEntry
bits net.IP parent parentIndirection
cidr uint cidr uint8
bit_at_byte uint bitAtByte uint8
bit_at_shift uint bitAtShift uint8
perPeerElem *list.Element bits []byte
perPeerElem *list.Element
} }
func isLittleEndian() bool { func commonBits(ip1, ip2 []byte) uint8 {
one := uint32(1)
return *(*byte)(unsafe.Pointer(&one)) != 0
}
func swapU32(i uint32) uint32 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes32(i)
}
func swapU64(i uint64) uint64 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint {
size := len(ip1) size := len(ip1)
if size == net.IPv4len { if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0])) a := binary.BigEndian.Uint32(ip1)
b := (*uint32)(unsafe.Pointer(&ip2[0])) b := binary.BigEndian.Uint32(ip2)
x := *a ^ *b x := a ^ b
return uint(bits.LeadingZeros32(swapU32(x))) return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len { } else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0])) a := binary.BigEndian.Uint64(ip1)
b := (*uint64)(unsafe.Pointer(&ip2[0])) b := binary.BigEndian.Uint64(ip2)
x := *a ^ *b x := a ^ b
if x != 0 { if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x))) return uint8(bits.LeadingZeros64(x))
} }
a = (*uint64)(unsafe.Pointer(&ip1[8])) a = binary.BigEndian.Uint64(ip1[8:])
b = (*uint64)(unsafe.Pointer(&ip2[8])) b = binary.BigEndian.Uint64(ip2[8:])
x = *a ^ *b x = a ^ b
return 64 + uint(bits.LeadingZeros64(swapU64(x))) return 64 + uint8(bits.LeadingZeros64(x))
} else { } else {
panic("Wrong size bit string") panic("Wrong size bit string")
} }
@ -79,32 +66,8 @@ func (node *trieEntry) removeFromPeerEntries() {
} }
} }
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { func (node *trieEntry) choose(ip []byte) byte {
if node == nil { return (ip[node.bitAtByte] >> node.bitAtShift) & 1
return node
}
// walk recursively
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
} }
func (node *trieEntry) maskSelf() { func (node *trieEntry) maskSelf() {
@ -114,86 +77,125 @@ func (node *trieEntry) maskSelf() {
} }
} }
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
// at leaf func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
if node == nil { func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{ node := &trieEntry{
bits: ip, peer: peer,
peer: peer, parent: trie,
cidr: cidr, bits: ip,
bit_at_byte: cidr / 8, cidr: cidr,
bit_at_shift: 7 - (cidr % 8), bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
} }
node.maskSelf() node.maskSelf()
node.addToPeerEntries() node.addToPeerEntries()
return node *trie.parentBit = node
return
} }
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
// traverse deeper if exact {
node.removeFromPeerEntries()
common := commonBits(node.bits, ip) node.peer = peer
if node.cidr <= cidr && common >= node.cidr { node.addToPeerEntries()
if node.cidr == cidr { return
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
} }
// split node
newNode := &trieEntry{ newNode := &trieEntry{
bits: ip, peer: peer,
peer: peer, bits: ip,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bitAtByte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
newNode.maskSelf() newNode.maskSelf()
newNode.addToPeerEntries() newNode.addToPeerEntries()
cidr = min(cidr, common) var down *trieEntry
if node == nil {
// check for shorter prefix down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(node.bits) bit := newNode.choose(down.bits)
newNode.child[bit] = node down.parent = parentIndirection{&newNode.child[bit], bit}
return newNode newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
} }
// create new parent for node & newNode node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
parent := &trieEntry{ cidr: cidr,
bits: append([]byte{}, ip...), bitAtByte: cidr / 8,
peer: nil, bitAtShift: 7 - (cidr % 8),
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
} }
parent.maskSelf() node.maskSelf()
bit := parent.choose(ip) bit := node.choose(down.bits)
parent.child[bit] = newNode down.parent = parentIndirection{&node.child[bit], bit}
parent.child[bit^1] = node node.child[bit] = down
bit = node.choose(newNode.bits)
return parent newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
} }
func (node *trieEntry) lookup(ip net.IP) *Peer { func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer var found *Peer
size := uint(len(ip)) size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil { if node.peer != nil {
found = node.peer found = node.peer
} }
if node.bit_at_byte == size { if node.bitAtByte == size {
break break
} }
bit := node.choose(ip) bit := node.choose(ip)
@ -208,13 +210,14 @@ type AllowedIPs struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry) node := elem.Value.(*trieEntry)
if !cb(node.bits, node.cidr) { a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return return
} }
} }
@ -224,32 +227,68 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer) var next *list.Element
table.IPv6 = table.IPv6.removeByPeer(peer) for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
} }
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
switch len(ip) { if prefix.Addr().Is6() {
case net.IPv6len: ip := prefix.Addr().As16()
table.IPv6 = table.IPv6.insert(ip, cidr, peer) parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
case net.IPv4len: } else if prefix.Addr().Is4() {
table.IPv4 = table.IPv4.insert(ip, cidr, peer) ip := prefix.Addr().As4()
default: parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }
} }
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
return table.IPv4.lookup(address) switch len(ip) {
} case net.IPv6len:
return table.IPv6.lookup(ip)
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { case net.IPv4len:
table.mutex.RLock() return table.IPv4.lookup(ip)
defer table.mutex.RUnlock() default:
return table.IPv6.lookup(address) panic(errors.New("looking up unknown address type"))
}
} }

View file

@ -1,25 +1,28 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"math/rand" "math/rand"
"net"
"net/netip"
"sort" "sort"
"testing" "testing"
) )
const ( const (
NumberOfPeers = 100 NumberOfPeers = 100
NumberOfAddresses = 250 NumberOfPeerRemovals = 4
NumberOfTests = 10000 NumberOfAddresses = 250
NumberOfTests = 10000
) )
type SlowNode struct { type SlowNode struct {
peer *Peer peer *Peer
cidr uint cidr uint8
bits []byte bits []byte
} }
@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i] r[i], r[j] = r[j], r[i]
} }
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r { for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer t.peer = peer
@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil return nil
} }
func TestTrieRandomIPv4(t *testing.T) { func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
var trie *trieEntry n := 0
var slow SlowRouter for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
const AddressLength = 4
for n := 0; n < NumberOfPeers; n++ { for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte var addr4 [4]byte
rand.Read(addr[:]) rand.Read(addr4[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Intn(32) + 1)
index := rand.Int() % NumberOfPeers index := rand.Intn(NumberOfPeers)
trie = trie.insert(addr[:], cidr, peers[index]) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
} }
for n := 0; n < NumberOfTests; n++ { var p int
var addr [AddressLength]byte for p = 0; ; p++ {
rand.Read(addr[:]) for n := 0; n < NumberOfTests; n++ {
peer1 := slow.Lookup(addr[:]) var addr4 [4]byte
peer2 := trie.lookup(addr[:]) rand.Read(addr4[:])
if peer1 != peer2 { peer1 := slow4.Lookup(addr4[:])
t.Error("Trie did not match naive implementation, for:", addr) peer2 := allowedIPs.Lookup(addr4[:])
} if peer1 != peer2 {
} t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
} }
func TestTrieRandomIPv6(t *testing.T) { var addr6 [16]byte
var trie *trieEntry rand.Read(addr6[:])
var slow SlowRouter peer1 = slow6.Lookup(addr6[:])
var peers []*Peer peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
rand.Seed(1) t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
const AddressLength = 16
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
} }
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
} }
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,20 +8,17 @@ package device
import ( import (
"math/rand" "math/rand"
"net" "net"
"net/netip"
"testing" "testing"
) )
/* Todo: More comprehensive
*/
type testPairCommonBits struct { type testPairCommonBits struct {
s1 []byte s1 []byte
s2 []byte s2 []byte
match uint match uint8
} }
func TestCommonBits(t *testing.T) { func TestCommonBits(t *testing.T) {
tests := []testPairCommonBits{ tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@ -42,9 +39,10 @@ func TestCommonBits(t *testing.T) {
} }
} }
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry var trie *trieEntry
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1) rand.Seed(1)
@ -57,9 +55,9 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
for n := 0; n < addressNumber; n++ { for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index]) root.insert(addr[:], cidr, peers[index])
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@ -97,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }
} }
assertNEQ := func(peer *Peer, a, b, c, d byte) { assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer { if p == peer {
t.Error("Assert NEQ failed") t.Error("Assert NEQ failed")
} }
@ -153,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0) assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0) assertEQ(a, 255, 0, 0, 0)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0)
@ -161,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
trie = nil allowedIPs.RemoveByPeer(a)
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1) assertNEQ(a, 192, 168, 0, 1)
} }
@ -184,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
expand := func(a uint32) []byte { expand := func(a uint32) []byte {
var out [4]byte var out [4]byte
@ -195,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:] return out[:]
} }
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte var addr []byte
addr = append(addr, expand(a)...) addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
trie = trie.insert(addr, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {
@ -210,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
p := trie.lookup(addr) p := allowedIPs.Lookup(addr)
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,7 +8,7 @@ package device
import ( import (
"errors" "errors"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
) )
type DummyDatagram struct { type DummyDatagram struct {
@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil return nil
} }
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6 datagram, ok := <-b.in6
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
copy(buff, datagram.msg) copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4 datagram, ok := <-b.in4
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
copy(buff, datagram.msg) copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil return nil
} }
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil return nil
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference. // call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed. // When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct { type outboundQueue struct {
c chan *QueueOutboundElement c chan *QueueOutboundElementsContainer
wg sync.WaitGroup wg sync.WaitGroup
} }
func newOutboundQueue() *outboundQueue { func newOutboundQueue() *outboundQueue {
q := &outboundQueue{ q := &outboundQueue{
c: make(chan *QueueOutboundElement, QueueOutboundSize), c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
} }
q.wg.Add(1) q.wg.Add(1)
go func() { go func() {
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs. // A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct { type inboundQueue struct {
c chan *QueueInboundElement c chan *QueueInboundElementsContainer
wg sync.WaitGroup wg sync.WaitGroup
} }
func newInboundQueue() *inboundQueue { func newInboundQueue() *inboundQueue {
q := &inboundQueue{ q := &inboundQueue{
c: make(chan *QueueInboundElement, QueueInboundSize), c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
} }
q.wg.Add(1) q.wg.Add(1)
go func() { go func() {
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
} }
type autodrainingInboundQueue struct { type autodrainingInboundQueue struct {
c chan *QueueInboundElement c chan *QueueInboundElementsContainer
} }
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values. // some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{ q := &autodrainingInboundQueue{
c: make(chan *QueueInboundElement, QueueInboundSize), c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
} }
runtime.SetFinalizer(q, device.flushInboundQueue) runtime.SetFinalizer(q, device.flushInboundQueue)
return q return q
@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for { for {
select { select {
case elem := <-q.c: case elemsContainer := <-q.c:
elem.Lock() elemsContainer.Lock()
device.PutMessageBuffer(elem.buffer) for _, elem := range elemsContainer.elems {
device.PutInboundElement(elem) device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
default: default:
return return
} }
@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
} }
type autodrainingOutboundQueue struct { type autodrainingOutboundQueue struct {
c chan *QueueOutboundElement c chan *QueueOutboundElementsContainer
} }
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers. // All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{ q := &autodrainingOutboundQueue{
c: make(chan *QueueOutboundElement, QueueOutboundSize), c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
} }
runtime.SetFinalizer(q, device.flushOutboundQueue) runtime.SetFinalizer(q, device.flushOutboundQueue)
return q return q
@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for { for {
select { select {
case elem := <-q.c: case elemsContainer := <-q.c:
elem.Lock() elemsContainer.Lock()
device.PutMessageBuffer(elem.buffer) for _, elem := range elemsContainer.elems {
device.PutOutboundElement(elem) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
default: default:
return return
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -35,7 +35,6 @@ const (
/* Implementation constants */ /* Implementation constants */
const ( const (
UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers MaxPeers = 1 << 16 // maximum number of configured peers
) )

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2]) return hmac.Equal(mac1[:], msg[smac1:smac2])
} }
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
recv uint32, recv uint32,
src []byte, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.RLock() st.RLock()
// refresh cookie secret // refresh cookie secret
@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
if err != nil { if err != nil {
return false return false
} }
@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
} }
func (st *CookieGenerator) AddMacs(msg []byte) { func (st *CookieGenerator) AddMacs(msg []byte) {
size := len(msg) size := len(msg)
smac2 := size - blake2s.Size128 smac2 := size - blake2s.Size128

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -10,7 +10,6 @@ import (
) )
func TestCookieMAC1(t *testing.T) { func TestCookieMAC1(t *testing.T) {
// setup generator / checker // setup generator / checker
var ( var (
@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20 msg[5] ^= 0x20
srcBad1 := []byte{192, 168, 13, 37, 40, 01} srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) { if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }
srcBad2 := []byte{192, 168, 13, 38, 40, 01} srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) { if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -11,10 +11,12 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/ipc"
"golang.zx2c4.com/wireguard/rwcancel" "github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"golang.zx2c4.com/wireguard/tun" "github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/tevino/abool/v2"
) )
type Device struct { type Device struct {
@ -30,7 +32,7 @@ type Device struct {
// will become the actual state; Up can fail. // will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use. // The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only. // Unsynchronized uses of state must therefore be advisory/best-effort only.
state uint32 // actually a deviceState, but typed uint32 for convenience state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed. // stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup stopping sync.WaitGroup
// mu protects state changes. // mu protects state changes.
@ -44,6 +46,7 @@ type Device struct {
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
port uint16 // listening port port uint16 // listening port
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
brokenRoaming bool
} }
staticIdentity struct { staticIdentity struct {
@ -52,24 +55,26 @@ type Device struct {
publicKey NoisePublicKey publicKey NoisePublicKey
} }
rate struct {
underLoadUntil int64
limiter ratelimiter.Ratelimiter
}
peers struct { peers struct {
sync.RWMutex // protects keyMap sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer keyMap map[NoisePublicKey]*Peer
} }
rate struct {
underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter
}
allowedips AllowedIPs allowedips AllowedIPs
indexTable IndexTable indexTable IndexTable
cookieChecker CookieChecker cookieChecker CookieChecker
pool struct { pool struct {
messageBuffers *WaitPool inboundElementsContainer *WaitPool
inboundElements *WaitPool outboundElementsContainer *WaitPool
outboundElements *WaitPool messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
} }
queue struct { queue struct {
@ -80,22 +85,39 @@ type Device struct {
tun struct { tun struct {
device tun.Device device tun.Device
mtu int32 mtu atomic.Int32
} }
ipcMutex sync.RWMutex ipcMutex sync.RWMutex
closed chan struct{} closed chan struct{}
log *Logger log *Logger
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
}
type aSecCfgType struct {
isSet bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
} }
// deviceState represents the state of a Device. // deviceState represents the state of a Device.
// There are three states: down, up, closed. // There are three states: down, up, closed.
// Transitions: // Transitions:
// //
// down -----+ // down -----+
// ↑↓ ↓ // ↑↓ ↓
// up -> closed // up -> closed
//
type deviceState uint32 type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@ -108,7 +130,7 @@ const (
// deviceState returns device.state.state as a deviceState // deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value. // See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState { func (device *Device) deviceState() deviceState {
return deviceState(atomic.LoadUint32(&device.state.state)) return deviceState(device.state.state.Load())
} }
// isClosed reports whether the device is closed (or is closing). // isClosed reports whether the device is closed (or is closing).
@ -147,20 +169,21 @@ func (device *Device) changeState(want deviceState) (err error) {
case old: case old:
return nil return nil
case deviceStateUp: case deviceStateUp:
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked() err = device.upLocked()
if err == nil { if err == nil {
break break
} }
fallthrough // up failed; bring the device all the way back down fallthrough // up failed; bring the device all the way back down
case deviceStateDown: case deviceStateDown:
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked() errDown := device.downLocked()
if err == nil { if err == nil {
err = errDown err = errDown
} }
} }
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
return return
} }
@ -172,10 +195,15 @@ func (device *Device) upLocked() error {
return err return err
} }
// The IPC set operation waits for peers to be created before calling Start() on them,
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Start() peer.Start()
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@ -210,13 +238,13 @@ func (device *Device) Down() error {
func (device *Device) IsUnderLoad() bool { func (device *Device) IsUnderLoad() bool {
// check if currently under load // check if currently under load
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad { if underLoad {
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true return true
} }
// check if recently under load // check if recently under load
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() return device.rate.underLoadUntil.Load() > now.UnixNano()
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@ -260,7 +288,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }
@ -276,7 +304,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.state.state = uint32(deviceStateDown) device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{}) device.closed = make(chan struct{})
device.log = logger device.log = logger
device.net.bind = bind device.net.bind = bind
@ -286,10 +314,11 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err) device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU mtu = DefaultMTU
} }
device.tun.mtu = int32(mtu) device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init() device.rate.limiter.Init()
device.indexTable.Init() device.indexTable.Init()
device.PopulatePools() device.PopulatePools()
// create queues // create queues
@ -304,9 +333,9 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.state.stopping.Wait() device.state.stopping.Wait()
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
for i := 0; i < cpus; i++ { for i := 0; i < cpus; i++ {
go device.RoutineEncryption() go device.RoutineEncryption(i + 1)
go device.RoutineDecryption() go device.RoutineDecryption(i + 1)
go device.RoutineHandshake() go device.RoutineHandshake(i + 1)
} }
device.state.stopping.Add(1) // RoutineReadFromTUN device.state.stopping.Add(1) // RoutineReadFromTUN
@ -317,6 +346,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device return device
} }
// BatchSize returns the BatchSize for the device as a whole which is the max of
// the bind batch size and the tun batch size. The batch size reported by device
// is the size used to construct memory pools, and is the allowed batch size for
// the lifetime of the device.
func (device *Device) BatchSize() int {
size := device.net.bind.BatchSize()
dSize := device.tun.device.BatchSize()
if size < dSize {
size = dSize
}
return size
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock() device.peers.RLock()
defer device.peers.RUnlock() defer device.peers.RUnlock()
@ -349,10 +391,12 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() { func (device *Device) Close() {
device.state.Lock() device.state.Lock()
defer device.state.Unlock() defer device.state.Unlock()
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
if device.isClosed() { if device.isClosed() {
return return
} }
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing") device.log.Verbosef("Device closing")
device.tun.device.Close() device.tun.device.Close()
@ -372,6 +416,8 @@ func (device *Device) Close() {
device.rate.limiter.Close() device.rate.limiter.Close()
device.resetProtocol()
device.log.Verbosef("Device closed") device.log.Verbosef("Device closed")
close(device.closed) close(device.closed)
} }
@ -438,11 +484,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses // clear cached source addresses
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.markEndpointSrcForClearing()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
} }
device.peers.RUnlock() device.peers.RUnlock()
@ -467,11 +509,13 @@ func (device *Device) BindUpdate() error {
var err error var err error
var recvFns []conn.ReceiveFunc var recvFns []conn.ReceiveFunc
netc := &device.net netc := &device.net
recvFns, netc.port, err = netc.bind.Open(netc.port) recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil { if err != nil {
netc.port = 0 netc.port = 0
return err return err
} }
netc.netlinkCancel, err = device.startRouteListener(netc.bind) netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil { if err != nil {
netc.bind.Close() netc.bind.Close()
@ -490,11 +534,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses // clear cached source addresses
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.markEndpointSrcForClearing()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
} }
device.peers.RUnlock() device.peers.RUnlock()
@ -502,8 +542,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns)) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns { for _, fn := range recvFns {
go device.RoutineReceiveIncoming(fn) go device.RoutineReceiveIncoming(batchSize, fn)
} }
device.log.Verbosef("UDP bind has been updated") device.log.Verbosef("UDP bind has been updated")
@ -516,3 +557,251 @@ func (device *Device) BindClose() error {
device.net.Unlock() device.net.Unlock()
return err return err
} }
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
}
isASecOn := false
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
)
}
} else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = 1
}
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = 2
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = 3
}
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = 4
}
isSameMap := map[uint32]bool{}
isSameMap[MessageInitiationType] = true
isSameMap[MessageResponseType] = true
isSameMap[MessageCookieReplyType] = true
isSameMap[MessageTransportType] = true
// size will be different if same values
if len(isSameMap) != 4 {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
)
}
}
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ; %w`,
newInitSize,
newResponseSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
)
}
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -11,7 +11,8 @@ import (
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net" "net/netip"
"os"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sync" "sync"
@ -19,9 +20,10 @@ import (
"testing" "testing"
"time" "time"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/conn/bindtest" "github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest" "github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
) )
// uapiCfg returns a string that contains cfg formatted use with IpcSet. // uapiCfg returns a string that contains cfg formatted use with IpcSet.
@ -48,7 +50,7 @@ func uapiCfg(cfg ...string) string {
// genConfigs generates a pair of configs that connect to each other. // genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports. // The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) { func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:]) _, err := rand.Read(key1[:])
if err != nil { if err != nil {
@ -89,6 +91,65 @@ func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) {
return return
} }
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
// A testPair is a pair of testPeers. // A testPair is a pair of testPeers.
type testPair [2]testPeer type testPair [2]testPeer
@ -96,7 +157,7 @@ type testPair [2]testPeer
type testPeer struct { type testPeer struct {
tun *tuntest.ChannelTUN tun *tuntest.ChannelTUN
dev *Device dev *Device
ip net.IP ip netip.Addr
} }
type SendDirection bool type SendDirection bool
@ -113,7 +174,11 @@ func (d SendDirection) String() string {
return "pong" return "pong"
} }
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
) {
tb.Helper() tb.Helper()
p0, p1 := pair[0], pair[1] p0, p1 := pair[0], pair[1]
if !ping { if !ping {
@ -147,8 +212,16 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
} }
// genTestPair creates a testPair. // genTestPair creates a testPair.
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { func genTestPair(
cfg, endpointCfg := genConfigs(tb) tb testing.TB,
realSocket, withASecurity bool,
) (pair testPair) {
var cfg, endpointCfg [2]string
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
} else {
cfg, endpointCfg = genConfigs(tb)
}
var binds [2]conn.Bind var binds [2]conn.Bind
if realSocket { if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
@ -159,7 +232,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair { for i := range pair {
p := &pair[i] p := &pair[i]
p.tun = tuntest.NewChannelTUN() p.tun = tuntest.NewChannelTUN()
p.ip = net.IPv4(1, 0, 0, byte(i+1)) p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() { if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError level = LogLevelError
@ -192,7 +265,18 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
func TestTwoDevicePing(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
pair := genTestPair(t, true) pair := genTestPair(t, true, false)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil) pair.Send(t, Ping, nil)
}) })
@ -207,7 +291,7 @@ func TestUpDown(t *testing.T) {
const otrials = 10 const otrials = 10
for n := 0; n < otrials; n++ { for n := 0; n < otrials; n++ {
pair := genTestPair(t, false) pair := genTestPair(t, false, false)
for i := range pair { for i := range pair {
for k := range pair[i].dev.peers.keyMap { for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@ -241,7 +325,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use. // TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races. // It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) { func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true) pair := genTestPair(t, true, false)
done := make(chan struct{}) done := make(chan struct{})
const warmupIters = 10 const warmupIters = 10
@ -307,11 +391,22 @@ func TestConcurrencySafety(t *testing.T) {
} }
}) })
// Perform bind updates and keepalive sends concurrently with tunnel use.
t.Run("bindUpdate and keepalive", func(t *testing.T) {
const iters = 10
for i := 0; i < iters; i++ {
for _, peer := range pair {
peer.dev.BindUpdate()
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
}
}
})
close(done) close(done)
} }
func BenchmarkLatency(b *testing.B) { func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true) pair := genTestPair(b, true, false)
// Establish a connection. // Establish a connection.
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
@ -325,7 +420,7 @@ func BenchmarkLatency(b *testing.B) {
} }
func BenchmarkThroughput(b *testing.B) { func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true) pair := genTestPair(b, true, false)
// Establish a connection. // Establish a connection.
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
@ -333,7 +428,7 @@ func BenchmarkThroughput(b *testing.B) {
// Measure how long it takes to receive b.N packets, // Measure how long it takes to receive b.N packets,
// starting when we receive the first packet. // starting when we receive the first packet.
var recv uint64 var recv atomic.Uint64
var elapsed time.Duration var elapsed time.Duration
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -342,7 +437,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time var start time.Time
for { for {
<-pair[0].tun.Inbound <-pair[0].tun.Inbound
new := atomic.AddUint64(&recv, 1) new := recv.Add(1)
if new == 1 { if new == 1 {
start = time.Now() start = time.Now()
} }
@ -358,7 +453,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip) ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound pingc := pair[1].tun.Outbound
var sent uint64 var sent uint64
for atomic.LoadUint64(&recv) != uint64(b.N) { for recv.Load() != uint64(b.N) {
sent++ sent++
pingc <- ping pingc <- ping
} }
@ -369,7 +464,7 @@ func BenchmarkThroughput(b *testing.B) {
} }
func BenchmarkUAPIGet(b *testing.B) { func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true) pair := genTestPair(b, true, false)
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil) pair.Send(b, Pong, nil)
b.ReportAllocs() b.ReportAllocs()
@ -405,3 +500,73 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
}) })
} }
type fakeBindSized struct {
size int
}
func (b *fakeBindSized) Open(
port uint16,
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 1, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
}

View file

@ -1,53 +1,49 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"math/rand" "math/rand"
"net" "net/netip"
) )
type DummyEndpoint struct { type DummyEndpoint struct {
src [16]byte src, dst netip.Addr
dst [16]byte
} }
func CreateDummyEndpoint() (*DummyEndpoint, error) { func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint var src, dst [16]byte
if _, err := rand.Read(end.src[:]); err != nil { if _, err := rand.Read(src[:]); err != nil {
return nil, err return nil, err
} }
_, err := rand.Read(end.dst[:]) _, err := rand.Read(dst[:])
return &end, err return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
} }
func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string { func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr return netip.AddrPortFrom(e.SrcIP(), 1000).String()
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) DstToString() string { func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr return netip.AddrPortFrom(e.DstIP(), 1000).String()
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) SrcToBytes() []byte { func (e *DummyEndpoint) DstToBytes() []byte {
return e.src[:] out := e.DstIP().AsSlice()
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
} }
func (e *DummyEndpoint) DstIP() net.IP { func (e *DummyEndpoint) DstIP() netip.Addr {
return e.dst[:] return e.dst
} }
func (e *DummyEndpoint) SrcIP() net.IP { func (e *DummyEndpoint) SrcIP() netip.Addr {
return e.src[:] return e.src
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

69
device/junk_creator.go Normal file
View file

@ -0,0 +1,69 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

124
device/junk_creator_test.go Normal file
View file

@ -0,0 +1,124 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets()
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -20,7 +20,7 @@ type KDFTest struct {
t2 string t2 string
} }
func assertEquals(t *testing.T, a string, b string) { func assertEquals(t *testing.T, a, b string) {
if a != b { if a != b {
t.Fatal("expected", a, "=", b) t.Fatal("expected", a, "=", b)
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -10,9 +10,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"golang.zx2c4.com/wireguard/replay" "github.com/amnezia-vpn/amneziawg-go/replay"
) )
/* Due to limitations in Go and /x/crypto there is currently /* Due to limitations in Go and /x/crypto there is currently
@ -23,7 +22,7 @@ import (
*/ */
type Keypair struct { type Keypair struct {
sendNonce uint64 // accessed atomically sendNonce atomic.Uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter replay.Filter replayFilter replay.Filter
@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next *Keypair next atomic.Pointer[Keypair]
}
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -16,8 +16,8 @@ import (
// They do not require a trailing newline in the format. // They do not require a trailing newline in the format.
// If nil, that level of logging will be silent. // If nil, that level of logging will be silent.
type Logger struct { type Logger struct {
Verbosef func(format string, args ...interface{}) Verbosef func(format string, args ...any)
Errorf func(format string, args ...interface{}) Errorf func(format string, args ...any)
} }
// Log levels for use with NewLogger. // Log levels for use with NewLogger.
@ -28,14 +28,14 @@ const (
) )
// Function for use in Logger for discarding logged lines. // Function for use in Logger for discarding logged lines.
func DiscardLogf(format string, args ...interface{}) {} func DiscardLogf(format string, args ...any) {}
// NewLogger constructs a Logger that writes to stdout. // NewLogger constructs a Logger that writes to stdout.
// It logs at the specified log level and above. // It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend. // It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger { func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf} logger := &Logger{DiscardLogf, DiscardLogf}
logf := func(prefix string) func(string, ...interface{}) { logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
} }
if level >= LogLevelVerbose { if level >= LogLevelVerbose {

View file

@ -1,48 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.int32) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.int32, flag)
}
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

View file

@ -1,16 +1,19 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
// though it will try to deal with it, and race maybe, if called after.
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.endpoint.Lock()
peer.disableRoaming = peer.endpoint != nil peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.Unlock() peer.endpoint.Unlock()
} }
device.peers.RUnlock() device.peers.RUnlock()
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -9,6 +9,7 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"errors"
"hash" "hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return return
} }
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { var errInvalidPublicKey = errors.New("invalid public key")
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk) apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk) ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk) curve25519.ScalarMult(&ss, ask, apk)
return ss if isZero(ss[:]) {
return ss, errInvalidPublicKey
}
return ss, nil
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -15,7 +15,7 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n" "github.com/amnezia-vpn/amneziawg-go/tai64n"
) )
type handshakeState int type handshakeState int
@ -52,11 +52,11 @@ const (
WGLabelCookie = "cookie--" WGLabelCookie = "cookie--"
) )
const ( var (
MessageInitiationType = 1 MessageInitiationType uint32 = 1
MessageResponseType = 2 MessageResponseType uint32 = 2
MessageCookieReplyType = 3 MessageCookieReplyType uint32 = 3
MessageTransportType = 4 MessageTransportType uint32 = 4
) )
const ( const (
@ -75,6 +75,10 @@ const (
MessageTransportOffsetContent = 16 MessageTransportOffsetContent = 16
) )
var packetSizeToMsgType map[int]uint32
var msgTypeToJunkSize map[uint32]int
/* Type is an 8-bit field, followed by 3 nul bytes, /* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder * by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now) * we can treat these as a 32-bit unsigned int (for now)
@ -138,11 +142,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data) KDF1(dst, c[:], data)
} }
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hash.Write(h[:]) hash.Write(h[:])
hash.Write(data) hash.Write(data)
@ -175,8 +179,6 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -195,18 +197,20 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
} }
device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) { if err != nil {
return nil, errZeroECDHResult return nil, err
} }
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
KDF2( KDF2(
@ -221,7 +225,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) { if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult return nil, errInvalidPublicKey
} }
KDF2( KDF2(
&handshake.chainKey, &handshake.chainKey,
@ -252,9 +256,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.aSecMux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -264,11 +271,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) { if err != nil {
return nil return nil
} }
KDF2(&chainKey, &key, chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
@ -282,7 +288,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer // lookup peer
peer := device.LookupPeer(peerPK) peer := device.LookupPeer(peerPK)
if peer == nil { if peer == nil || !peer.isRunning.Load() {
return nil return nil
} }
@ -370,7 +376,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
var msg MessageResponse var msg MessageResponse
device.aSecMux.RLock()
msg.Type = MessageResponseType msg.Type = MessageResponseType
device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -384,12 +392,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) if err != nil {
handshake.mixKey(ss[:]) return nil, err
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
}() ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
// add preshared key // add preshared key
@ -406,11 +418,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:]) handshake.mixHash(tau[:])
func() { aead, _ := chacha20poly1305.New(key[:])
aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.mixHash(msg.Empty[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = handshakeResponseCreated handshake.state = handshakeResponseCreated
@ -418,9 +428,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver
@ -436,7 +449,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
) )
ok := func() bool { ok := func() bool {
// lock handshake state // lock handshake state
handshake.mutex.RLock() handshake.mutex.RLock()
@ -456,17 +468,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
func() { ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
// add preshared key (psk) // add preshared key (psk)
@ -484,7 +498,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript // authenticate transcript
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
return false return false
} }
@ -582,12 +596,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.loadNext() next := keypairs.next.Load()
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.storeNext(nil) keypairs.next.Store(nil)
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@ -596,7 +610,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.storeNext(keypair) keypairs.next.Store(keypair)
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@ -608,18 +622,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.loadNext() != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
keypairs.Lock() keypairs.Lock()
defer keypairs.Unlock() defer keypairs.Unlock()
if keypairs.loadNext() != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.loadNext() keypairs.current = keypairs.next.Load()
keypairs.storeNext(nil) keypairs.next.Store(nil)
return true return true
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -10,8 +10,8 @@ import (
"encoding/binary" "encoding/binary"
"testing" "testing"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/tun/tuntest" "github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
) )
func TestCurveWrappers(t *testing.T) { func TestCurveWrappers(t *testing.T) {
@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey() pk1 := sk1.publicKey()
pk2 := sk2.publicKey() pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2) ss1, err1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1) ss2, err2 := sk2.sharedSecret(pk1)
if ss1 != ss2 { if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet") t.Fatal("Failed to compute shared secet")
} }
} }
@ -71,6 +71,8 @@ func TestNoiseHandshake(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer1.Start()
peer2.Start()
assertEqual( assertEqual(
t, t,
@ -146,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.loadNext() key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

View file

@ -1,53 +1,46 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"container/list" "container/list"
"encoding/base64"
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
) )
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning atomic.Bool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs
keypairs Keypairs handshake Handshake
handshake Handshake device *Device
device *Device stopping sync.WaitGroup // routines pending stop
endpoint conn.Endpoint txBytes atomic.Uint64 // bytes send to peer (endpoint)
stopping sync.WaitGroup // routines pending stop rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
// These fields are accessed with atomic operations, which must be endpoint struct {
// 64-bit aligned even on 32-bit platforms. Go guarantees that an sync.Mutex
// allocated struct will be 64-bit aligned. So we place val conn.Endpoint
// atomically-accessed fields up front, so that they can share in clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
// this alignment before smaller fields throw it off. disableRoaming bool
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
} }
disableRoaming bool
timers struct { timers struct {
retransmitHandshake *Timer retransmitHandshake *Timer
sendKeepalive *Timer sendKeepalive *Timer
newHandshake *Timer newHandshake *Timer
zeroKeyMaterial *Timer zeroKeyMaterial *Timer
persistentKeepalive *Timer persistentKeepalive *Timer
handshakeAttempts uint32 handshakeAttempts atomic.Uint32
needAnotherKeepalive AtomicBool needAnotherKeepalive atomic.Bool
sentLastMinuteHandshake AtomicBool sentLastMinuteHandshake atomic.Bool
} }
state struct { state struct {
@ -55,14 +48,14 @@ type Peer struct {
} }
queue struct { queue struct {
staged chan *QueueOutboundElement // staged packets before a handshake is available staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing inbound *autodrainingInboundQueue // sequential ordering of tun writing
} }
cookieGenerator CookieGenerator cookieGenerator CookieGenerator
trieEntries list.List trieEntries list.List
persistentKeepaliveInterval uint32 // accessed atomically persistentKeepaliveInterval atomic.Uint32
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@ -84,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer // create peer
peer := new(Peer) peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk) peer.cookieGenerator.Init(pk)
peer.device = device peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key // map public key
_, ok := device.peers.keyMap[pk] _, ok := device.peers.keyMap[pk]
@ -102,26 +93,27 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH // pre-compute DH
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
peer.endpoint = nil peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
// init timers
peer.timersInit()
// add // add
device.peers.keyMap[pk] = peer device.peers.keyMap[pk] = peer
// start peer
peer.timersInit()
if peer.device.isUp() {
peer.Start()
}
return peer, nil return peer, nil
} }
func (peer *Peer) SendBuffer(buffer []byte) error { func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock() peer.device.net.RLock()
defer peer.device.net.RUnlock() defer peer.device.net.RUnlock()
@ -129,27 +121,53 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return nil return nil
} }
peer.RLock() peer.endpoint.Lock()
defer peer.RUnlock() endpoint := peer.endpoint.val
if endpoint == nil {
if peer.endpoint == nil { peer.endpoint.Unlock()
return errors.New("no known endpoint for peer") return errors.New("no known endpoint for peer")
} }
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffer, peer.endpoint) err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil { if err == nil {
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) var totalLen uint64
for _, b := range buffers {
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
} }
return err return err
} }
func (peer *Peer) String() string { func (peer *Peer) String() string {
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) // The awful goo that follows is identical to:
abbreviatedKey := "invalid" //
if len(base64Key) == 44 { // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
//
// except that it is considerably more efficient.
src := peer.handshake.remoteStatic
b64 := func(input byte) byte {
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
} }
return fmt.Sprintf("peer(%s)", abbreviatedKey) b := []byte("peer(____…____)")
const first = len("peer(")
const second = len("peer(____…")
b[first+0] = b64((src[0] >> 2) & 63)
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
b[first+3] = b64(src[2] & 63)
b[second+0] = b64(src[29] & 63)
b[second+1] = b64((src[30] >> 2) & 63)
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
b[second+3] = b64((src[31] << 2) & 63)
return string(b)
} }
func (peer *Peer) Start() { func (peer *Peer) Start() {
@ -162,12 +180,12 @@ func (peer *Peer) Start() {
peer.state.Lock() peer.state.Lock()
defer peer.state.Unlock() defer peer.state.Unlock()
if peer.isRunning.Get() { if peer.isRunning.Load() {
return return
} }
device := peer.device device := peer.device
device.log.Verbosef("%v - Starting...", peer) device.log.Verbosef("%v - Starting", peer)
// reset routine state // reset routine state
peer.stopping.Wait() peer.stopping.Wait()
@ -183,10 +201,14 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound) device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound) device.flushOutboundQueue(peer.queue.outbound)
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
peer.isRunning.Set(true) // Use the device batch size, not the bind batch size, as the device size is
// the size of the batch pools.
batchSize := peer.device.BatchSize()
go peer.RoutineSequentialSender(batchSize)
go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
} }
func (peer *Peer) ZeroAndFlushAll() { func (peer *Peer) ZeroAndFlushAll() {
@ -198,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.loadNext()) device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.storeNext(nil) keypairs.next.Store(nil)
keypairs.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@ -226,11 +248,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.Lock() keypairs.Lock()
if keypairs.current != nil { if keypairs.current != nil {
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) keypairs.current.sendNonce.Store(RejectAfterMessages)
} }
if keypairs.next != nil { if next := keypairs.next.Load(); next != nil {
next := keypairs.loadNext() next.sendNonce.Store(RejectAfterMessages)
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
} }
keypairs.Unlock() keypairs.Unlock()
} }
@ -243,7 +264,7 @@ func (peer *Peer) Stop() {
return return
} }
peer.device.log.Verbosef("%v - Stopping...", peer) peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop() peer.timersStop()
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
@ -256,10 +277,20 @@ func (peer *Peer) Stop() {
} }
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if peer.disableRoaming { peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
return return
} }
peer.Lock() peer.endpoint.clearSrcOnTx = false
peer.endpoint = endpoint peer.endpoint.val = endpoint
peer.Unlock() }
func (peer *Peer) markEndpointSrcForClearing() {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return
}
peer.endpoint.clearSrcOnTx = true
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -14,49 +14,85 @@ type WaitPool struct {
pool sync.Pool pool sync.Pool
cond sync.Cond cond sync.Cond
lock sync.Mutex lock sync.Mutex
count uint32 count atomic.Uint32
max uint32 max uint32
} }
func NewWaitPool(max uint32, new func() interface{}) *WaitPool { func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max} p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock} p.cond = sync.Cond{L: &p.lock}
return p return p
} }
func (p *WaitPool) Get() interface{} { func (p *WaitPool) Get() any {
if p.max != 0 { if p.max != 0 {
p.lock.Lock() p.lock.Lock()
for atomic.LoadUint32(&p.count) >= p.max { for p.count.Load() >= p.max {
p.cond.Wait() p.cond.Wait()
} }
atomic.AddUint32(&p.count, 1) p.count.Add(1)
p.lock.Unlock() p.lock.Unlock()
} }
return p.pool.Get() return p.pool.Get()
} }
func (p *WaitPool) Put(x interface{}) { func (p *WaitPool) Put(x any) {
p.pool.Put(x) p.pool.Put(x)
if p.max == 0 { if p.max == 0 {
return return
} }
atomic.AddUint32(&p.count, ^uint32(0)) p.count.Add(^uint32(0))
p.cond.Signal() p.cond.Signal()
} }
func (device *Device) PopulatePools() { func (device *Device) PopulatePools() {
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte) return new([MaxMessageSize]byte)
}) })
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueInboundElement) return new(QueueInboundElement)
}) })
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueOutboundElement) return new(QueueOutboundElement)
}) })
} }
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.inboundElementsContainer.Put(c)
}
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.outboundElementsContainer.Put(c)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -17,29 +17,31 @@ import (
func TestWaitPool(t *testing.T) { func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled") t.Skip("Currently disabled")
var wg sync.WaitGroup var wg sync.WaitGroup
trials := int32(100000) var trials atomic.Int32
startTrials := int32(100000)
if raceEnabled { if raceEnabled {
// This test can be very slow with -race. // This test can be very slow with -race.
trials /= 10 startTrials /= 10
} }
trials.Store(startTrials)
workers := runtime.NumCPU() + 2 workers := runtime.NumCPU() + 2
if workers-4 <= 0 { if workers-4 <= 0 {
t.Skip("Not enough cores") t.Skip("Not enough cores")
} }
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers) wg.Add(workers)
max := uint32(0) var max atomic.Uint32
updateMax := func() { updateMax := func() {
count := atomic.LoadUint32(&p.count) count := p.count.Load()
if count > p.max { if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max) t.Errorf("count (%d) > max (%d)", count, p.max)
} }
for { for {
old := atomic.LoadUint32(&max) old := max.Load()
if count <= old { if count <= old {
break break
} }
if atomic.CompareAndSwapUint32(&max, old, count) { if max.CompareAndSwap(old, count) {
break break
} }
} }
@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 { for trials.Add(-1) > 0 {
updateMax() updateMax()
x := p.Get() x := p.Get()
updateMax() updateMax()
@ -59,25 +61,74 @@ func TestWaitPool(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
if max != p.max { if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
} }
} }
func BenchmarkWaitPool(b *testing.B) { func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup var wg sync.WaitGroup
trials := int32(b.N) var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2 workers := runtime.NumCPU() + 2
if workers-4 <= 0 { if workers-4 <= 0 {
b.Skip("Not enough cores") b.Skip("Not enough cores")
} }
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers) wg.Add(workers)
b.ResetTimer() b.ResetTimer()
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 { for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkWaitPoolEmpty(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(0, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkSyncPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := sync.Pool{New: func() any { return make([]byte, 16) }}
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get() x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x) p.Put(x)

View file

@ -1,17 +1,19 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
/* Reduce memory consumption for Android */ /* Reduce memory consumption for Android */
const ( const (
QueueStagedSize = 128 QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024
MaxSegmentSize = 2200 MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096 PreallocatedBuffersPerPool = 4096
) )

View file

@ -1,14 +1,16 @@
// +build !android,!ios,!windows //go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
const ( const (
QueueStagedSize = 128 QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024

View file

@ -1,19 +1,21 @@
// +build ios //go:build ios
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */ // Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
// These are vars instead of consts, because heavier network extensions might want to reduce
const ( // them further.
QueueStagedSize = 128 var (
QueueOutboundSize = 1024 QueueStagedSize = 128
QueueInboundSize = 1024 QueueOutboundSize = 1024
QueueHandshakeSize = 1024 QueueInboundSize = 1024
MaxSegmentSize = 1700 QueueHandshakeSize = 1024
PreallocatedBuffersPerPool = 1024 PreallocatedBuffersPerPool uint32 = 1024
) )
const MaxSegmentSize = 1700

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View file

@ -1,8 +1,8 @@
//+build !race //go:build !race
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View file

@ -1,8 +1,8 @@
//+build race //go:build race
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -11,14 +11,12 @@ import (
"errors" "errors"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
) )
type QueueHandshakeElement struct { type QueueHandshakeElement struct {
@ -29,7 +27,6 @@ type QueueHandshakeElement struct {
} }
type QueueInboundElement struct { type QueueInboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
packet []byte packet []byte
counter uint64 counter uint64
@ -37,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint endpoint conn.Endpoint
} }
type QueueInboundElementsContainer struct {
sync.Mutex
elems []*QueueInboundElement
}
// clearPointers clears elem fields that contain pointers. // clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and // This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily. // avoids accidentally keeping other objects around unnecessarily.
@ -53,12 +55,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver! * NOTE: Not thread safe, but called by sequential receiver!
*/ */
func (peer *Peer) keepKeyFreshReceiving() { func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Get() { if peer.timers.sentLastMinuteHandshake.Load() {
return return
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Set(true) peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@ -68,7 +70,10 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { func (device *Device) RoutineReceiveIncoming(
maxBatchSize int,
recv conn.ReceiveFunc,
) {
recvName := recv.PrettyName() recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@ -81,168 +86,226 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed // receive datagrams until conn is closed
buffer := device.GetMessageBuffer()
var ( var (
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
bufs = make([][]byte, maxBatchSize)
err error err error
size int sizes = make([]int, maxBatchSize)
endpoint conn.Endpoint count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int deathSpiral int
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
) )
for { for i := range bufsArrs {
size, endpoint, err = recv(buffer[:]) bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if bufsArrs[i] != nil {
device.PutMessageBuffer(bufsArrs[i])
}
}
}()
for {
count, err = recv(bufs, sizes, endpoints)
if err != nil { if err != nil {
device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
} }
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return return
} }
device.log.Errorf("Failed to receive packet: %v", err)
if deathSpiral < 10 { if deathSpiral < 10 {
deathSpiral++ deathSpiral++
time.Sleep(time.Second / 3) time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue continue
} }
return return
} }
deathSpiral = 0 deathSpiral = 0
if size < MinMessageSize { device.aSecMux.RLock()
continue // handle each packet in the batch
} for i, size := range sizes[:count] {
if size < MinMessageSize {
// check size of packet
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue continue
} }
// lookup key pair // check size of packet
receiver := binary.LittleEndian.Uint32( packet := bufsArrs[i][:size]
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], var msgType uint32
) if device.isAdvancedSecurityOn() {
value := device.indexTable.Lookup(receiver) if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
keypair := value.keypair junkSize := msgTypeToJunkSize[assumedMsgType]
if keypair == nil { // transport size can align with other header types;
continue // making sure we have the right msgType
} msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
// check keypair expiry packet = packet[junkSize:]
} else {
if keypair.created.Add(RejectAfterTime).Before(time.Now()) { device.log.Verbosef("Transport packet lined up with another msg type")
continue msgType = binary.LittleEndian.Uint32(packet[:4])
} }
} else {
// create work element msgType = binary.LittleEndian.Uint32(packet[:4])
peer := value.peer if msgType != MessageTransportType {
elem := device.GetInboundElement() device.log.Verbosef("ASec: Received message with unknown type")
elem.packet = packet continue
elem.buffer = buffer }
elem.keypair = keypair }
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
// add to decryption queues
if peer.isRunning.Get() {
peer.queue.inbound.c <- elem
device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer()
} else { } else {
device.PutInboundElement(elem) msgType = binary.LittleEndian.Uint32(packet[:4])
} }
continue switch msgType {
// otherwise it is a fixed size & handshake related packet // check if transport
case MessageInitiationType: case MessageTransportType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType: // check size
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType: if len(packet) < MessageTransportSize {
okay = len(packet) == MessageCookieReplySize continue
}
default: // lookup key pair
device.log.Verbosef("Received message with unknown type")
} receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = bufsArrs[i]
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsContainer()
elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}
default:
device.log.Verbosef("Received message with unknown type")
continue
}
if okay {
select { select {
case device.queue.handshake.c <- QueueHandshakeElement{ case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType, msgType: msgType,
buffer: buffer, buffer: bufsArrs[i],
packet: packet, packet: packet,
endpoint: endpoint, endpoint: endpoints[i],
}: }:
buffer = device.GetMessageBuffer() bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
default: default:
} }
} }
device.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
device.queue.decryption.c <- elemsContainer
} else {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
} }
} }
func (device *Device) RoutineDecryption() { func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker - stopped") defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker - started") device.log.Verbosef("Routine: decryption worker %d - started", id)
for elem := range device.queue.decryption.c { for elemsContainer := range device.queue.decryption.c {
// split message into fields for _, elem := range elemsContainer.elems {
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] // split message into fields
content := elem.packet[MessageTransportOffsetContent:] counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer // decrypt and release to consumer
var err error var err error
elem.counter = binary.LittleEndian.Uint64(counter) elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce // copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open( elem.packet, err = elem.keypair.receive.Open(
content[:0], content[:0],
nonce[:], nonce[:],
content, content,
nil, nil,
) )
if err != nil { if err != nil {
elem.packet = nil elem.packet = nil
}
} }
elem.Unlock() elemsContainer.Unlock()
} }
} }
/* Handles incoming packets related to handshake /* Handles incoming packets related to handshake
*/ */
func (device *Device) RoutineHandshake() { func (device *Device) RoutineHandshake(id int) {
defer func() { defer func() {
device.log.Verbosef("Routine: handshake worker - stopped") device.log.Verbosef("Routine: handshake worker %d - stopped", id)
device.queue.encryption.wg.Done() device.queue.encryption.wg.Done()
}() }()
device.log.Verbosef("Routine: handshake worker - started") device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c { for elem := range device.queue.handshake.c {
device.aSecMux.RLock()
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
switch elem.msgType { switch elem.msgType {
@ -269,10 +332,15 @@ func (device *Device) RoutineHandshake() {
// consume reply // consume reply
if peer := entry.peer; peer.isRunning.Get() { if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) device.log.Verbosef(
"Receiving cookie response from %s",
elem.endpoint.DstToString(),
)
if !peer.cookieGenerator.ConsumeReply(&reply) { if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response") device.log.Verbosef(
"Could not decrypt invalid cookie response",
)
} }
} }
@ -314,9 +382,7 @@ func (device *Device) RoutineHandshake() {
switch elem.msgType { switch elem.msgType {
case MessageInitiationType: case MessageInitiationType:
// unmarshal // unmarshal
var msg MessageInitiation var msg MessageInitiation
reader := bytes.NewReader(elem.packet) reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg) err := binary.Read(reader, binary.LittleEndian, &msg)
@ -326,7 +392,6 @@ func (device *Device) RoutineHandshake() {
} }
// consume initiation // consume initiation
peer := device.ConsumeMessageInitiation(&msg) peer := device.ConsumeMessageInitiation(&msg)
if peer == nil { if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
@ -342,7 +407,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer) device.log.Verbosef("%v - Received handshake initiation", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
@ -370,7 +435,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer) device.log.Verbosef("%v - Received handshake response", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers // update timers
@ -391,11 +456,12 @@ func (device *Device) RoutineHandshake() {
peer.SendKeepalive() peer.SendKeepalive()
} }
skip: skip:
device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
} }
} }
func (peer *Peer) RoutineSequentialReceiver() { func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device device := peer.device
defer func() { defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@ -403,89 +469,109 @@ func (peer *Peer) RoutineSequentialReceiver() {
}() }()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer) device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
for elem := range peer.queue.inbound.c { bufs := make([][]byte, 0, maxBatchSize)
if elem == nil {
for elemsContainer := range peer.queue.inbound.c {
if elemsContainer == nil {
return return
} }
var err error elemsContainer.Lock()
elem.Lock() validTailPacket := -1
if elem.packet == nil { dataPacketReceived := false
// decryption failed rxBytesLen := uint64(0)
goto skip for i, elem := range elemsContainer.elems {
} if elem.packet == nil {
// decryption failed
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { continue
goto skip
}
peer.SetEndpointFromPacket(elem.endpoint)
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
goto skip
}
peer.timersDataReceived()
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
goto skip
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
goto skip
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip
} }
case ipv6.Version: if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
if len(elem.packet) < ipv6.HeaderLen { continue
goto skip
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
goto skip
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip
} }
default: validTailPacket = i
device.log.Verbosef("Packet with invalid IP version from %v", peer) if peer.ReceivedWithKeypair(elem.keypair) {
goto skip peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
dataPacketReceived = true
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
}
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
}
default:
device.log.Verbosef(
"Packet with invalid IP version from %v",
peer,
)
continue
}
bufs = append(
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
} }
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) peer.rxBytes.Add(rxBytesLen)
if err != nil && !device.isClosed() { if validTailPacket >= 0 {
device.log.Errorf("Failed to write packet to TUN device: %v", err) peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
} }
if len(peer.queue.inbound.c) == 0 { if dataPacketReceived {
err = device.tun.device.Flush() peer.timersDataReceived()
if err != nil { }
peer.device.log.Errorf("Unable to flush packets: %v", err) if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
} }
} }
skip: for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem) device.PutInboundElement(elem)
}
bufs = bufs[:0]
device.PutInboundElementsContainer(elemsContainer)
} }
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,11 +8,14 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"os"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -43,7 +46,6 @@ import (
*/ */
type QueueOutboundElement struct { type QueueOutboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!) packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
@ -51,10 +53,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer peer *Peer // related peer
} }
type QueueOutboundElementsContainer struct {
sync.Mutex
elems []*QueueOutboundElement
}
func (device *Device) NewOutboundElement() *QueueOutboundElement { func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement() elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer() elem.buffer = device.GetMessageBuffer()
elem.Mutex = sync.Mutex{}
elem.nonce = 0 elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers. // keypair and peer were cleared (if necessary) by clearPointers.
return elem return elem
@ -74,14 +80,17 @@ func (elem *QueueOutboundElement) clearPointers() {
/* Queues a keepalive if no packets are queued for peer /* Queues a keepalive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepalive() { func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Get() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
elemsContainer := peer.device.GetOutboundElementsContainer()
elemsContainer.elems = append(elemsContainer.elems, elem)
select { select {
case peer.queue.staged <- elem: case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer) peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default: default:
peer.device.PutMessageBuffer(elem.buffer) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem) peer.device.PutOutboundElement(elem)
peer.device.PutOutboundElementsContainer(elemsContainer)
} }
} }
peer.SendStagedPackets() peer.SendStagedPackets()
@ -89,7 +98,7 @@ func (peer *Peer) SendKeepalive() {
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry { if !isRetry {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
} }
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
@ -114,17 +123,56 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err return err
} }
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.device.junkCreator.createJunkPackets()
peer.device.aSecMux.RUnlock()
var buff [MessageInitiationSize]byte if err != nil {
writer := bytes.NewBuffer(buff[:0]) peer.device.log.Errorf("%v - %v", peer, err)
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes() packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffer(packet) sendBuffer = append(sendBuffer, junkedHeader)
err = peer.SendBuffers(sendBuffer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
} }
@ -145,12 +193,29 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err return err
} }
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
var buff [MessageResponseSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, response) binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes() packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession() err = peer.BeginSymmetricSession()
if err != nil { if err != nil {
@ -162,27 +227,35 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffer(packet) // TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
} }
return err return err
} }
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
)
if err != nil { if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err) device.log.Errorf("Failed to create cookie reply: %v", err)
return err return err
} }
var buff [MessageCookieReplySize]byte var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0]) writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) // TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil return nil
} }
@ -191,17 +264,12 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil { if keypair == nil {
return return
} }
nonce := atomic.LoadUint64(&keypair.sendNonce) nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
/* Reads packets from the TUN and inserts
* into staged queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN() { func (device *Device) RoutineReadFromTUN() {
defer func() { defer func() {
device.log.Verbosef("Routine: TUN reader - stopped") device.log.Verbosef("Routine: TUN reader - stopped")
@ -211,80 +279,123 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started") device.log.Verbosef("Routine: TUN reader - started")
var elem *QueueOutboundElement var (
batchSize = device.BatchSize()
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
)
for i := range elems {
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
defer func() {
for _, elem := range elems {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
}
}()
for { for {
if elem != nil { // read packets
device.PutMessageBuffer(elem.buffer) count, readErr = device.tun.device.Read(bufs, sizes, offset)
device.PutOutboundElement(elem) for i := 0; i < count; i++ {
if sizes[i] < 1 {
continue
}
elem := elems[i]
elem.packet = bufs[i][offset : offset+sizes[i]]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")
}
if peer == nil {
continue
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsContainer()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
} }
elem = device.NewOutboundElement()
// read packet for peer, elemsForPeer := range elemsByPeer {
if peer.isRunning.Load() {
peer.StagePackets(elemsForPeer)
peer.SendStagedPackets()
} else {
for _, elem := range elemsForPeer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsForPeer)
}
delete(elemsByPeer, peer)
}
offset := MessageTransportHeaderSize if readErr != nil {
size, err := device.tun.device.Read(elem.buffer[:], offset) if errors.Is(readErr, tun.ErrTooManySegments) {
// TODO: record stat for this
if err != nil { // This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
continue
}
if !device.isClosed() { if !device.isClosed() {
device.log.Errorf("Failed to read packet from TUN device: %v", err) if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
}
go device.Close() go device.Close()
} }
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return return
} }
if size == 0 || size > MaxContentSize {
continue
}
elem.packet = elem.buffer[offset : offset+size]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")
}
if peer == nil {
continue
}
if peer.isRunning.Get() {
peer.StagePacket(elem)
elem = nil
peer.SendStagedPackets()
}
} }
} }
func (peer *Peer) StagePacket(elem *QueueOutboundElement) { func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for { for {
select { select {
case peer.queue.staged <- elem: case peer.queue.staged <- elems:
return return
default: default:
} }
select { select {
case tooOld := <-peer.queue.staged: case tooOld := <-peer.queue.staged:
peer.device.PutMessageBuffer(tooOld.buffer) for _, elem := range tooOld.elems {
peer.device.PutOutboundElement(tooOld) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(tooOld)
default: default:
} }
} }
@ -297,32 +408,59 @@ top:
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
return return
} }
for { for {
var elemsContainerOOO *QueueOutboundElementsContainer
select { select {
case elem := <-peer.queue.staged: case elemsContainer := <-peer.queue.staged:
elem.peer = peer i := 0
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 for _, elem := range elemsContainer.elems {
if elem.nonce >= RejectAfterMessages { elem.peer = peer
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) elem.nonce = keypair.sendNonce.Add(1) - 1
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
}
elemsContainer.Lock()
elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
}
if len(elemsContainer.elems) == 0 {
peer.device.PutOutboundElementsContainer(elemsContainer)
goto top goto top
} }
elem.keypair = keypair
elem.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.queue.outbound.c <- elem peer.queue.outbound.c <- elemsContainer
peer.device.queue.encryption.c <- elem peer.device.queue.encryption.c <- elemsContainer
} else { } else {
peer.device.PutMessageBuffer(elem.buffer) for _, elem := range elemsContainer.elems {
peer.device.PutOutboundElement(elem) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
}
if elemsContainerOOO != nil {
goto top
} }
default: default:
return return
@ -333,9 +471,12 @@ top:
func (peer *Peer) FlushStagedPackets() { func (peer *Peer) FlushStagedPackets() {
for { for {
select { select {
case elem := <-peer.queue.staged: case elemsContainer := <-peer.queue.staged:
peer.device.PutMessageBuffer(elem.buffer) for _, elem := range elemsContainer.elems {
peer.device.PutOutboundElement(elem) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
default: default:
return return
} }
@ -362,48 +503,45 @@ func calculatePaddingSize(packetSize, mtu int) int {
* *
* Obs. One instance per core * Obs. One instance per core
*/ */
func (device *Device) RoutineEncryption() { func (device *Device) RoutineEncryption(id int) {
var paddingZeros [PaddingMultiple]byte var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: encryption worker - stopped") defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker - started") device.log.Verbosef("Routine: encryption worker %d - started", id)
for elem := range device.queue.encryption.c { for elemsContainer := range device.queue.encryption.c {
// populate header fields for _, elem := range elemsContainer.elems {
header := elem.buffer[:MessageTransportHeaderSize] // populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4] fieldType := header[0:4]
fieldReceiver := header[4:8] fieldReceiver := header[4:8]
fieldNonce := header[8:16] fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType) binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16 // pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer // encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal( elem.packet = elem.keypair.send.Seal(
header, header,
nonce[:], nonce[:],
elem.packet, elem.packet,
nil, nil,
) )
elem.Unlock() }
elemsContainer.Unlock()
} }
} }
/* Sequentially reads packets from queue and sends to endpoint func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequentialSender() {
device := peer.device device := peer.device
defer func() { defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@ -411,36 +549,57 @@ func (peer *Peer) RoutineSequentialSender() {
}() }()
device.log.Verbosef("%v - Routine: sequential sender - started", peer) device.log.Verbosef("%v - Routine: sequential sender - started", peer)
for elem := range peer.queue.outbound.c { bufs := make([][]byte, 0, maxBatchSize)
if elem == nil {
for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
if elemsContainer == nil {
return return
} }
elem.Lock() if !peer.isRunning.Load() {
if !peer.isRunning.Get() {
// peer has been stopped; return re-usable elems to the shared pool. // peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped // This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed. // immediately after this check, in which case, elem will get processed.
// The timers and SendBuffer code are resilient to a few stragglers. // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure // TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary. // that we never accidentally keep timers alive longer than necessary.
device.PutMessageBuffer(elem.buffer) elemsContainer.Lock()
device.PutOutboundElement(elem) for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
continue continue
} }
dataSent := false
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
bufs = append(bufs, elem.packet)
}
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
// send message and return buffer to pool err := peer.SendBuffers(bufs)
if dataSent {
err := peer.SendBuffer(elem.packet)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent() peer.timersDataSent()
} }
device.PutMessageBuffer(elem.buffer) for _, elem := range elemsContainer.elems {
device.PutOutboundElement(elem) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
if err != nil { if err != nil {
device.log.Errorf("%v - Failed to send data packet: %v", peer, err) var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
device.log.Verbosef(err.Error())
err = errGSO.RetryErr
}
}
if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue continue
} }

View file

@ -1,10 +1,10 @@
// +build !linux //go:build !linux
package device package device
import ( import (
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/rwcancel" "github.com/amnezia-vpn/amneziawg-go/rwcancel"
) )
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* *
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port * WireGuard's kernelspace implementation. This is more or less a straight port
@ -20,12 +20,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/rwcancel" "github.com/amnezia-vpn/amneziawg-go/rwcancel"
) )
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
if _, ok := bind.(*conn.LinuxSocketBind); !ok { if !conn.StdNetSupportsStickySockets {
return nil, nil
}
if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil return nil, nil
} }
@ -107,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok { if !ok {
break break
} }
pePtr.peer.Lock() pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint { if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.Unlock() pePtr.peer.endpoint.Unlock()
break break
} }
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock() pePtr.peer.endpoint.Unlock()
break break
} }
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.Unlock() pePtr.peer.endpoint.Unlock()
} }
attr = attr[attrhdr.Len:] attr = attr[attrhdr.Len:]
} }
@ -131,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock() device.peers.RLock()
i := uint32(1) i := uint32(1)
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.RLock() peer.endpoint.Lock()
if peer.endpoint == nil { if peer.endpoint.val == nil {
peer.RUnlock() peer.endpoint.Unlock()
continue continue
} }
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil { if nativeEP == nil {
peer.RUnlock() peer.endpoint.Unlock()
continue continue
} }
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock() peer.endpoint.Unlock()
break break
} }
nlmsg := struct { nlmsg := struct {
@ -169,12 +172,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8, Len: 8,
Type: unix.RTA_DST, Type: unix.RTA_DST,
}, },
nativeEP.Dst4().Addr, nativeEP.DstIP().As4(),
unix.RtAttr{ unix.RtAttr{
Len: 8, Len: 8,
Type: unix.RTA_SRC, Type: unix.RTA_SRC,
}, },
nativeEP.Src4().Src, nativeEP.SrcIP().As4(),
unix.RtAttr{ unix.RtAttr{
Len: 8, Len: 8,
Type: unix.RTA_MARK, Type: unix.RTA_MARK,
@ -185,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock() reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{ reqPeer[i] = peerEndpointPtr{
peer: peer, peer: peer,
endpoint: &peer.endpoint, endpoint: &peer.endpoint.val,
} }
reqPeerLock.Unlock() reqPeerLock.Unlock()
peer.RUnlock() peer.endpoint.Unlock()
i++ i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil { if err != nil {
@ -204,7 +207,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
} }
func createNetlinkRouteSocket() (int, error) { func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil { if err != nil {
return -1, err return -1, err
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* *
* This is based heavily on timers.c from the kernel implementation. * This is based heavily on timers.c from the kernel implementation.
*/ */
@ -8,12 +8,14 @@
package device package device
import ( import (
"math/rand"
"sync" "sync"
"sync/atomic"
"time" "time"
_ "unsafe"
) )
//go:linkname fastrandn runtime.fastrandn
func fastrandn(n uint32) uint32
// A Timer manages time-based aspects of the WireGuard protocol. // A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list. // Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct { type Timer struct {
@ -71,11 +73,11 @@ func (timer *Timer) IsPending() bool {
} }
func (peer *Peer) timersActive() bool { func (peer *Peer) timersActive() bool {
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
} }
func expiredRetransmitHandshake(peer *Peer) { func expiredRetransmitHandshake(peer *Peer) {
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() { if peer.timersActive() {
@ -94,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
} }
} else { } else {
atomic.AddUint32(&peer.timers.handshakeAttempts, 1) peer.timers.handshakeAttempts.Add(1)
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock() peer.markEndpointSrcForClearing()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(true) peer.SendHandshakeInitiation(true)
} }
@ -110,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) { func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive() peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Get() { if peer.timers.needAnotherKeepalive.Load() {
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} }
@ -121,13 +119,8 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) { func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock() peer.markEndpointSrcForClearing()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
func expiredZeroKeyMaterial(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) {
@ -136,7 +129,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
} }
func expiredPersistentKeepalive(peer *Peer) { func expiredPersistentKeepalive(peer *Peer) {
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@ -144,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -154,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() { if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else { } else {
peer.timers.needAnotherKeepalive.Set(true) peer.timers.needAnotherKeepalive.Store(true)
} }
} }
} }
@ -176,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */ /* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() { func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -185,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Del() peer.timers.retransmitHandshake.Del()
} }
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) peer.lastHandshakeNano.Store(time.Now().UnixNano())
} }
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@ -199,7 +192,7 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) keepalive := peer.persistentKeepaliveInterval.Load()
if keepalive > 0 && peer.timersActive() { if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
} }
@ -214,9 +207,9 @@ func (peer *Peer) timersInit() {
} }
func (peer *Peer) timersStart() { func (peer *Peer) timersStart() {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

View file

@ -1,15 +1,14 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"fmt" "fmt"
"sync/atomic"
"golang.zx2c4.com/wireguard/tun" "github.com/amnezia-vpn/amneziawg-go/tun"
) )
const DefaultMTU = 1420 const DefaultMTU = 1420
@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize mtu = MaxContentSize
} }
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu)) old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu { if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -12,13 +12,13 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
) )
type IPCError struct { type IPCError struct {
@ -38,12 +38,12 @@ func (s IPCError) ErrorCode() int64 {
return s.code return s.code
} }
func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError { func ipcErrorf(code int64, msg string, args ...any) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)} return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
} }
var byteBufferPool = &sync.Pool{ var byteBufferPool = &sync.Pool{
New: func() interface{} { return new(bytes.Buffer) }, New: func() any { return new(bytes.Buffer) },
} }
// IpcGetOperation implements the WireGuard configuration protocol "get" operation. // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
@ -55,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
buf := byteBufferPool.Get().(*bytes.Buffer) buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
defer byteBufferPool.Put(buf) defer byteBufferPool.Put(buf)
sendf := func(format string, args ...interface{}) { sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...) fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n') buf.WriteByte('\n')
} }
@ -72,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
} }
func() { func() {
// lock required resources // lock required resources
device.net.RLock() device.net.RLock()
@ -98,31 +97,61 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark) sendf("fwmark=%d", device.net.fwmark)
} }
// serialize each peer state if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
}
}
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.RLock() // Serialize peer state.
defer peer.RUnlock() peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
peer.handshake.mutex.RUnlock()
sendf("protocol_version=1") sendf("protocol_version=1")
if peer.endpoint != nil { peer.endpoint.Lock()
sendf("endpoint=%s", peer.endpoint.DstToString()) if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
} }
peer.endpoint.Unlock()
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds() secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds() nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs) sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano) sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)) sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr) sendf("allowed_ip=%s", prefix.String())
return true return true
}) })
} }
@ -151,19 +180,27 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer) peer := new(ipcSetPeer)
deviceConfig := true deviceConfig := true
tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
// Blank line means terminate operation. // Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
return nil return nil
} }
parts := strings.Split(line, "=") key, value, ok := strings.Cut(line, "=")
if len(parts) != 2 { if !ok {
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts)) return ipcErrorf(
ipc.IpcErrorProtocol,
"failed to parse line %q",
line,
)
} }
key := parts[0]
value := parts[1]
if key == "public_key" { if key == "public_key" {
if deviceConfig { if deviceConfig {
@ -180,7 +217,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error var err error
if deviceConfig { if deviceConfig {
err = device.handleDeviceLine(key, value) err = device.handleDeviceLine(key, value, &tempASecCfg)
} else { } else {
err = device.handlePeerLine(peer, key, value) err = device.handlePeerLine(peer, key, value)
} }
@ -188,6 +225,10 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err return err
} }
} }
err = device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig() peer.handlePostConfig()
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
@ -196,7 +237,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil return nil
} }
func (device *Device) handleDeviceLine(key, value string) error { func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
@ -242,6 +283,83 @@ func (device *Device) handleDeviceLine(key, value string) error {
device.log.Verbosef("UAPI: Removing all peers") device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
default: default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
} }
@ -254,15 +372,29 @@ type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on *Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer created bool // new reports whether this is a newly created peer
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
} }
func (peer *ipcSetPeer) handlePostConfig() { func (peer *ipcSetPeer) handlePostConfig() {
if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() { if peer.Peer == nil || peer.dummy {
return
}
if peer.created {
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
if peer.pkaOn {
peer.SendKeepalive()
}
peer.SendStagedPackets() peer.SendStagedPackets()
} }
} }
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { func (device *Device) handlePublicKeyLine(
peer *ipcSetPeer,
value string,
) error {
// Load/create the peer we are configuring. // Load/create the peer we are configuring.
var publicKey NoisePublicKey var publicKey NoisePublicKey
err := publicKey.FromHex(value) err := publicKey.FromHex(value)
@ -292,7 +424,10 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error
return nil return nil
} }
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { func (device *Device) handlePeerLine(
peer *ipcSetPeer,
key, value string,
) error {
switch key { switch key {
case "update_only": case "update_only":
// allow disabling of creation // allow disabling of creation
@ -334,9 +469,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
} }
peer.Lock() peer.endpoint.Lock()
defer peer.Unlock() defer peer.endpoint.Unlock()
peer.endpoint = endpoint peer.endpoint.val = endpoint
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
@ -346,17 +481,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
} }
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on. // Send immediate keepalive if we're turning it on and before it wasn't on.
if old == 0 && secs != 0 { peer.pkaOn = old == 0 && secs != 0
if err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
}
if device.isUp() && !peer.dummy {
peer.SendKeepalive()
}
}
case "replace_allowed_ips": case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
@ -370,16 +498,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip": case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value)
_, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
} }
if peer.dummy { if peer.dummy {
return nil return nil
} }
ones, _ := network.Mask.Size() device.allowedips.Insert(prefix, peer.Peer)
device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {

51
format_test.go Normal file
View file

@ -0,0 +1,51 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"go/format"
"io/fs"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
)
func TestFormatting(t *testing.T) {
var wg sync.WaitGroup
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
t.Errorf("unable to walk %s: %v", path, err)
return nil
}
if d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
wg.Add(1)
go func(path string) {
defer wg.Done()
src, err := os.ReadFile(path)
if err != nil {
t.Errorf("unable to read %s: %v", path, err)
return
}
if runtime.GOOS == "windows" {
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
}
formatted, err := format.Source(src)
if err != nil {
t.Errorf("unable to format %s: %v", path, err)
return
}
if !bytes.Equal(src, formatted) {
t.Errorf("unformatted code: %s", path)
}
}(path)
return nil
})
wg.Wait()
}

18
go.mod
View file

@ -1,9 +1,17 @@
module golang.zx2c4.com/wireguard module github.com/amnezia-vpn/amneziawg-go
go 1.16 go 1.24
require ( require (
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 github.com/tevino/abool/v2 v2.1.0
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/crypto v0.36.0
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169 golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
)
require (
github.com/google/btree v1.1.3 // indirect
golang.org/x/time v0.9.0 // indirect
) )

36
go.sum
View file

@ -1,16 +1,20 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169 h1:fpeMGRM6A+XFcw4RPCO8s8hH7ppgrGR22pSIjwM7YUI= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=

View file

@ -1,12 +1,11 @@
// +build windows // Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/* SPDX-License-Identifier: MIT //go:build windows
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe package namedpipe
import ( import (
"io" "io"
@ -22,8 +21,10 @@ import (
type timeoutChan chan struct{} type timeoutChan chan struct{}
var ioInitOnce sync.Once var (
var ioCompletionPort windows.Handle ioInitOnce sync.Once
ioCompletionPort windows.Handle
)
// ioResult contains the result of an asynchronous IO operation // ioResult contains the result of an asynchronous IO operation
type ioResult struct { type ioResult struct {
@ -52,7 +53,7 @@ type file struct {
handle windows.Handle handle windows.Handle
wg sync.WaitGroup wg sync.WaitGroup
wgLock sync.RWMutex wgLock sync.RWMutex
closing uint32 // used as atomic boolean closing atomic.Bool
socket bool socket bool
readDeadline deadlineHandler readDeadline deadlineHandler
writeDeadline deadlineHandler writeDeadline deadlineHandler
@ -63,7 +64,7 @@ type deadlineHandler struct {
channel timeoutChan channel timeoutChan
channelLock sync.RWMutex channelLock sync.RWMutex
timer *time.Timer timer *time.Timer
timedout uint32 // used as atomic boolean timedout atomic.Bool
} }
// makeFile makes a new file from an existing file handle // makeFile makes a new file from an existing file handle
@ -87,7 +88,7 @@ func makeFile(h windows.Handle) (*file, error) {
func (f *file) closeHandle() { func (f *file) closeHandle() {
f.wgLock.Lock() f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once. // Atomically set that we are closing, releasing the resources only once.
if atomic.SwapUint32(&f.closing, 1) == 0 { if f.closing.Swap(true) == false {
f.wgLock.Unlock() f.wgLock.Unlock()
// cancel all IO and wait for it to complete // cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil) windows.CancelIoEx(f.handle, nil)
@ -110,7 +111,7 @@ func (f *file) Close() error {
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) { func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock() f.wgLock.RLock()
if atomic.LoadUint32(&f.closing) == 1 { if f.closing.Load() {
f.wgLock.RUnlock() f.wgLock.RUnlock()
return nil, os.ErrClosed return nil, os.ErrClosed
} }
@ -142,7 +143,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
return int(bytes), err return int(bytes), err
} }
if atomic.LoadUint32(&f.closing) == 1 { if f.closing.Load() {
windows.CancelIoEx(f.handle, &c.o) windows.CancelIoEx(f.handle, &c.o)
} }
@ -158,7 +159,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == windows.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
if atomic.LoadUint32(&f.closing) == 1 { if f.closing.Load() {
err = os.ErrClosed err = os.ErrClosed
} }
} else if err != nil && f.socket { } else if err != nil && f.socket {
@ -190,7 +191,7 @@ func (f *file) Read(b []byte) (int, error) {
} }
defer f.wg.Done() defer f.wg.Done()
if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { if f.readDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@ -217,7 +218,7 @@ func (f *file) Write(b []byte) (int, error) {
} }
defer f.wg.Done() defer f.wg.Done()
if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { if f.writeDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@ -254,7 +255,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
d.timer = nil d.timer = nil
} }
atomic.StoreUint32(&d.timedout, 0) d.timedout.Store(false)
select { select {
case <-d.channel: case <-d.channel:
@ -269,7 +270,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
timeoutIO := func() { timeoutIO := func() {
atomic.StoreUint32(&d.timedout, 1) d.timedout.Store(true)
close(d.channel) close(d.channel)
} }

View file

@ -1,13 +1,12 @@
// +build windows // Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/* SPDX-License-Identifier: MIT //go:build windows
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. // Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
package winpipe package namedpipe
import ( import (
"context" "context"
@ -15,6 +14,7 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"sync/atomic"
"time" "time"
"unsafe" "unsafe"
@ -28,7 +28,7 @@ type pipe struct {
type messageBytePipe struct { type messageBytePipe struct {
pipe pipe
writeClosed bool writeClosed atomic.Bool
readEOF bool readEOF bool
} }
@ -50,25 +50,26 @@ func (f *pipe) SetDeadline(t time.Time) error {
// CloseWrite closes the write side of a message pipe in byte mode. // CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error { func (f *messageBytePipe) CloseWrite() error {
if f.writeClosed { if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
err := f.file.Flush() err := f.file.Flush()
if err != nil { if err != nil {
f.writeClosed.Store(false)
return err return err
} }
_, err = f.file.Write(nil) _, err = f.file.Write(nil)
if err != nil { if err != nil {
f.writeClosed.Store(false)
return err return err
} }
f.writeClosed = true
return nil return nil
} }
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite. // they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) { func (f *messageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed { if f.writeClosed.Load() {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
if len(b) == 0 { if len(b) == 0 {
@ -142,30 +143,24 @@ type DialConfig struct {
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
} }
// Dial connects to the specified named pipe by path, timing out if the connection // DialTimeout connects to the specified named pipe by path, timing out if the
// takes longer than the specified duration. If timeout is nil, then we use // connection takes longer than the specified duration. If timeout is zero, then
// a default timeout of 2 seconds. // we use a default timeout of 2 seconds.
func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
var absTimeout time.Time if timeout == 0 {
if timeout != nil { timeout = time.Second * 2
absTimeout = time.Now().Add(*timeout)
} else {
absTimeout = time.Now().Add(2 * time.Second)
} }
absTimeout := time.Now().Add(timeout)
ctx, _ := context.WithDeadline(context.Background(), absTimeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialContext(ctx, path, config) conn, err := config.DialContext(ctx, path)
if err == context.DeadlineExceeded { if err == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
return conn, err return conn, err
} }
// DialContext attempts to connect to the specified named pipe by path // DialContext attempts to connect to the specified named pipe by path.
// cancellation or timeout. func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) {
if config == nil {
config = &DialConfig{}
}
var err error var err error
var h windows.Handle var h windows.Handle
h, err = tryDialPipe(ctx, &path) h, err = tryDialPipe(ctx, &path)
@ -213,6 +208,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn
return &pipe{file: f, path: path}, nil return &pipe{file: f, path: path}, nil
} }
var defaultDialer DialConfig
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(path, timeout)
}
// DialContext calls DialConfig.DialContext using an empty configuration.
func DialContext(ctx context.Context, path string) (net.Conn, error) {
return defaultDialer.DialContext(ctx, path)
}
type acceptResponse struct { type acceptResponse struct {
f *file f *file
err error err error
@ -222,12 +229,12 @@ type pipeListener struct {
firstHandle windows.Handle firstHandle windows.Handle
path string path string
config ListenConfig config ListenConfig
acceptCh chan (chan acceptResponse) acceptCh chan chan acceptResponse
closeCh chan int closeCh chan int
doneCh chan int doneCh chan int
} }
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
path16, err := windows.UTF16PtrFromString(path) path16, err := windows.UTF16PtrFromString(path)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
@ -247,7 +254,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
oa.ObjectName = &ntPath oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe. // The security descriptor is only needed for the first pipe.
if first { if isFirstPipe {
if sd != nil { if sd != nil {
oa.SecurityDescriptor = sd oa.SecurityDescriptor = sd
} else { } else {
@ -257,7 +264,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
return 0, err return 0, err
} }
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
sd, err := windows.NewSecurityDescriptor() sd, err = windows.NewSecurityDescriptor()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -275,11 +282,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
disposition := uint32(windows.FILE_OPEN) disposition := uint32(windows.FILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if first { if isFirstPipe {
disposition = windows.FILE_CREATE disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system // By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking // will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false. // client connections until the next call with isFirstPipe == false.
access = windows.SYNCHRONIZE access = windows.SYNCHRONIZE
} }
@ -395,10 +402,7 @@ type ListenConfig struct {
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. // Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
// The pipe must not already exist. // The pipe must not already exist.
func Listen(path string, c *ListenConfig) (net.Listener, error) { func (c *ListenConfig) Listen(path string) (net.Listener, error) {
if c == nil {
c = &ListenConfig{}
}
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil { if err != nil {
return nil, err return nil, err
@ -407,12 +411,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
firstHandle: h, firstHandle: h,
path: path, path: path,
config: *c, config: *c,
acceptCh: make(chan (chan acceptResponse)), acceptCh: make(chan chan acceptResponse),
closeCh: make(chan int), closeCh: make(chan int),
doneCh: make(chan int), doneCh: make(chan int),
} }
// The first connection is swallowed on Windows 7 & 8, so synthesize it. // The first connection is swallowed on Windows 7 & 8, so synthesize it.
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
path16, err := windows.UTF16PtrFromString(path) path16, err := windows.UTF16PtrFromString(path)
if err == nil { if err == nil {
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
@ -425,6 +429,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
return l, nil return l, nil
} }
var defaultListener ListenConfig
// Listen calls ListenConfig.Listen using an empty configuration.
func Listen(path string) (net.Listener, error) {
return defaultListener.Listen(path)
}
func connectPipe(p *file) error { func connectPipe(p *file) error {
c, err := p.prepareIo() c, err := p.prepareIo()
if err != nil { if err != nil {

View file

@ -1,12 +1,11 @@
// +build windows // Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/* SPDX-License-Identifier: MIT //go:build windows
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe_test package namedpipe_test
import ( import (
"bufio" "bufio"
@ -21,8 +20,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/winpipe"
) )
func randomPipePath() string { func randomPipePath() string {
@ -30,7 +29,7 @@ func randomPipePath() string {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return `\\.\PIPE\go-winpipe-test-` + guid.String() return `\\.\PIPE\go-namedpipe-test-` + guid.String()
} }
func TestPingPong(t *testing.T) { func TestPingPong(t *testing.T) {
@ -39,7 +38,7 @@ func TestPingPong(t *testing.T) {
pong = 24 pong = 24
) )
pipePath := randomPipePath() pipePath := randomPipePath()
listener, err := winpipe.Listen(pipePath, nil) listener, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatalf("unable to listen on pipe: %v", err) t.Fatalf("unable to listen on pipe: %v", err)
} }
@ -64,11 +63,12 @@ func TestPingPong(t *testing.T) {
t.Fatalf("unable to write pong to pipe: %v", err) t.Fatalf("unable to write pong to pipe: %v", err)
} }
}() }()
client, err := winpipe.Dial(pipePath, nil, nil) client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatalf("unable to dial pipe: %v", err) t.Fatalf("unable to dial pipe: %v", err)
} }
defer client.Close() defer client.Close()
client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte var data [1]byte
data[0] = ping data[0] = ping
_, err = client.Write(data[:]) _, err = client.Write(data[:])
@ -85,7 +85,7 @@ func TestPingPong(t *testing.T) {
} }
func TestDialUnknownFailsImmediately(t *testing.T) { func TestDialUnknownFailsImmediately(t *testing.T) {
_, err := winpipe.Dial(randomPipePath(), nil, nil) _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) { if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err) t.Fatalf("expected ENOENT got %v", err)
} }
@ -93,13 +93,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) {
func TestDialListenerTimesOut(t *testing.T) { func TestDialListenerTimesOut(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
d := 10 * time.Millisecond pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
_, err = winpipe.Dial(pipePath, &d, nil) if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded { if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
} }
@ -107,14 +109,17 @@ func TestDialListenerTimesOut(t *testing.T) {
func TestDialContextListenerTimesOut(t *testing.T) { func TestDialContextListenerTimesOut(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
d := 10 * time.Millisecond d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d) ctx, _ := context.WithTimeout(context.Background(), d)
_, err = winpipe.DialContext(ctx, pipePath, nil) pipe, err := namedpipe.DialContext(ctx, pipePath)
if err == nil {
pipe.Close()
}
if err != context.DeadlineExceeded { if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err) t.Fatalf("expected context.DeadlineExceeded, got %v", err)
} }
@ -123,14 +128,14 @@ func TestDialContextListenerTimesOut(t *testing.T) {
func TestDialListenerGetsCancelled(t *testing.T) { func TestDialListenerGetsCancelled(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ch := make(chan error)
defer l.Close() defer l.Close()
ch := make(chan error)
go func(ctx context.Context, ch chan error) { go func(ctx context.Context, ch chan error) {
_, err := winpipe.DialContext(ctx, pipePath, nil) _, err := namedpipe.DialContext(ctx, pipePath)
ch <- err ch <- err
}(ctx, ch) }(ctx, ch)
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
@ -147,23 +152,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
} }
pipePath := randomPipePath() pipePath := randomPipePath()
sd, _ := windows.SecurityDescriptorFromString("D:") sd, _ := windows.SecurityDescriptorFromString("D:")
c := winpipe.ListenConfig{ l, err := (&namedpipe.ListenConfig{
SecurityDescriptor: sd, SecurityDescriptor: sd,
} }).Listen(pipePath)
l, err := winpipe.Listen(pipePath, &c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
_, err = winpipe.Dial(pipePath, nil, nil) pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
pipe.Close()
}
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
} }
} }
func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, cfg) if cfg == nil {
cfg = &namedpipe.ListenConfig{}
}
l, err := cfg.Listen(pipePath)
if err != nil { if err != nil {
return return
} }
@ -179,7 +189,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
ch <- response{c, err} ch <- response{c, err}
}() }()
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
return return
} }
@ -236,7 +246,7 @@ func server(l net.Listener, ch chan int) {
func TestFullListenDialReadWrite(t *testing.T) { func TestFullListenDialReadWrite(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -245,7 +255,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
ch := make(chan int) ch := make(chan int)
go server(l, ch) go server(l, ch)
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -275,7 +285,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
func TestCloseAbortsListen(t *testing.T) { func TestCloseAbortsListen(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -328,7 +338,7 @@ func TestCloseServerEOFClient(t *testing.T) {
} }
func TestCloseWriteEOF(t *testing.T) { func TestCloseWriteEOF(t *testing.T) {
cfg := &winpipe.ListenConfig{ cfg := &namedpipe.ListenConfig{
MessageMode: true, MessageMode: true,
} }
c, s, err := getConnection(cfg) c, s, err := getConnection(cfg)
@ -356,7 +366,7 @@ func TestCloseWriteEOF(t *testing.T) {
func TestAcceptAfterCloseFails(t *testing.T) { func TestAcceptAfterCloseFails(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -369,12 +379,15 @@ func TestAcceptAfterCloseFails(t *testing.T) {
func TestDialTimesOutByDefault(t *testing.T) { func TestDialTimesOutByDefault(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
_, err = winpipe.Dial(pipePath, nil, nil) pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded { if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
} }
@ -382,7 +395,7 @@ func TestDialTimesOutByDefault(t *testing.T) {
func TestTimeoutPendingRead(t *testing.T) { func TestTimeoutPendingRead(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -400,7 +413,7 @@ func TestTimeoutPendingRead(t *testing.T) {
close(serverDone) close(serverDone)
}() }()
client, err := winpipe.Dial(pipePath, nil, nil) client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -430,7 +443,7 @@ func TestTimeoutPendingRead(t *testing.T) {
func TestTimeoutPendingWrite(t *testing.T) { func TestTimeoutPendingWrite(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -448,7 +461,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
close(serverDone) close(serverDone)
}() }()
client, err := winpipe.Dial(pipePath, nil, nil) client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -480,13 +493,12 @@ type CloseWriter interface {
} }
func TestEchoWithMessaging(t *testing.T) { func TestEchoWithMessaging(t *testing.T) {
c := winpipe.ListenConfig{ pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{
MessageMode: true, // Use message mode so that CloseWrite() is supported MessageMode: true, // Use message mode so that CloseWrite() is supported
InputBufferSize: 65536, // Use 64KB buffers to improve performance InputBufferSize: 65536, // Use 64KB buffers to improve performance
OutputBufferSize: 65536, OutputBufferSize: 65536,
} }).Listen(pipePath)
pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, &c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -496,19 +508,21 @@ func TestEchoWithMessaging(t *testing.T) {
clientDone := make(chan bool) clientDone := make(chan bool)
go func() { go func() {
// server echo // server echo
conn, e := l.Accept() conn, err := l.Accept()
if e != nil { if err != nil {
t.Fatal(e) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
io.Copy(conn, conn) _, err = io.Copy(conn, conn)
if err != nil {
t.Fatal(err)
}
conn.(CloseWriter).CloseWrite() conn.(CloseWriter).CloseWrite()
close(listenerDone) close(listenerDone)
}() }()
timeout := 1 * time.Second client, err := namedpipe.DialTimeout(pipePath, time.Second)
client, err := winpipe.Dial(pipePath, &timeout, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -521,7 +535,7 @@ func TestEchoWithMessaging(t *testing.T) {
if e != nil { if e != nil {
t.Fatal(e) t.Fatal(e)
} }
if n != 2 { if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n) t.Fatalf("expected 2 bytes, got %v", n)
} }
close(clientDone) close(clientDone)
@ -545,7 +559,7 @@ func TestEchoWithMessaging(t *testing.T) {
func TestConnectRace(t *testing.T) { func TestConnectRace(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -565,7 +579,7 @@ func TestConnectRace(t *testing.T) {
}() }()
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -580,7 +594,7 @@ func TestMessageReadMode(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -602,7 +616,7 @@ func TestMessageReadMode(t *testing.T) {
s.Close() s.Close()
}() }()
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -643,13 +657,13 @@ func TestListenConnectRace(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil { if err == nil {
c.Close() c.Close()
} }
wg.Done() wg.Done()
}() }()
s, err := winpipe.Listen(pipePath, nil) s, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Error(i, err) t.Error(i, err)
} else { } else {

View file

@ -1,8 +1,8 @@
// +build darwin freebsd openbsd //go:build darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@ -54,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@ -104,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err l.connErr <- err
return return
} }
if kerr != nil || n != 1 { if (kerr != nil || n != 1) && kerr != unix.EINTR {
if kerr != nil { if kerr != nil {
l.connErr <- kerr l.connErr <- kerr
} else { } else {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@ -9,8 +9,8 @@ import (
"net" "net"
"os" "os"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
type UAPIListener struct { type UAPIListener struct {
@ -51,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@ -97,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
} }
go func(l *UAPIListener) { go func(l *UAPIListener) {
var buff [0]byte var buf [0]byte
for { for {
defer uapi.inotifyRWCancel.Close() defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition // start with lstat to avoid race condition
@ -105,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err l.connErr <- err
return return
} }
_, err := uapi.inotifyRWCancel.Read(buff[:]) _, err := uapi.inotifyRWCancel.Read(buf[:])
if err != nil { if err != nil {
l.connErr <- err l.connErr <- err
return return

View file

@ -1,8 +1,8 @@
// +build linux darwin freebsd openbsd //go:build linux || darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@ -26,14 +26,14 @@ const (
// socketDirectory is variable because it is modified by a linker // socketDirectory is variable because it is modified by a linker
// flag in wireguard-android. // flag in wireguard-android.
var socketDirectory = "/var/run/wireguard" var socketDirectory = "/var/run/amneziawg"
func sockPath(iface string) string { func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
} }
func UAPIOpen(name string) (*os.File, error) { func UAPIOpen(name string) (*os.File, error) {
if err := os.MkdirAll(socketDirectory, 0755); err != nil { if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
return nil, err return nil, err
} }
@ -43,7 +43,7 @@ func UAPIOpen(name string) (*os.File, error) {
return nil, err return nil, err
} }
oldUmask := unix.Umask(0077) oldUmask := unix.Umask(0o077)
defer unix.Umask(oldUmask) defer unix.Umask(oldUmask)
listener, err := net.ListenUnix("unix", addr) listener, err := net.ListenUnix("unix", addr)

15
ipc/uapi_wasm.go Normal file
View file

@ -0,0 +1,15 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
// Made up sentinel error codes for {js,wasip1}/wasm.
const (
IpcErrorIO = 1
IpcErrorInvalid = 2
IpcErrorPortInUse = 3
IpcErrorUnknown = 4
IpcErrorProtocol = 5
)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@ -8,9 +8,8 @@ package ipc
import ( import (
"net" "net"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/winpipe"
) )
// TODO: replace these with actual standard windows error numbers from the win package // TODO: replace these with actual standard windows error numbers from the win package
@ -54,18 +53,16 @@ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
func init() { func init() {
var err error var err error
/* SDDL_DEVOBJ_SYS_ALL from the WDK */ UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
config := winpipe.ListenConfig{ listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor, SecurityDescriptor: UAPISecurityDescriptor,
} }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

49
main.go
View file

@ -1,8 +1,8 @@
// +build !windows //go:build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@ -13,12 +13,12 @@ import (
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/device" "github.com/amnezia-vpn/amneziawg-go/device"
"golang.zx2c4.com/wireguard/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
"golang.zx2c4.com/wireguard/tun" "github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/sys/unix"
) )
const ( const (
@ -46,20 +46,20 @@ func warning() {
return return
} }
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐") fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │") fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │") fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: │") fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │") fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘") fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
} }
func main() { func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" { if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH) fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
return return
} }
@ -111,7 +111,7 @@ func main() {
// open TUN device (or use supplied fd) // open TUN device (or use supplied fd)
tun, err := func() (tun.Device, error) { tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU) return tun.CreateTUN(interfaceName, device.DefaultMTU)
@ -124,7 +124,7 @@ func main() {
return nil, err return nil, err
} }
err = syscall.SetNonblock(int(fd), true) err = unix.SetNonblock(int(fd), true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +134,7 @@ func main() {
}() }()
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tdev.Name()
if err2 == nil { if err2 == nil {
interfaceName = realInterfaceName interfaceName = realInterfaceName
} }
@ -145,7 +145,7 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Verbosef("Starting wireguard-go version %s", Version) logger.Verbosef("Starting amneziawg-go version %s", Version)
if err != nil { if err != nil {
logger.Errorf("Failed to create TUN device: %v", err) logger.Errorf("Failed to create TUN device: %v", err)
@ -169,7 +169,6 @@ func main() {
return os.NewFile(uintptr(fd), ""), nil return os.NewFile(uintptr(fd), ""), nil
}() }()
if err != nil { if err != nil {
logger.Errorf("UAPI listen error: %v", err) logger.Errorf("UAPI listen error: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
@ -197,7 +196,7 @@ func main() {
files[0], // stdin files[0], // stdin
files[1], // stdout files[1], // stdout
files[2], // stderr files[2], // stderr
tun.File(), tdev.File(),
fileUAPI, fileUAPI,
}, },
Dir: ".", Dir: ".",
@ -223,7 +222,7 @@ func main() {
return return
} }
device := device.NewDevice(tun, conn.NewDefaultBind(), logger) device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started") logger.Verbosef("Device started")
@ -251,7 +250,7 @@ func main() {
// wait for program to terminate // wait for program to terminate
signal.Notify(term, syscall.SIGTERM) signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
select { select {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@ -9,13 +9,14 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"syscall"
"golang.zx2c4.com/wireguard/conn" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
) )
const ( const (
@ -29,13 +30,13 @@ func main() {
} }
interfaceName := os.Args[1] interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.") fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
logger := device.NewLogger( logger := device.NewLogger(
device.LogLevelVerbose, device.LogLevelVerbose,
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Verbosef("Starting wireguard-go version %s", Version) logger.Verbosef("Starting amneziawg-go version %s", Version)
tun, err := tun.CreateTUN(interfaceName, 0) tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil { if err == nil {
@ -81,7 +82,7 @@ func main() {
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill) signal.Notify(term, os.Kill)
signal.Notify(term, syscall.SIGTERM) signal.Notify(term, windows.SIGTERM)
select { select {
case <-term: case <-term:

View file

@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net" "net/netip"
"sync" "sync"
"time" "time"
) )
@ -30,8 +30,7 @@ type Ratelimiter struct {
timeNow func() time.Time timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry table map[netip.Addr]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
@ -57,8 +56,7 @@ func (rate *Ratelimiter) Init() {
} }
rate.stopReset = make(chan struct{}) rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.table = make(map[netip.Addr]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again. stopReset := rate.stopReset // store in case Init is called again.
@ -87,71 +85,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock() rate.mu.Lock()
defer rate.mu.Unlock() defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 { for key, entry := range rate.table {
entry.mu.Lock() entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key) delete(rate.table, key)
} }
entry.mu.Unlock() entry.mu.Unlock()
} }
for key, entry := range rate.tableIPv6 { return len(rate.table) == 0
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mu.Unlock()
}
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
} }
func (rate *Ratelimiter) Allow(ip net.IP) bool { func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
// lookup entry // lookup entry
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.mu.RLock() rate.mu.RLock()
entry = rate.table[ip]
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
}
rate.mu.RUnlock() rate.mu.RUnlock()
// make new entry if not found // make new entry if not found
if entry == nil { if entry == nil {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow() entry.lastTime = rate.timeNow()
rate.mu.Lock() rate.mu.Lock()
if IPv4 != nil { rate.table[ip] = entry
rate.tableIPv4[keyIPv4] = entry if len(rate.table) == 1 {
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { rate.stopReset <- struct{}{}
rate.stopReset <- struct{}{}
}
} else {
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
} }
rate.mu.Unlock() rate.mu.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.mu.Lock() entry.mu.Lock()
now := rate.timeNow() now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@ -161,7 +127,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
} }
// subtract cost of packet // subtract cost of packet
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.mu.Unlock() entry.mu.Unlock()

View file

@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
) )
@ -71,21 +71,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst", text: "packet following 2 packet burst",
}) })
ips := []net.IP{ ips := []netip.Addr{
net.ParseIP("127.0.0.1"), netip.MustParseAddr("127.0.0.1"),
net.ParseIP("192.168.1.1"), netip.MustParseAddr("192.168.1.1"),
net.ParseIP("172.167.2.3"), netip.MustParseAddr("172.167.2.3"),
net.ParseIP("97.231.252.215"), netip.MustParseAddr("97.231.252.215"),
net.ParseIP("248.97.91.167"), netip.MustParseAddr("248.97.91.167"),
net.ParseIP("188.208.233.47"), netip.MustParseAddr("188.208.233.47"),
net.ParseIP("104.2.183.179"), netip.MustParseAddr("104.2.183.179"),
net.ParseIP("72.129.46.120"), netip.MustParseAddr("72.129.46.120"),
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
} }
now := time.Now() now := time.Now()

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
@ -34,7 +34,7 @@ func (f *Filter) Reset() {
// ValidateCounter checks if the counter should be accepted. // ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected. // Overlimit counters (>= limit) are always rejected.
func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool { func (f *Filter) ValidateCounter(counter, limit uint64) bool {
if counter >= limit { if counter >= limit {
return false return false
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package replay package replay

View file

@ -1,24 +0,0 @@
// +build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
type fdSet struct {
unix.FdSet
}
func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63)
fdset.Bits[i/bits] |= 1 << uint(i%bits)
}
func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63)
return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
}

View file

@ -1,8 +1,8 @@
// +build !windows //go:build !windows && !wasm
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
// Package rwcancel implements cancelable read/write operations on // Package rwcancel implements cancelable read/write operations on
@ -17,13 +17,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func max(a, b int) int {
if a > b {
return a
}
return b
}
type RWCancel struct { type RWCancel struct {
fd int fd int
closingReader *os.File closingReader *os.File
@ -50,13 +43,12 @@ func RetryAfterError(err error) bool {
} }
func (rw *RWCancel) ReadyRead() bool { func (rw *RWCancel) ReadyRead() bool {
closeFd := int(rw.closingReader.Fd()) closeFd := int32(rw.closingReader.Fd())
fdset := fdSet{}
fdset.set(rw.fd) pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
fdset.set(closeFd)
var err error var err error
for { for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) { if err == nil || !RetryAfterError(err) {
break break
} }
@ -64,20 +56,18 @@ func (rw *RWCancel) ReadyRead() bool {
if err != nil { if err != nil {
return false return false
} }
if fdset.check(closeFd) { if pollFds[1].Revents != 0 {
return false return false
} }
return fdset.check(rw.fd) return pollFds[0].Revents != 0
} }
func (rw *RWCancel) ReadyWrite() bool { func (rw *RWCancel) ReadyWrite() bool {
closeFd := int(rw.closingReader.Fd()) closeFd := int32(rw.closingReader.Fd())
fdset := fdSet{} pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
fdset.set(rw.fd)
fdset.set(closeFd)
var err error var err error
for { for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) { if err == nil || !RetryAfterError(err) {
break break
} }
@ -85,10 +75,11 @@ func (rw *RWCancel) ReadyWrite() bool {
if err != nil { if err != nil {
return false return false
} }
if fdset.check(closeFd) {
if pollFds[1].Revents != 0 {
return false return false
} }
return fdset.check(rw.fd) return pollFds[0].Revents != 0
} }
func (rw *RWCancel) Read(p []byte) (n int, err error) { func (rw *RWCancel) Read(p []byte) (n int, err error) {
@ -98,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
return n, err return n, err
} }
if !rw.ReadyRead() { if !rw.ReadyRead() {
return 0, errors.New("fd closed") return 0, os.ErrClosed
} }
} }
} }
@ -110,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) {
return n, err return n, err
} }
if !rw.ReadyWrite() { if !rw.ReadyWrite() {
return 0, errors.New("fd closed") return 0, os.ErrClosed
} }
} }
} }

View file

@ -1,8 +1,9 @@
//go:build windows || wasm
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
package rwcancel package rwcancel
type RWCancel struct { type RWCancel struct{}
}
func (*RWCancel) Cancel() {} func (*RWCancel) Cancel() {}

View file

@ -1,15 +0,0 @@
// +build !linux,!windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
_, err := unix.Select(nfd, r, w, e, timeout)
return err
}

View file

@ -1,13 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) (err error) {
_, err = unix.Select(nfd, r, w, e, timeout)
return
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tai64n package tai64n
@ -11,9 +11,11 @@ import (
"time" "time"
) )
const TimestampSize = 12 const (
const base = uint64(0x400000000000000a) TimestampSize = 12
const whitenerMask = uint32(0x1000000 - 1) base = uint64(0x400000000000000a)
whitenerMask = uint32(0x1000000 - 1)
)
type Timestamp [TimestampSize]byte type Timestamp [TimestampSize]byte

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tai64n package tai64n

View file

@ -1,9 +1,9 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package tun
import ( import (
"reflect" "reflect"
@ -18,15 +18,15 @@ func checkAlignment(t *testing.T, name string, offset uintptr) {
} }
} }
// TestPeerAlignment checks that atomically-accessed fields are // TestRateJugglerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package. // aligned to 64-bit boundaries, as required by the atomic package.
// //
// Unfortunately, violating this rule on 32-bit platforms results in a // Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime. // hard segfault at runtime.
func TestPeerAlignment(t *testing.T) { func TestRateJugglerAlignment(t *testing.T) {
var p Peer var r rateJuggler
typ := reflect.TypeOf(&p).Elem() typ := reflect.TypeOf(&r).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size()) t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
@ -38,20 +38,21 @@ func TestPeerAlignment(t *testing.T) {
) )
} }
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
} }
// TestDeviceAlignment checks that atomically-accessed fields are // TestNativeTunAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package. // aligned to 64-bit boundaries, as required by the atomic package.
// //
// Unfortunately, violating this rule on 32-bit platforms results in a // Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime. // hard segfault at runtime.
func TestDeviceAlignment(t *testing.T) { func TestNativeTunAlignment(t *testing.T) {
var d Device var tun NativeTun
typ := reflect.TypeOf(&d).Elem() typ := reflect.TypeOf(&tun).Elem()
t.Logf("Device type size: %d, with fields:", typ.Size()) t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
@ -61,5 +62,6 @@ func TestDeviceAlignment(t *testing.T) {
field.Type.Align(), field.Type.Align(),
) )
} }
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
} }

118
tun/checksum.go Normal file
View file

@ -0,0 +1,118 @@
package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
for len(b) >= 128 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
b = b[128:]
}
if len(b) >= 64 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
b = b[64:]
}
if len(b) >= 32 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
if len(b) >= 16 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
b = b[16:]
}
if len(b) >= 8 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
b = b[8:]
}
if len(b) >= 4 {
ac += uint64(binary.BigEndian.Uint32(b))
b = b[4:]
}
if len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}
return ac
}
func checksum(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

35
tun/checksum_test.go Normal file
View file

@ -0,0 +1,35 @@
package tun
import (
"fmt"
"math/rand"
"testing"
)
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,
128,
256,
512,
1024,
1500,
2048,
4096,
8192,
9000,
9001,
}
for _, length := range lengths {
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum(buf, 0)
}
})
}
}

12
tun/errors.go Normal file
View file

@ -0,0 +1,12 @@
package tun
import (
"errors"
)
var (
// ErrTooManySegments is returned by Device.Read() when segmentation
// overflows the length of supplied buffers. This error should not cause
// reads to cease.
ErrTooManySegments = errors.New("too many segments")
)

View file

@ -1,8 +1,8 @@
// +build ignore //go:build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@ -10,27 +10,27 @@ package main
import ( import (
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"net/netip"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/device" "github.com/amnezia-vpn/amneziawg-go/device"
"golang.zx2c4.com/wireguard/tun/netstack" "github.com/amnezia-vpn/amneziawg-go/tun/netstack"
) )
func main() { func main() {
tun, tnet, err := netstack.CreateNetTUN( tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("192.168.4.28")},
[]net.IP{net.ParseIP("8.8.8.8")}, []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420) 1420)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
endpoint=163.172.161.0:12912
allowed_ip=0.0.0.0/0 allowed_ip=0.0.0.0/0
endpoint=127.0.0.1:58120
`) `)
err = dev.Up() err = dev.Up()
if err != nil { if err != nil {
@ -42,7 +42,7 @@ allowed_ip=0.0.0.0/0
DialContext: tnet.DialContext, DialContext: tnet.DialContext,
}, },
} }
resp, err := client.Get("https://www.zx2c4.com/ip") resp, err := client.Get("http://192.168.4.29/")
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }

View file

@ -1,8 +1,8 @@
// +build ignore //go:build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@ -12,26 +12,27 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/netip"
"golang.zx2c4.com/wireguard/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/device" "github.com/amnezia-vpn/amneziawg-go/device"
"golang.zx2c4.com/wireguard/tun/netstack" "github.com/amnezia-vpn/amneziawg-go/tun/netstack"
) )
func main() { func main() {
tun, tnet, err := netstack.CreateNetTUN( tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420, 1420,
) )
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b listen_port=58120
endpoint=163.172.161.0:12912 public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
allowed_ip=0.0.0.0/0 allowed_ip=192.168.4.28/32
persistent_keepalive_interval=25 persistent_keepalive_interval=25
`) `)
dev.Up() dev.Up()

Some files were not shown because too many files have changed in this diff Show more