Compare commits

..

737 commits

Author SHA1 Message Date
pokamest
27e661d68e
Merge pull request #70 from marko1777/junk-improvements
Junk improvements
2025-04-07 15:31:41 +01:00
Mark Puha
71be0eb3a6 faster and more secure junk creation 2025-03-18 08:34:23 +01:00
pokamest
e3f1273f8a
Merge pull request #64 from drkivi/master
Patch for golang crypto and net submodules
2025-02-18 11:50:35 +00:00
drkivi
c97b5b7615
Update go.sum
Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:58 +03:30
drkivi
668ddfd455
Update go.mod
Submodules Version Up

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:17 +03:30
drkivi
b8da08c106
Update Dockerfile
golang -> 1.23.6
AWGTOOLS_RELEASE -> 1.0.20241018

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:43:02 +03:30
Iurii Egorov
2e3f7d122c Update Go version in Dockerfile 2024-07-01 13:47:44 +03:00
Iurii Egorov
2e7780471a
Remove GetOffloadInfo() (#32)
* Remove GetOffloadInfo()
* Remove GetOffloadInfo() from bind_windows as well
* Allow lightweight tags to be used in the version
2024-05-24 16:18:23 +01:00
albexk
87d8c00f86 Up go to 1.22.3, up crypto to 0.21.0 2024-05-21 08:09:58 -07:00
albexk
c00bda9200 Fix output of the version command 2024-05-14 03:51:01 -07:00
albexk
d2b0fc9789 Add resetting of message types when closing the device 2024-05-14 03:51:01 -07:00
albexk
77d39ff3b9 Minor naming changes 2024-05-14 03:51:01 -07:00
albexk
e433d13df6 Add disabling UDP GSO when an error occurs due to inconsistent peer mtu 2024-05-14 03:51:01 -07:00
RomikB
3ddf952973 unsafe rebranding: change pipe name 2024-05-13 11:10:42 -07:00
albexk
3f0a3bcfa0 Fix wg reconnection problem after awg connection 2024-03-16 14:16:13 +00:00
AlexanderGalkov
4dddf62e57 Update Dockerfile
add wg and wg-quick symlinks

Signed-off-by: AlexanderGalkov <143902290+AlexanderGalkov@users.noreply.github.com>
2024-02-20 20:32:38 +07:00
tiaga
827ec6e14b
Merge pull request #21 from amnezia-vpn/fix-dockerfile
Fix Dockerfile
2024-02-13 21:47:55 +07:00
tiaga
92e28a0d14 Fix Dockerfile
Fix AmneziaWG tools installation.
2024-02-13 21:44:41 +07:00
tiaga
52fed4d362
Merge pull request #20 from amnezia-vpn/update_dockerfile
Update Dockerfile
2024-02-13 21:28:17 +07:00
tiaga
9c6b3ff332 Update Dockerfile
- rename `wg` and `wg-quick` to `awg` and `awg-quick` accordingly
- add iptables
- update AmneziaWG tools version
2024-02-13 21:27:34 +07:00
pokamest
7de7a9a754
Merge pull request #19 from amnezia-vpn/fix/go_sum
Fix go.sum
2024-02-12 05:31:57 -08:00
albexk
0c347529b8 Fix go.sum 2024-02-12 16:27:56 +03:00
albexk
6705978fc8 Add debug udp offload info 2024-02-12 04:40:23 -08:00
albexk
032e33f577 Fix Android UDP GRO check 2024-02-12 04:40:23 -08:00
albexk
59101fd202 Bump crypto, net, sys modules to the latest versions 2024-02-12 04:40:23 -08:00
tiaga
8bcfbac230
Merge pull request #17 from amnezia-vpn/fix-pipeline
Fix pipeline
2024-02-07 19:19:11 +07:00
tiaga
f0dfb5eacc Fix pipeline
Fix path to GitHub Actions workflow.
2024-02-07 18:53:55 +07:00
tiaga
9195025d8f
Merge pull request #16 from amnezia-vpn/pipeline
Add pipeline
2024-02-07 18:48:12 +07:00
tiaga
cbd414dfec Add pipeline
Build and push Docker image on a tag push.
2024-02-07 18:44:59 +07:00
tiaga
7155d20913
Merge pull request #14 from amnezia-vpn/docker
Update Dockerfile
2024-02-02 23:40:39 +07:00
tiaga
bfeb3954f6 Update Dockerfile
- update Alpine version
- improve `Dockerfile` to use pre-built AmneziaWG tools
2024-02-02 22:56:00 +07:00
Iurii Egorov
e3c9ec8012 Naming unify 2024-01-20 23:18:10 +03:00
pokamest
ce9d3866a3
Merge pull request #10 from amnezia-vpn/upstream-merge
Merge upstream changes
2024-01-14 13:37:00 -05:00
Iurii Egorov
e5f355e843 Fix incorrect configuration handling for zero-valued Jc 2024-01-14 18:22:02 +03:00
Iurii Egorov
c05b2ee2a3 Merge remote-tracking branch 'upstream/master' into upstream-merge 2024-01-09 21:34:30 +03:00
tiaga
180c9284f3
Merge pull request #7 from amnezia-vpn/docker
Add Dockerfile
2023-12-19 18:50:48 +07:00
tiaga
015e11875d Add Dockerfile
Build Docker image with the corresponding wg-tools version.
2023-12-19 18:46:04 +07:00
Martin Basovnik
12269c2761 device: fix possible deadlock in close method
There is a possible deadlock in `device.Close()` when you try to close
the device very soon after its start. The problem is that two different
methods acquire the same locks in different order:

1. device.Close()
 - device.ipcMutex.Lock()
 - device.state.Lock()

2. device.changeState(deviceState)
 - device.state.Lock()
 - device.ipcMutex.Lock()

Reproducer:

    func TestDevice_deadlock(t *testing.T) {
    	d := randDevice(t)
    	d.Close()
    }

Problem:

    $ go clean -testcache && go test -race -timeout 3s -run TestDevice_deadlock ./device | grep -A 10 sync.runtime_SemacquireMutex
    sync.runtime_SemacquireMutex(0xc000117d20?, 0x94?, 0x0?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    golang.zx2c4.com/wireguard/device.(*Device).Close(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:373 +0xb6
    golang.zx2c4.com/wireguard/device.TestDevice_deadlock(0x0?)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device_test.go:480 +0x2c
    testing.tRunner(0xc00014c000, 0x131d7b0)
    --
    sync.runtime_SemacquireMutex(0xc000130564?, 0x60?, 0xc000130548?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    sync.(*RWMutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/rwmutex.go:147 +0x45
    golang.zx2c4.com/wireguard/device.(*Device).upLocked(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:179 +0x72
    golang.zx2c4.com/wireguard/device.(*Device).changeState(0xc000130500, 0x1)

Signed-off-by: Martin Basovnik <martin.basovnik@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:38:47 +01:00
Jason A. Donenfeld
542e565baa device: do atomic 64-bit add outside of vector loop
Only bother updating the rxBytes counter once we've processed a whole
vector, since additions are atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:35:57 +01:00
Jordan Whited
7c20311b3d device: reduce redundant per-packet overhead in RX path
Peer.RoutineSequentialReceiver() deals with packet vectors and does not
need to perform timer and endpoint operations for every packet in a
given vector. Changing these per-packet operations to per-vector
improves throughput by as much as 10% in some environments.

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
4ffa9c2032 device: change Peer.endpoint locking to reduce contention
Access to Peer.endpoint was previously synchronized by Peer.RWMutex.
This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the
sole caller of Endpoint.ClearSrc(), which is signaled via a new bool,
Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now
set this bool, primarily via peer.markEndpointSrcForClearing().
Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an
updated conn.Endpoint is stored. This maintains the same event order as
before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx
is set, but before the next Peer.SendBuffers() call results in the
latest conn.Endpoint source being used for the next packet transmission.

These changes result in throughput improvements for single flow,
parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP
tests as measured on both Linux and Windows. Latency under load improves
especially for high throughput Linux scenarios. These improvements are
likely realized on all platforms to some degree, as the changes are not
platform-specific.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
d0bc03c707 tun: implement UDP GSO/GRO for Linux
Implement UDP GSO and GRO for the Linux tun.Device, which is made
possible by virtio extensions in the kernel's TUN driver starting in
v6.2.

secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is
used to demonstrate the effect of this commit between two Linux
computers with i5-12400 CPUs. There is roughly ~13us of round trip
latency between them. secnetperf was invoked with the following command
line options:
-stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0

The first result is from commit 2e0774f without UDP GSO/GRO on the TUN.

[conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us
SendTotalPackets=55859 SendSuspectedLostPackets=61
SendSpuriousLostPackets=59 SendCongestionCount=27
SendEcnCongestionCount=0 RecvTotalPackets=2779122
RecvReorderedPackets=0 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms).

The second result is with UDP GSO/GRO on the TUN.

[conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us
SendTotalPackets=165033 SendSuspectedLostPackets=64
SendSpuriousLostPackets=61 SendCongestionCount=53
SendEcnCongestionCount=0 RecvTotalPackets=11845268
RecvReorderedPackets=25267 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms).

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:27:22 +01:00
Jordan Whited
1cf89f5339 tun: fix Device.Read() buf length assumption on Windows
The length of a packet read from the underlying TUN device may exceed
the length of a supplied buffer when MTU exceeds device.MaxMessageSize.

Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:20:49 +01:00
pokamest
b43118018e
Merge pull request #4 from amnezia-vpn/upstream-merge
Upstream merge
2023-11-30 10:49:31 -08:00
Iurii Egorov
7af55a3e6f Merge remote-tracking branch 'upstream/master' 2023-11-17 22:42:13 +03:00
pokamest
c493b95f66
Update README.md
Signed-off-by: pokamest <pokamest@gmail.com>
2023-10-25 22:41:33 +01:00
Jason A. Donenfeld
2e0774f246 device: ratchet up max segment size on android
GRO requires big allocations to be efficient. This isn't great, as there
might be Android memory usage issues. So we should revisit this commit.
But at least it gets things working again.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-22 02:12:13 +02:00
Jason A. Donenfeld
b3df23dcd4 conn: set unused OOB to zero length
Otherwise in the event that we're using GSO without sticky sockets, we
pass garbage OOB buffers to sendmmsg, making a EINVAL, when GSO doesn't
set its header.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 19:32:07 +02:00
Jason A. Donenfeld
f502ec3fad conn: fix cmsg data padding calculation for gso
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 19:06:38 +02:00
Jason A. Donenfeld
5d37bd24e1 conn: separate gso and sticky control
Android wants GSO but not sticky.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-21 18:44:01 +02:00
Jason A. Donenfeld
24ea13351e conn: harmonize GOOS checks between "linux" and "android"
Otherwise GRO gets enabled on Android, but the conn doesn't use it,
resulting in bundled packets being discarded.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-18 21:14:13 +02:00
Jason A. Donenfeld
177caa7e44 conn: simplify supportsUDPOffload
This allows a kernel to support UDP_GRO while not supporting
UDP_SEGMENT.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-18 21:02:52 +02:00
Mazay B
b81ca925db peer.device.aSecMux.RLock added 2023-10-14 11:42:30 +01:00
James Tucker
42ec952ead go.mod,tun/netstack: bump gvisor
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:37:17 +02:00
James Tucker
ec8f6f82c2 tun: fix crash when ForceMTU is called after close
Close closes the events channel, resulting in a panic from send on
closed channel.

Reported-By: Brad Fitzpatrick <brad@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:37:17 +02:00
Jordan Whited
1ec454f253 device: move Queue{In,Out}boundElement Mutex to container type
Queue{In,Out}boundElement locking can contribute to significant
overhead via sync.Mutex.lockSlow() in some environments. These types
are passed throughout the device package as elements in a slice, so
move the per-element Mutex to a container around the slice.

Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
8a015f7c76 tun: reduce redundant checksumming in tcpGRO()
IPv4 header and pseudo header checksums were being computed on every
merge operation. Additionally, virtioNetHdr was being written at the
same time. This delays those operations until after all coalescing has
occurred.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
895d6c23cd tun: unwind summing loop in checksumNoFold()
$ benchstat old.txt new.txt
goos: linux
goarch: amd64
pkg: golang.zx2c4.com/wireguard/tun
cpu: 12th Gen Intel(R) Core(TM) i5-12400
                 │   old.txt    │               new.txt               │
                 │    sec/op    │   sec/op     vs base                │
Checksum/64-12     10.670n ± 2%   4.769n ± 0%  -55.30% (p=0.000 n=10)
Checksum/128-12    19.665n ± 2%   8.032n ± 0%  -59.16% (p=0.000 n=10)
Checksum/256-12     37.68n ± 1%   16.06n ± 0%  -57.37% (p=0.000 n=10)
Checksum/512-12     76.61n ± 3%   32.13n ± 0%  -58.06% (p=0.000 n=10)
Checksum/1024-12   160.55n ± 4%   64.25n ± 0%  -59.98% (p=0.000 n=10)
Checksum/1500-12   231.05n ± 7%   94.12n ± 0%  -59.26% (p=0.000 n=10)
Checksum/2048-12    309.5n ± 3%   128.5n ± 0%  -58.48% (p=0.000 n=10)
Checksum/4096-12    603.8n ± 4%   257.2n ± 0%  -57.41% (p=0.000 n=10)
Checksum/8192-12   1185.0n ± 3%   515.5n ± 0%  -56.50% (p=0.000 n=10)
Checksum/9000-12   1328.5n ± 5%   564.8n ± 0%  -57.49% (p=0.000 n=10)
Checksum/9001-12   1340.5n ± 3%   564.8n ± 0%  -57.87% (p=0.000 n=10)
geomean             185.3n        77.99n       -57.92%

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
4201e08f1d device: distribute crypto work as slice of elements
After reducing UDP stack traversal overhead via GSO and GRO,
runtime.chanrecv() began to account for a high percentage (20% in one
environment) of perf samples during a throughput benchmark. The
individual packet channel ops with the crypto goroutines was the primary
contributor to this overhead.

Updating these channels to pass vectors, which the device package
already handles at its ends, reduced this overhead substantially, and
improved throughput.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is with UDP GSO and GRO, and with single element
channels.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

The second result is with channels updated to pass a slice of
elements.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  13.2 GBytes  11.3 Gbits/sec  182   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  13.2 GBytes  11.3 Gbits/sec  182   sender
[  5]   0.00-10.04  sec  13.2 GBytes  11.3 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
Jordan Whited
6a84778f2c conn, device: use UDP GSO and GRO on Linux
StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is
dependent on checksum offload support on the egress netdev. UDP GSO
will be disabled in the event sendmmsg() returns EIO, which is a strong
signal that the egress netdev does not support checksum offload.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is from commit 052af4a without UDP GSO or GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139   3.01 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139  sender
[  5]   0.00-10.04  sec  9.85 GBytes  8.42 Gbits/sec        receiver

The second result is with UDP GSO and GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-10-10 15:07:36 +02:00
pokamest
b34974c476
Merge pull request #3 from amnezia-vpn/bugfix/uapi_adv_sec_onoff
Manage advanced security via uapi
2023-10-09 06:07:20 -07:00
Mazay B
f30419e0d1 Manage advanced sec via uapi 2023-10-09 13:22:49 +01:00
Mark Puha
8f1a6a10b2
Advanced security (#2)
* Advanced security header layer & config
2023-10-05 21:41:27 +01:00
Dimitri Papadopoulos Orfanos
469159ecf7 netstack: fix typo
Signed-off-by: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-07-04 15:56:30 +02:00
Brad Fitzpatrick
6e755e132a all: adjust build tags for wasip1/wasm
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-07-04 15:54:42 +02:00
springhack
1f25eac395 conn: windows: add missing return statement in DstToString AF_INET
Signed-off-by: SpringHack <springhack@live.cn>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-06-27 18:02:50 +02:00
James Tucker
25eb973e00 conn: store IP_PKTINFO cmsg in StdNetendpoint src
Replace the src storage inside StdNetEndpoint with a copy of the raw
control message buffer, to reduce allocation and perform less work on a
per-packet basis.

Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-06-27 17:48:32 +02:00
James Tucker
b7cd547315 device: wait for and lock ipc operations during close
If an IPC operation is in flight while close starts, it is possible for
both processes to deadlock. Prevent this by taking the IPC lock at the
start of close and for the duration.

Signed-off-by: James Tucker <jftucker@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-06-27 17:43:35 +02:00
Jordan Whited
052af4a807 tun: use correct IP header comparisons in tcpGRO() and tcpPacketsCanCoalesce()
tcpGRO() was using an incorrect IPv4 more fragments bit mask.

tcpPacketsCanCoalesce() was not distinguishing tcp6 from tcp4, and TTL
values were not compared. TTL values should be equal at the IP layer,
otherwise the packets should not coalesce. This tracks with the kernel.

Reviewed-by: Denton Gentry <dgentry@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-25 23:13:38 +01:00
Jordan Whited
aad7fca9c5 tun: disqualify tcp4 packets w/IP options from coalescing
IP options were not being compared prior to coalescing. They are not
commonly used. Disqualification due to nonzero options is in line with
the kernel.

Reviewed-by: Denton Gentry <dgentry@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-25 23:13:26 +01:00
Jason A. Donenfeld
6f895be10d conn: move booleans to bottom of StdNetBind struct
This results in a more compact structure.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-24 17:05:07 +01:00
Jason A. Donenfeld
6a07b2a355 conn: use ipv6 message pool for ipv6 receiving
Looks like a simple copy&paste error.

Fixes: 9e2f386 ("conn, device, tun: implement vectorized I/O on Linux")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-24 16:20:16 +01:00
Jordan Whited
334b605e72 conn: fix StdNetEndpoint data race by dynamically allocating endpoints
In 9e2f386 ("conn, device, tun: implement vectorized I/O on Linux"), the
Linux-specific Bind implementation was collapsed into StdNetBind. This
introduced a race on StdNetEndpoint from getSrcFromControl() and
setSrcControl().

Remove the sync.Pool involved in the race, and simplify StdNetBind's
receive path to allocate StdNetEndpoint on the heap instead, with the
intent for it to be cleaned up by the GC, later. This essentially
reverts ef5c587 ("conn: remove the final alloc per packet receive"),
adding back that allocation, unfortunately.

This does slightly increase resident memory usage in higher throughput
scenarios. StdNetBind is the only Bind implementation that was using
this Endpoint recycling technique prior to this commit.

This is considered a stop-gap solution, and there are plans to replace
the allocation with a better mechanism.

Reported-by: lsc <lsc@lv6.tw>
Link: https://lore.kernel.org/wireguard/ac87f86f-6837-4e0e-ec34-1df35f52540e@lv6.tw/
Fixes: 9e2f386 ("conn, device, tun: implement vectorized I/O on Linux")
Cc: Josh Bleecher Snyder <josharian@gmail.com>
Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-24 14:37:13 +01:00
Jason A. Donenfeld
3a9e75374f conn: disable sticky sockets on Android
We can't have the netlink listener socket, so it's not possible to
support it. Plus, android networking stack complexity makes it a bit
tricky anyway, so best to leave it disabled.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-23 18:39:00 +01:00
Jason A. Donenfeld
cc20c08c96 global: remove old style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-23 18:34:09 +01:00
Jordan Whited
1417a47c8f tun: replace ErrorBatch() with errors.Join()
Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-17 15:18:04 +01:00
Jordan Whited
7f511c3bb1 go.mod: bump to Go 1.20
Reviewed-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-17 15:18:04 +01:00
Jordan Whited
07a1e55270 conn: fix getSrcFromControl() iteration
We only expect a single control message in the normal case, but this
would loop infinitely if there were more.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-16 17:45:41 +01:00
Jordan Whited
fff53afca7 conn: use CmsgSpace() for ancillary data buf sizing
CmsgLen() does not account for data alignment.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-16 17:45:41 +01:00
Jason A. Donenfeld
0ad14a89f5 global: buff -> buf
This always struck me as kind of weird and non-standard.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-13 17:55:53 +01:00
Jason A. Donenfeld
7d327ed35a conn: use right cmsghdr len types on 32-bit in sticky test
Cmsghdr uses uint32 and uint64 on 32-bit and 64-bit respectively for the
Len member, which makes assignments and comparisons slightly more
irksome than usual.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 16:19:18 +01:00
Jordan Whited
f41f474466 conn: make StdNetBind.BatchSize() return 1 for non-Linux
This commit updates StdNetBind.BatchSize() to return 1 instead of
IdealBatchSize for non-Linux platforms. Non-Linux platforms do not
yet benefit from values > 1, which only serves to increase memory
consumption.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:53:09 +01:00
Jordan Whited
5819c6af28 tun/netstack: enable TCP Selective Acknowledgements
Enable TCP SACK for the gVisor Stack used in tun/netstack. This can
improve throughput by an order of magnitude in the presence of packet
loss.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:39 +01:00
Jordan Whited
6901984f6a conn: ensure control message size is respected in StdNetBind
This commit re-slices received control messages in StdNetBind to the
value the OS reports on a successful read. Previously, the len of this
slice would always be srcControlSize, which could result in control
message values leaking through a sync.Pool round trip. This is
unlikely with the IP_PKTINFO socket option set successfully, but
should be guarded against.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:32 +01:00
Jordan Whited
2fcdaf9799 conn: fix StdNetBind fallback on Windows
If RIO is unavailable, NewWinRingBind() falls back to StdNetBind.
StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving
datagrams, specifically via the {Read,Write}Batch methods.
These methods are unimplemented on Windows and will return runtime
errors as a result. Additionally, only Linux benefits from these
x/net types for reading and writing, so we update StdNetBind to fall
back to the standard library net package for all platforms other than
Linux.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:24 +01:00
Jason A. Donenfeld
dbd949307e conn: inch BatchSize toward being non-dynamic
There's not really a use at the moment for making this configurable, and
once bind_windows.go behaves like bind_std.go, we'll be able to use
constants everywhere. So begin that simplification now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:22 +01:00
Jordan Whited
f26efb65f2 conn: set SO_{SND,RCV}BUF to 7MB on the Bind UDP socket
The conn.Bind UDP sockets' send and receive buffers are now being sized
to 7MB, whereas they were previously inheriting the system defaults.
The system defaults are considerably small and can result in dropped
packets on high speed links. By increasing the size of these buffers we
are able to achieve higher throughput in the aforementioned case.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with 32-core Xeon Platinum CPUs @ 2.9Ghz. There is
roughly ~125us of round trip latency between them.

The first result is from commit 792b49c which uses the system defaults,
e.g. net.core.{r,w}mem_max = 212992. The TCP retransmits are correlated
with buffer full drops on both sides.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  4.74 GBytes  4.08 Gbits/sec  2742   285 KBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  4.74 GBytes  4.08 Gbits/sec  2742   sender
[  5]   0.00-10.04  sec  4.74 GBytes  4.06 Gbits/sec         receiver

The second result is after increasing SO_{SND,RCV}BUF to 7MB, i.e.
applying this commit.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  6.14 GBytes  5.27 Gbits/sec    0   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  6.14 GBytes  5.27 Gbits/sec    0    sender
[  5]   0.00-10.04  sec  6.14 GBytes  5.25 Gbits/sec         receiver

The specific value of 7MB is chosen as it is the max supported by a
default configuration of macOS. A value greater than 7MB may further
benefit throughput for environments with higher network latency and
lower CPU clocks, but will also increase latency under load
(bufferbloat). Some platforms will silently clamp the value to other
maximums. On Linux, we use SO_{SND,RCV}BUFFORCE in case 7MB is beyond
net.core.{r,w}mem_max.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:20 +01:00
Jason A. Donenfeld
f67c862a2a go.mod: bump deps
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:18 +01:00
Jordan Whited
9e2f386022 conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.

Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.

Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:17 +01:00
Jordan Whited
3bb8fec7e4 conn, device, tun: implement vectorized I/O plumbing
Accept packet vectors for reading and writing in the tun.Device and
conn.Bind interfaces, so that the internal plumbing between these
interfaces now passes a vector of packets. Vectors move untouched
between these interfaces, i.e. if 128 packets are received from
conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is
no internal buffering.

Currently, existing implementations are only adjusted to have vectors
of length one. Subsequent patches will improve that.

Also, as a related fixup, use the unix and windows packages rather than
the syscall package when possible.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:13 +01:00
Jason A. Donenfeld
21636207a6 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-23 19:12:33 +01:00
Jason A. Donenfeld
c7b76d3d9e device: uniformly check ECDH output for zeros
For some reason, this was omitted for response messages.

Reported-by: z <dzm@unexpl0.red>
Fixes: 8c34c4c ("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-16 16:33:14 +01:00
Jordan Whited
1e2c3e5a3c tun: guard Device.Events() against chan writes
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-09 12:35:58 -03:00
Jason A. Donenfeld
ebbd4a4330 global: bump copyright year
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:39:29 -03:00
Soren L. Hansen
0ae4b3177c tun/netstack: make http examples communicate with each other
This seems like a much better demonstration as it removes the need for
external components.

Signed-off-by: Søren L. Hansen <sorenisanerd@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:38:19 -03:00
Colin Adler
077ce8ecab tun/netstack: bump gvisor
Bump gVisor to a recent known-good version.

Signed-off-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:10:52 -03:00
Jason A. Donenfeld
bb719d3a6e global: bump copyright year
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-20 17:21:32 +02:00
Colin Adler
fde0a9525a tun/netstack: ensure (*netTun).incomingPacket chan is closed
Without this, `device.Close()` will deadlock.

Signed-off-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-20 17:17:29 +02:00
Brad Fitzpatrick
b51010ba13 all: use Go 1.19 and its atomic types
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-04 12:57:30 +02:00
Jason A. Donenfeld
d1d08426b2 tun/netstack: remove separate module
Now that the gvisor deps aren't insane, we can just do this in the main
module.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-08-29 12:14:05 -04:00
Shengjing Zhu
3381e21b18 tun/netstack: bump to latest gvisor
To build with go1.19, gvisor needs
99325baf ("Bump gVisor build tags to go1.19").

However gvisor.dev/gvisor/pkg/tcpip/buffer is no longer available,
so refactor to use gvisor.dev/gvisor/pkg/tcpip/link/channel directly.

Signed-off-by: Shengjing Zhu <i@zhsj.me>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-08-29 12:01:05 -04:00
Brad Fitzpatrick
c31a7b1ab4 conn, device, tun: set CLOEXEC on fds
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-07-04 01:42:12 +02:00
Tobias Klauser
6a08d81f6b tun: use ByteSliceToString from golang.org/x/sys/unix
Use unix.ByteSliceToString in (*NativeTun).nameSlice to convert the
TUNGETIFF ioctl result []byte to a string.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-06-01 15:00:07 +02:00
Josh Bleecher Snyder
ef5c587f78 conn: remove the final alloc per packet receive
This does bind_std only; other platforms remain.

The remaining alloc per iteration in the Throughput benchmark
comes from the tuntest package, and should not appear in regular use.

name           old time/op      new time/op      delta
Latency-10         25.2µs ± 1%      25.0µs ± 0%   -0.58%  (p=0.006 n=10+10)
Throughput-10      2.44µs ± 3%      2.41µs ± 2%     ~     (p=0.140 n=10+8)

name           old alloc/op     new alloc/op     delta
Latency-10           854B ± 5%        741B ± 3%  -13.22%  (p=0.000 n=10+10)
Throughput-10        265B ±34%        267B ±39%     ~     (p=0.670 n=10+10)

name           old allocs/op    new allocs/op    delta
Latency-10           16.0 ± 0%        14.0 ± 0%  -12.50%  (p=0.000 n=10+10)
Throughput-10        2.00 ± 0%        1.00 ± 0%  -50.00%  (p=0.000 n=10+10)

name           old packet-loss  new packet-loss  delta
Throughput-10        0.01 ±82%       0.01 ±282%     ~     (p=0.321 n=9+8)

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-04-07 03:31:10 +02:00
Jason A. Donenfeld
193cf8d6a5 conn: use netip for std bind
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-17 22:23:02 -06:00
Jason A. Donenfeld
ee1c8e0e87 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 21:32:14 -06:00
Jason A. Donenfeld
95b48cdb39 tun/netstack: bump mod
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 18:01:34 -06:00
Jason A. Donenfeld
5aff28b14c mod: bump packages and remove compat netip
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 17:51:47 -06:00
Josh Bleecher Snyder
46826fc4e5 all: use any in place of interface{}
Enabled by using Go 1.18. A bit less verbose.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2022-03-16 16:40:24 -07:00
Josh Bleecher Snyder
42c9af45e1 all: update to Go 1.18
Bump go.mod and README.

Switch to upstream net/netip.

Use strings.Cut.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2022-03-16 16:09:48 -07:00
Alexander Neumann
ae6bc4dd64 tun/netstack: check error returned by SetDeadline()
Signed-off-by: Alexander Neumann <alexander.neumann@redteam-pentesting.de>
[Jason: don't wrap deadline error.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-09 18:27:36 -07:00
Alexander Neumann
2cec4d1a62 tun/netstack: update to latest wireguard-go
This commit fixes all callsites of netip.AddrFromSlice(), which has
changed its signature and now returns two values.

Signed-off-by: Alexander Neumann <alexander.neumann@redteam-pentesting.de>
[Jason: remove error handling from AddrFromSlice.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-09 18:27:36 -07:00
Jason A. Donenfeld
3b95c81cc1 tun/netstack: simplify read timeout on ping socket
I'm not 100% sure this is correct, but it certainly is a lot simpler.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-02-02 23:30:31 +01:00
Thomas H. Ptacek
b9669b734e tun/netstack: implement ICMP ping
Provide a PacketConn interface for netstack's ICMP endpoint; netstack
currently only provides EchoRequest/EchoResponse ICMP support, so this
code exposes only an interface for doing ping.

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
[Jason: rework structure, match std go interfaces, add example code]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-02-02 23:09:37 +01:00
Jason A. Donenfeld
e0b8f11489 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-01-17 17:37:42 +01:00
Jason A. Donenfeld
114a3db918 ipc: bsd: try again if kqueue returns EINTR
Reported-by: J. Michael McAtee <mmcatee@jumptrading.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-01-14 16:10:43 +01:00
Jason A. Donenfeld
9c9e7e2724 global: apply gofumpt
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-12-09 23:15:55 +01:00
Jason A. Donenfeld
2dd424e2d8 device: handle peer post config on blank line
We missed a function exit point. This was exacerbated by e3134bf
("device: defer state machine transitions until configuration is
complete"), but the bug existed prior. Minus provided the following
useful reproducer script:

    #!/usr/bin/env bash

    set -eux

    make wireguard-go || exit 125

    ip netns del test-ns || true
    ip netns add test-ns
    ip link add test-kernel type wireguard
    wg set test-kernel listen-port 0 private-key <(echo "QMCfZcp1KU27kEkpcMCgASEjDnDZDYsfMLHPed7+538=") peer "eDPZJMdfnb8ZcA/VSUnLZvLB2k8HVH12ufCGa7Z7rHI=" allowed-ips 10.51.234.10/32
    ip link set test-kernel netns test-ns up
    ip -n test-ns addr add 10.51.234.1/24 dev test-kernel
    port=$(ip netns exec test-ns wg show test-kernel listen-port)

    ip link del test-go || true
    ./wireguard-go test-go
    wg set test-go private-key <(echo "WBM7qimR3vFk1QtWNfH+F4ggy/hmO+5hfIHKxxI4nF4=") peer "+nj9Dkqpl4phsHo2dQliGm5aEiWJJgBtYKbh7XjeNjg=" allowed-ips 0.0.0.0/0 endpoint 127.0.0.1:$port
    ip addr add 10.51.234.10/24 dev test-go
    ip link set test-go up

    ping -c2 -W1 10.51.234.1

Reported-by: minus <minus@mnus.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-29 12:31:54 -05:00
Josh Bleecher Snyder
387f7c461a device: reduce peer lock critical section in UAPI
The deferred RUnlock calls weren't executing until all peers
had been processed. Add an anonymous function so that each
peer may be unlocked as soon as it is completed.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Josh Bleecher Snyder
4d87c9e824 device: remove code using unsafe
There is no performance impact.

name                             old time/op  new time/op  delta
TrieIPv4Peers100Addresses1000-8  78.6ns ± 1%  79.4ns ± 3%    ~     (p=0.604 n=10+9)
TrieIPv4Peers10Addresses10-8     29.1ns ± 2%  28.8ns ± 1%  -1.12%  (p=0.014 n=10+9)
TrieIPv6Peers100Addresses1000-8  78.9ns ± 1%  78.6ns ± 1%    ~     (p=0.492 n=10+10)
TrieIPv6Peers10Addresses10-8     29.3ns ± 2%  28.6ns ± 2%  -2.16%  (p=0.000 n=10+10)

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Jason A. Donenfeld
ef8d6804d7 global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Jason A. Donenfeld
de7c702ace device: only propagate roaming value before peer is referenced elsewhere
A peer.endpoint never becomes nil after being not-nil, so creation is
the only time we actually need to set this. This prevents a race from
when the variable is actually used elsewhere, and allows us to avoid an
expensive atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:16:04 +01:00
Jason A. Donenfeld
fc4f975a4d device: align 64-bit atomic member in Device
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
Jason A. Donenfeld
9d699ba730 device: start peers before running handshake test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
Jason A. Donenfeld
425f7c726b Makefile: don't use test -v because it hides failures in scrollback
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
David Anderson
3cae233d69 device: fix nil pointer dereference in uapi read
Signed-off-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 20:43:26 +01:00
Jason A. Donenfeld
111e0566dc device: make new peers inherit broken mobile semantics
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
e3134bf665 device: defer state machine transitions until configuration is complete
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
63abb5537b device: do not consume handshake messages if not running
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
851efb1bb6 tun: move wintun to its own repo
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-04 12:53:55 +01:00
Jason A. Donenfeld
c07dd60cdb namedpipe: rename from winpipe to keep in sync with CL299009
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-04 12:53:52 +01:00
Jason A. Donenfeld
eb6302c7eb device: timers: use pre-seeded per-thread unlocked fastrandn for jitter
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-28 13:47:50 +02:00
Jason A. Donenfeld
60683d7361 device: timers: seed unsafe rng before use for jitter
Forgetting to seed the unsafe rng, the jitter before followed a fixed
pattern, which didn't help when a fleet of computers all boot at once.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-28 13:34:21 +02:00
Jason A. Donenfeld
e42c6c4bc2 wintun: align 64-bit argument on ARM32
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-26 14:53:40 +02:00
Jason A. Donenfeld
828a885a71 README: raise minimum Go to 1.17
Suggested-by: Adam Bliss <abliss@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-25 17:53:11 +02:00
Mikael Magnusson
f1f626090e tun/netstack: update gvisor
Update gvisor to v0.0.0-20211020211948-f76a604701b6, which requires some
changes to tun.go:

WriteRawPacket: Add function with not implemented error.

CreateNetTUN: Replace stack.AddAddress with stack.AddProtocolAddress, and
fix IPv6 address in error message.

Signed-off-by: Mikael Magnusson <mikma@users.sourceforge.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-22 13:22:29 -06:00
Brad Fitzpatrick
82e0b734e5 ipc, rwcancel: compile on js/wasm
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-10-20 14:50:05 -06:00
Jason A. Donenfeld
fdf57a1fa4 wintun: allow retrieving DLL version
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-20 12:13:44 -06:00
Jason A. Donenfeld
f87e87af0d version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 23:27:13 -06:00
Jason A. Donenfeld
ba9e364dab wintun: remove memmod option for dll loading
Only wireguard-windows used this, and it's moving to wgnt exclusively.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 22:49:38 -06:00
Jason A. Donenfeld
dfd688b6aa global: remove old-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 12:02:10 -06:00
Jason A. Donenfeld
c01d52b66a global: add newer-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 11:46:53 -06:00
Jason A. Donenfeld
82d2aa87aa wintun: use new swdevice-based API for upcoming Wintun 0.14
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 00:26:46 -06:00
Jason A. Donenfeld
982d5d2e84 conn,wintun: use unsafe.Slice instead of unsafeSlice
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:57:53 -06:00
Jason A. Donenfeld
642a56e165 memmod: import from wireguard-windows
We'll eventually be getting rid of it here, but keep it sync'd up for
now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:53:36 -06:00
Jason A. Donenfeld
bb745b2ea3 rwcancel: use unix.Poll again but bump x/sys so it uses ppoll under the hood
This reverts commit fcc601dbf0 but then
bumps go.mod.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-27 14:19:15 -06:00
Jason A. Donenfeld
fcc601dbf0 rwcancel: use ppoll on Linux for Android
This is a temporary measure while we wait for
https://go-review.googlesource.com/c/sys/+/352310 to land.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-26 17:16:38 -06:00
Tobias Klauser
217ac1016b tun: make operateonfd.go build tags more specific
(*NativeTun).operateOnFd is only used on darwin and freebsd. Adjust the
build tags accordingly.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:54:01 -06:00
Tobias Klauser
eae5e0f3a3 tun: avoid leaking sock fd in CreateTUN error cases
At these points, the socket file descriptor is not yet wrapped in an
*os.File, so it needs to be closed explicitly on error.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:53:49 -06:00
Jason A. Donenfeld
2ef39d4754 global: add new go 1.17 build comments
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-05 16:00:43 +02:00
Jason A. Donenfeld
3957e9b9dd memmod: register exception handler tables
Otherwise recent WDK binaries fail on ARM64, where an exception handler
is used for trapping an illegal instruction when ARMv8.1 atomics are
being tested for functionality.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-08-05 14:56:48 +02:00
Jason A. Donenfeld
bad6caeb82 memmod: fix protected delayed load the right way
The reason this was failing before is that dloadsup.h's
DloadObtainSection was doing a linear search of sections to find which
header corresponds with the IMAGE_DELAYLOAD_DESCRIPTOR section, and we
were stupidly overwriting the VirtualSize field, so the linear search
wound up matching the .text section, which then it found to not be
marked writable and failed with FAST_FAIL_DLOAD_PROTECTION_FAILURE.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:27:40 +02:00
Jason A. Donenfeld
c89f5ca665 memmod: disable protected delayed load for now
Probably a bad idea, but we don't currently support it, and those huge
windows.NewCallback trampolines make juicer targets anyway.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:13:03 +02:00
Jason A. Donenfeld
15b24b6179 ipc: allow admins but require high integrity label
Might be more reasonable.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-24 17:01:02 +02:00
Jason A. Donenfeld
f9b48a961c device: zero out allowedip node pointers when removing
This should make it a bit easier for the garbage collector.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-04 16:33:28 +02:00
Jason A. Donenfeld
d0cf96114f device: limit allowedip fuzzer a to 4 times through
Trying this for every peer winds up being very slow and precludes it
from acceptable runtime in the CI, so reduce this to 4.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 18:22:50 +02:00
Jason A. Donenfeld
841756e328 device: simplify allowedips lookup signature
The inliner should handle this for us.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
c382222eab device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
b41f4cc768 device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 15:08:42 +02:00
Jason A. Donenfeld
4a57024b94 device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 13:51:03 +02:00
Josh Bleecher Snyder
64cb82f2b3 go.mod: bump golang.org/x/sys again
To pick up https://go-review.googlesource.com/c/sys/+/307129.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-25 16:34:54 +02:00
Jason A. Donenfeld
c27ff9b9f6 device: allow reducing queue constants on iOS
Heavier network extensions might require the wireguard-go component to
use less ram, so let users of this reduce these as needed.

At some point we'll put this behind a configuration method of sorts, but
for now, just expose the consts as vars.

Requested-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-22 01:00:51 +02:00
Jason A. Donenfeld
99e8b4ba60 tun: linux: account for interface removal from outside
On Linux we can run `ip link del wg0`, in which case the fd becomes
stale, and we should exit. Since this is an intentional action, don't
treat it as an error.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:26:01 +02:00
Jason A. Donenfeld
bd83f0ac99 conn: linux: protect read fds
The -1 protection was removed and the wrong error was returned, causing
us to read from a bogus fd. As well, remove the useless closures that
aren't doing anything, since this is all synchronized anyway.

Fixes: 10533c3 ("all: make conn.Bind.Open return a slice of receive functions")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:09:55 +02:00
Jason A. Donenfeld
50d779833e rwcancel: use ordinary os.ErrClosed instead of custom error
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:56:36 +02:00
Jason A. Donenfeld
a9b377e9e1 rwcancel: use poll instead of select
Suggested-by: Lennart Poettering <lennart@poettering.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:42:34 +02:00
Jason A. Donenfeld
9087e444e6 device: optimize Peer.String even more
This reduces the allocation, branches, and amount of base64 encoding.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-18 17:43:53 +02:00
Josh Bleecher Snyder
25ad08a591 device: optimize Peer.String
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-05-14 00:37:30 +02:00
Jason A. Donenfeld
5846b62283 conn: windows: set count=0 on retry
When retrying, if count is not 0, we forget to dequeue another request,
and so the ring fills up and errors out.

Reported-by: Sascha Dierberg <dierberg@dresearch-fe.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-11 16:47:17 +02:00
Jason A. Donenfeld
9844c74f67 main: replace crlf on windows in fmt test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 22:23:32 +02:00
Jason A. Donenfeld
4e9e5dad09 main: check that code is formatted in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 17:48:26 +02:00
Jason A. Donenfeld
39e0b6dade tun: format
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:27 +02:00
Jason A. Donenfeld
7121927b87 device: add ID to repeated routines
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:21 +02:00
Jason A. Donenfeld
326aec10af device: remove unusual ... in messages
We dont use ... in any other present progressive messages except these.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:17:41 +02:00
Jason A. Donenfeld
efb8818550 device: avoid verbose log line during ordinary shutdown sequence
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:39:06 +02:00
Jason A. Donenfeld
69b39db0b4 tun: windows: set event before waiting
In 097af6e ("tun: windows: protect reads from closing") we made sure no
functions are running when End() is called, to avoid a UaF. But we still
need to kick that event somehow, so that Read() is allowed to exit, in
order to release the lock. So this commit calls SetEvent, while moving
the closing boolean to be atomic so it can be modified without locks,
and then moves to a WaitGroup for the RCU-like pattern.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:26:24 +02:00
Jason A. Donenfeld
db733ccd65 tun: windows: rearrange struct to avoid alignment trap on 32bit
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:19:00 +02:00
Jason A. Donenfeld
a7aec4449f tun: windows: check alignment in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:15:50 +02:00
Josh Bleecher Snyder
60a26371f4 device: log all errors received by RoutineReceiveIncoming
When debugging, it's useful to know why a receive func exited.

We were already logging that, but only in the "death spiral" case.
Move the logging up, to capture it always.
Reduce the verbosity, since it is not an error case any more.
Put the receive func name in the log line.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-06 11:22:13 +02:00
Jason A. Donenfeld
a544776d70 tun/netstack: update go mod and remove GSO argument
Reported-by: John Xiong <xiaoyang1258@yeah.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-06 11:07:26 +02:00
Jason A. Donenfeld
69a42a4eef tun: windows: send MTU update when forced MTU changes
Otherwise the padding doesn't get updated.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-05 11:42:45 +02:00
Jason A. Donenfeld
097af6e135 tun: windows: protect reads from closing
The code previously used the old errors channel for checking, rather
than the simpler boolean, which caused issues on shutdown, since the
errors channel was meaningless. However, looking at this exposed a more
basic problem: Close() and all the other functions that check the closed
boolean can race. So protect with a basic RW lock, to ensure that
Close() waits for all pending operations to complete.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:22:45 -04:00
Jason A. Donenfeld
8246d251ea conn: windows: do not error out when receiving UDP jumbogram
If we receive a large UDP packet, don't return an error to receive.go,
which then terminates the receive loop. Instead, simply retry.

Considering Winsock's general finickiness, we might consider other
places where an attacker on the wire can generate error conditions like
this.

Reported-by: Sascha Dierberg <sascha.dierberg@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:07:03 -04:00
Jason A. Donenfeld
c9db4b7aaa version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-24 13:07:27 -04:00
Jason A. Donenfeld
3625f8d284 tun: freebsd: avoid OOB writes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:10:23 -06:00
Jason A. Donenfeld
0687dc06c8 tun: freebsd: become controlling process when reopening tun FD
When we pass the TUN FD to the child, we have to call TUNSIFPID;
otherwise when we close the device, we get a splat in dmesg.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:02:44 -06:00
Jason A. Donenfeld
71aefa374d tun: freebsd: restructure and cleanup
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 14:54:59 -06:00
Jason A. Donenfeld
3d3e30beb8 tun: freebsd: remove horrific hack for getting tunnel name
As of FreeBSD 12.1, there's TUNGIFNAME.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 12:03:16 -06:00
Jason A. Donenfeld
b0e5b19969 tun: freebsd: set IFF_MULTICAST for routing daemons
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-18 20:09:04 -06:00
Jason A. Donenfeld
3988821442 main: print kernel warning on OpenBSD and FreeBSD too
More kernels!

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-15 23:35:45 -06:00
Jason A. Donenfeld
c7cd2c9eab device: don't defer unlocking from loop
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 16:19:35 -06:00
Jason A. Donenfeld
54dbe2471f conn: reconstruct v4 vs v6 receive function based on symtab
This is kind of gross but it's better than the alternatives.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 15:35:32 -06:00
Kristupas Antanavičius
d2fd0c0cc0 device: allocate new buffer in receive death spiral
Note: this bug is "hidden" by avoiding "death spiral" code path by
6228659 ("device: handle broader range of errors in RoutineReceiveIncoming").

If the code reached "death spiral" mechanism, there would be multiple
double frees happening. This results in a deadlock on iOS, because the
pools are fixed size and goroutine might stop until somebody makes
space in the pool.

This was almost 100% repro on the new ARM Macbooks:

- Build with 'ios' tag for Mac. This will enable bounded pools.
- Somehow call device.IpcSet at least couple of times (update config)
- device.BindUpdate() would be triggered
- RoutineReceiveIncoming would enter "death spiral".
- RoutineReceiveIncoming would stall on double free (pool is already
  full)
- The stuck routine would deadlock 'device.closeBindLocked()' function
  on line 'netc.stopping.Wait()'

Signed-off-by: Kristupas Antanavičius <kristupas.antanavicius@nordsec.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 11:14:53 -06:00
Jason A. Donenfeld
5f6bbe4ae8 conn: windows: reset ring to starting position after free
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 18:09:41 -06:00
Jason A. Donenfeld
75526d6071 conn: windows: compare head and tail properly
By not comparing these with the modulo, the ring became nearly never
full, resulting in completion queue buffers filling up prematurely.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Jason A. Donenfeld
fbf97502cf winrio: test that IOCP-based RIO is supported
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Josh Bleecher Snyder
10533c3e73 all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.

Beneficial consequences:

* If there's no IPv6 support on a system,
  conn.Bind.Open can choose not to return a receive function for it,
  which is simpler than tracking that state in the bind.
  This simplification removes existing data races from both
  conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
  the conn.Bind no longer needs to add a separate muxing layer.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-04-02 11:07:08 -06:00
Jason A. Donenfeld
8ed83e0427 conn: winrio: pass key parameter into struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-02 10:36:41 -06:00
Josh Bleecher Snyder
6228659a91 device: handle broader range of errors in RoutineReceiveIncoming
RoutineReceiveIncoming exits immediately on net.ErrClosed,
but not on other errors. However, for errors that are known
to be permanent, such as syscall.EAFNOSUPPORT,
we may as well exit immediately instead of retrying.

This considerably speeds up the package device tests right now,
because the Bind sometimes (incorrectly) returns syscall.EAFNOSUPPORT
instead of net.ErrClosed.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:41:43 -07:00
Josh Bleecher Snyder
517f0703f5 conn: document retry loop in StdNetBind.Open
It's not obvious on a first read what the loop is doing.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
204140016a conn: use local ipvN vars in StdNetBind.Open
This makes it clearer that they are fresh on each attempt,
and avoids the bookkeeping required to clearing them on failure.

Also, remove an unnecessary err != nil.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
822f5a6d70 conn: unify code in StdNetBind.Send
The sending code is identical for ipv4 and ipv6;
select the conn, then use it.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:32 -07:00
Josh Bleecher Snyder
02e419ed8a device: rename unsafeCloseBind to closeBindLocked
And document a bit.
This name is more idiomatic.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:07:12 -07:00
Jason A. Donenfeld
bc69a3fa60 version: bump snapshot 2021-03-23 13:07:19 -06:00
Jason A. Donenfeld
12ce53271b tun: freebsd: use broadcast mode instead of PPP mode
It makes the routing configuration simpler.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-23 12:41:34 -06:00
Jason A. Donenfeld
5f0c8b942d device: signal to close device in separate routine
Otherwise we wind up deadlocking.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-11 09:29:10 -07:00
Jason A. Donenfeld
c5f382624e tun: linux: do not spam events every second from hack listener
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-11 09:23:11 -07:00
Kay Diam
6005c573e2 tun: freebsd: allow empty names
This change allows omitting the tun interface name setting. When the
name is not set, the kernel automatically picks up the tun name and
index.

Signed-off-by: Kay Diam <kay.diam@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:32:27 -07:00
Jason A. Donenfeld
82f3e9e2af winpipe: move syscalls into x/sys
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:32:27 -07:00
Jason A. Donenfeld
4885e7c954 memmod: use resource functions from x/sys
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
497ba95de7 memmod: do not use IsBadReadPtr
It should be enough to check for the trailing zero name.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
0eb7206295 conn: linux: unexport mutex
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
20714ca472 mod: bump x/sys
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
c1e09f1927 mod: rename COPYING to LICENSE
Otherwise the netstack module doesn't show up on the package site.

https://github.com/golang/go/issues/43817#issuecomment-764987580

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-06 09:09:21 -07:00
Jason A. Donenfeld
79611c64e8 tun/netstack: bump deps and api
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-06 08:48:14 -07:00
Jason A. Donenfeld
593658d975 device: get rid of peers.empty boolean in timersActive
There's no way for len(peers)==0 when a current peer has
isRunning==false.

This requires some struct reshuffling so that the uint64 pointer is
aligned.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-06 08:44:38 -07:00
Jason A. Donenfeld
3c11c0308e conn: implement RIO for fast Windows UDP sockets
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-25 15:08:08 +01:00
Jason A. Donenfeld
f9dac7099e global: remove TODO name graffiti
Googlers have a habit of graffiting their name in TODO items that then
are never addressed, and other people won't go near those because
they're marked territory of another animal. I've been gradually cleaning
these up as I see them, but this commit just goes all the way and
removes the remaining stragglers.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
9a29ae267c device: test up/down using virtual conn
This prevents port clashing bugs.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
6603c05a4a device: cleanup unused test components
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
a4f8e83d5d conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
c69481f1b3 device: disable waitpool tests
This code is stable, and the test is finicky, especially on high core
count systems, so just disable it.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-22 15:26:47 +01:00
Brad Fitzpatrick
0f4809f366 tun: make NativeTun.Close well behaved, not crash on double close
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-02-22 15:26:29 +01:00
Brad Fitzpatrick
fecb8f482a README: bump document Go requirement to 1.16
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-02-22 15:26:29 +01:00
Jason A. Donenfeld
8bf4204d2e global: stop using ioutil
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-17 22:19:27 +01:00
Jason A. Donenfeld
4e439ea10e conn: bump to 1.16 and get rid of NetErrClosed hack
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-16 21:05:25 +01:00
Jason A. Donenfeld
7a0fb5bbb1 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-12 18:00:59 +01:00
Jason A. Donenfeld
c7b7998619 device: remove old version file
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-12 17:59:50 +01:00
Jason A. Donenfeld
ef8115f63b gitignore: remove old hacks
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-11 15:48:56 +01:00
Jason A. Donenfeld
75e6d810ed device: use container/list instead of open coding it
This linked list implementation is awful, but maybe Go 2 will help
eventually, and at least we're not open coding the hlist any more.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-10 18:19:11 +01:00
Jason A. Donenfeld
747f5440bc device: retry Up() in up/down test
We're loosing our ownership of the port when bringing the device down,
which means another test process could reclaim it. Avoid this by
retrying for 4 seconds.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-10 01:01:37 +01:00
Jason A. Donenfeld
aabc3770ba conn: close old fd before trying again
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-10 00:43:31 +01:00
Jason A. Donenfeld
484a9fd324 device: flush peer queues before starting device
In case some old packets snuck in there before, this flushes before
starting afresh.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-10 00:39:28 +01:00
Jason A. Donenfeld
5bf8d73127 device: create peer queues at peer creation time
Rather than racing with Start(), since we're never destroying these
queues, we just set the variables at creation time.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-10 00:21:12 +01:00
Jason A. Donenfeld
587a2b2a20 device: return error from Up() and Down()
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-10 00:12:23 +01:00
Jason A. Donenfeld
6f08a10041 rwcancel: add an explicit close call
This lets us collect FDs even if the GC doesn't do it for us.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 20:19:14 +01:00
Jason A. Donenfeld
a97ef39cd4 rwcancel: use errors.Is for unwrapping
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 19:54:00 +01:00
Jason A. Donenfeld
c040dea798 tun: use errors.Is for unwrapping
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 19:50:31 +01:00
Jason A. Donenfeld
5cdb862f15 conn: use errors.Is for unwrapping
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 19:46:57 +01:00
Jason A. Donenfeld
da32fe328b device: handshake routine writes into encryption queue
Since RoutineHandshake calls peer.SendKeepalive(), it potentially is a
writer into the encryption queue, so we need to bump the wg count.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 19:26:45 +01:00
Josh Bleecher Snyder
4eab21a7b7 device: make RoutineReadFromTUN keep encryption queue alive
RoutineReadFromTUN can trigger a call to SendStagedPackets.
SendStagedPackets attempts to protect against sending
on the encryption queue by checking peer.isRunning and device.isClosed.
However, those are subject to TOCTOU bugs.

If that happens, we get this:

goroutine 1254 [running]:
golang.zx2c4.com/wireguard/device.(*Peer).SendStagedPackets(0xc000798300)
        .../wireguard-go/device/send.go:321 +0x125
golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN(0xc000014780)
        .../wireguard-go/device/send.go:271 +0x21c
created by golang.zx2c4.com/wireguard/device.NewDevice
        .../wireguard-go/device/device.go:315 +0x298

Fix this with a simple, big hammer: Keep the encryption queue
alive as long as it might be written to.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 09:53:00 -08:00
Jason A. Donenfeld
30b96ba083 conn: try harder to have v4 and v6 ports agree
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 18:45:12 +01:00
Josh Bleecher Snyder
78ebce6932 device: only allocate peer queues once
This serves two purposes.

First, it makes repeatedly stopping then starting a peer cheaper.
Second, it prevents a data race observed accessing the queues.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 18:33:48 +01:00
Josh Bleecher Snyder
cae090d116 device: clarify device.state.state docs (again)
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 18:29:01 +01:00
Josh Bleecher Snyder
465261310b device: run fewer iterations in TestUpDown
The high iteration count was useful when TestUpDown
was the nexus of new bugs to investigate.

Now that it has stabilized, that's less valuable.
And it slows down running the tests and crowds out other tests.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 18:28:59 +01:00
Josh Bleecher Snyder
d117d42ae7 device: run fewer trials in TestWaitPool when race detector enabled
On a many-core machine with the race detector enabled,
this test can take several minutes to complete.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 18:28:58 +01:00
Josh Bleecher Snyder
ecceaadd16 device: remove nil elem check in finalizers
This is not necessary, and removing it speeds up detection of UAF bugs.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 18:28:55 +01:00
Jason A. Donenfeld
9e728c2eb0 device: rename unsafeRemovePeer to removePeerLocked
This matches the new naming scheme of upLocked and downLocked.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 16:11:33 +01:00
Jason A. Donenfeld
eaf664e4e9 device: remove deviceStateNew
It's never used and we won't have a use for it. Also, move to go-running
stringer, for those without GOPATHs.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:39:19 +01:00
Jason A. Donenfeld
a816e8511e device: fix comment typo and shorten state.mu.Lock to state.Lock
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
02138f1f81 device: fix typo in comment
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
d7bc7508e5 device: fix alignment on 32-bit machines and test for it
The test previously checked the offset within a substruct, not the
offset within the allocated struct, so this adds the two together.

It then fixes an alignment crash on 32-bit machines.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
d6e76fdbd6 device: do not log on idempotent device state change
Part of being actually idempotent is that we shouldn't penalize code
that takes advantage of this property with a log splat.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
6ac1240821 device: do not attach finalizer to non-returned object
Before, the code attached a finalizer to an object that wasn't returned,
resulting in immediate garbage collection. Instead return the actual
pointer.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
4b5d15ec2b device: lock elem in autodraining queue before freeing
Without this, we wind up freeing packets that the encryption/decryption
queues still have, resulting in a UaF.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
6548a682a9 device: remove listen port race in tests
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 15:37:04 +01:00
Jason A. Donenfeld
a60e6dab76 device: generate test keys on the fly
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-09 00:42:39 +01:00
Josh Bleecher Snyder
d8dd1f254f device: remove mutex from Peer send/receive
The immediate motivation for this change is an observed deadlock.

1. A goroutine calls peer.Stop. That calls peer.queue.Lock().
2. Another goroutine is in RoutineSequentialReceiver.
   It receives an elem from peer.queue.inbound.
3. The peer.Stop goroutine calls close(peer.queue.inbound),
   close(peer.queue.outbound), and peer.stopping.Wait().
   It blocks waiting for RoutineSequentialReceiver
   and RoutineSequentialSender to exit.
4. The RoutineSequentialReceiver goroutine calls peer.SendStagedPackets().
   SendStagedPackets attempts peer.queue.RLock().
   That blocks forever because the peer.Stop
   goroutine holds a write lock on that mutex.

A background motivation for this change is that it can be expensive
to have a mutex in the hot code path of RoutineSequential*.

The mutex was necessary to avoid attempting to send elems on a closed channel.
This commit removes that danger by never closing the channel.
Instead, we send a sentinel nil value on the channel to indicate
to the receiver that it should exit.

The only problem with this is that if the receiver exits,
we could write an elem into the channel which would never get received.
If it never gets received, it cannot get returned to the device pools.

To work around this, we use a finalizer. When the channel can be GC'd,
the finalizer drains any remaining elements from the channel and
restores them to the device pool.

After that change, peer.queue.RWMutex no longer makes sense where it is.
It is only used to prevent concurrent calls to Start and Stop.
Move it to a more sensible location and make it a plain sync.Mutex.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 13:02:52 -08:00
Josh Bleecher Snyder
57aadfcb14 device: create channels.go
We have a bunch of stupid channel tricks, and I'm about to add more.
Give them their own file. This commit is 100% code movement.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 12:38:19 -08:00
Josh Bleecher Snyder
af408eb940 device: print direction when ping transit fails
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 12:01:08 -08:00
Josh Bleecher Snyder
15810daa22 device: separate timersInit from timersStart
timersInit sets up the timers.
It need only be done once per peer.

timersStart does the work to prepare the timers
for a newly running peer. It needs to be done
every time a peer starts.

Separate the two and call them in the appropriate places.
This prevents data races on the peer's timers fields
when starting and stopping peers.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 10:32:07 -08:00
Josh Bleecher Snyder
d840445e9b device: don't track device interface state in RoutineTUNEventReader
We already track this state elsewhere. No need to duplicate.
The cost of calling changeState is negligible.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 10:32:07 -08:00
Josh Bleecher Snyder
675ff32e6c device: improve MTU change handling
The old code silently accepted negative MTUs.
It also set MTUs above the maximum.
It also had hard to follow deeply nested conditionals.

Add more paranoid handling,
and make the code more straight-line.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 10:32:07 -08:00
Josh Bleecher Snyder
3516ccc1e2 device: remove device.state.stopping from RoutineTUNEventReader
The TUN event reader does three things: Change MTU, device up, and device down.
Changing the MTU after the device is closed does no harm.
Device up and device down don't make sense after the device is closed,
but we can check that condition before proceeding with changeState.
There's thus no reason to block device.Close on RoutineTUNEventReader exiting.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 10:32:07 -08:00
Josh Bleecher Snyder
0bcb822e5b device: overhaul device state management
This commit simplifies device state management.
It creates a single unified state variable and documents its semantics.

It also makes state changes more atomic.
As an example of the sort of bug that occurred due to non-atomic state changes,
the following sequence of events used to occur approximately every 2.5 million test runs:

* RoutineTUNEventReader received an EventDown event.
* It called device.Down, which called device.setUpDown.
* That set device.state.changing, but did not yet attempt to lock device.state.Mutex.
* Test completion called device.Close.
* device.Close locked device.state.Mutex.
* device.Close blocked on a call to device.state.stopping.Wait.
* device.setUpDown then attempted to lock device.state.Mutex and blocked.

Deadlock results. setUpDown cannot progress because device.state.Mutex is locked.
Until setUpDown returns, RoutineTUNEventReader cannot call device.state.stopping.Done.
Until device.state.stopping.Done gets called, device.state.stopping.Wait is blocked.
As long as device.state.stopping.Wait is blocked, device.state.Mutex cannot be unlocked.
This commit fixes that deadlock by holding device.state.mu
when checking that the device is not closed.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 10:32:07 -08:00
Josh Bleecher Snyder
da95677203 device: remove unnecessary zeroing in peer.SendKeepalive
elem.packet is always already nil.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 10:14:17 -08:00
Josh Bleecher Snyder
9c75f58f3d device: remove device.state.stopping from RoutineHandshake
It is no longer necessary.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 08:18:32 -08:00
Josh Bleecher Snyder
84a42aed63 device: remove device.state.stopping from RoutineDecryption
It is no longer necessary, as of 454de6f3e64abd2a7bf9201579cd92eea5280996
(device: use channel close to shut down and drain decryption channel).

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 08:18:32 -08:00
Jason A. Donenfeld
4192036acd main: add back version file
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-04 15:33:04 +01:00
Jason A. Donenfeld
9c7bd73be2 tai64n: add string representation for error messages
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-03 17:56:46 +01:00
Jason A. Donenfeld
01e176af3c device: take peer handshake when reinitializing last sent handshake
This papers over other unrelated races, unfortunately.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-03 17:52:31 +01:00
Josh Bleecher Snyder
91617b4c52 device: fix goroutine leak test
The leak test had rare flakes.
If a system goroutine started at just the wrong moment, you'd get a false positive.
Instead of looping until the goroutines look good and then checking,
exit completely as soon as the number of goroutines looks good.
Also, check more frequently, in an attempt to complete faster.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-03 17:45:22 +01:00
Jason A. Donenfeld
7258a8973d device: add up/down stress test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-03 17:43:41 +01:00
Jason A. Donenfeld
d9d547a3f3 device: pass cfg strings around in tests instead of reader
This makes it easier to tag things onto the end manually for quick hacks.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-03 17:29:01 +01:00
Jason A. Donenfeld
c3bde5f590 device: benchmark the waitpool to compare it to the prior channels
Here is the old implementation:

    type WaitPool struct {
        c chan interface{}
    }

    func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
        p := &WaitPool{c: make(chan interface{}, max)}
        for i := uint32(0); i < max; i++ {
            p.c <- new()
        }
        return p
    }

    func (p *WaitPool) Get() interface{} {
        return <- p.c
    }

    func (p *WaitPool) Put(x interface{}) {
        p.c <- x
    }

It performs worse than the new one:

    name         old time/op  new time/op  delta
    WaitPool-16  16.4µs ± 5%  15.1µs ± 3%  -7.86%  (p=0.008 n=5+5)

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-03 16:59:29 +01:00
Josh Bleecher Snyder
fd63a233c9 device: test that we do not leak goroutines
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-03 00:57:57 +01:00
Josh Bleecher Snyder
8a374a35a0 device: tie encryption queue lifetime to the peers that write to it
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-03 00:57:57 +01:00
Jason A. Donenfeld
4846070322 device: use a waiting sync.Pool instead of a channel
Channels are FIFO which means we have guaranteed cache misses.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-02 19:32:13 +01:00
Jason A. Donenfeld
a9f80d8c58 device: reduce number of append calls when padding
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-29 20:10:48 +01:00
Jason A. Donenfeld
de51129e33 device: use int64 instead of atomic.Value for time stamp
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-29 18:57:03 +01:00
Jason A. Donenfeld
beb25cc4fd device: use new model queues for handshakes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-29 18:24:45 +01:00
Jason A. Donenfeld
9263014ed3 device: simplify peer queue locking
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-29 16:21:53 +01:00
Jason A. Donenfeld
f0f27d7fd2 device: reduce nesting when staging packet
Suggested-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-28 18:56:58 +01:00
Jason A. Donenfeld
d4112d9096 global: bump copyright
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-28 17:52:15 +01:00
Jason A. Donenfeld
bf3bb88851 device: remove version string
This is what modules are for, and Go binaries can introspect.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-28 17:23:39 +01:00
Jason A. Donenfeld
6a128dde71 device: do not allow get to run while set runs
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-28 15:26:22 +01:00
Jason A. Donenfeld
34c047c762 device: avoid hex allocations in IpcGet
benchmark               old ns/op     new ns/op     delta
BenchmarkUAPIGet-16     2872          2157          -24.90%

benchmark               old allocs     new allocs     delta
BenchmarkUAPIGet-16     30             18             -40.00%

benchmark               old bytes     new bytes     delta
BenchmarkUAPIGet-16     737           256           -65.26%

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-28 15:22:34 +01:00
Jason A. Donenfeld
d4725bc456 device: the psk is not a chapoly key
It's a separate type of key that gets hashed into the chain.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-28 14:45:53 +01:00
Jason A. Donenfeld
1b092ce584 device: get rid of nonce routine
This moves to a simple queue with no routine processing it, to reduce
scheduler pressure.

This splits latency in half!

benchmark                  old ns/op     new ns/op     delta
BenchmarkThroughput-16     2394          2364          -1.25%
BenchmarkLatency-16        259652        120810        -53.47%

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-27 18:38:27 +01:00
Jason A. Donenfeld
a11dec5dc1 tun: use %w for errors on linux
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-27 16:02:42 +01:00
Jason A. Donenfeld
ace50a0529 device: avoid deadlock when changing private key and removing self peers
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-27 15:53:21 +01:00
Jason A. Donenfeld
8cc99631d0 device: use linked list for per-peer allowed-ip traversal
This makes the IpcGet method much faster.

We also refactor the traversal API to use a callback so that we don't
need to allocate at all. Avoiding allocations we do self-masking on
insertion, which in turn means that split intermediate nodes require a
copy of the bits.

benchmark               old ns/op     new ns/op     delta
BenchmarkUAPIGet-16     3243          2659          -18.01%

benchmark               old allocs     new allocs     delta
BenchmarkUAPIGet-16     35             30             -14.29%

benchmark               old bytes     new bytes     delta
BenchmarkUAPIGet-16     1218          737           -39.49%

This benchmark is good, though it's only for a pair of peers, each with
only one allowedips. As this grows, the delta expands considerably.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-27 01:48:58 +01:00
Jason A. Donenfeld
d669c78c43 device: combine debug and info log levels into 'verbose'
There are very few cases, if any, in which a user only wants one of
these levels, so combine it into a single level.

While we're at it, reduce indirection on the loggers by using an empty
function rather than a nil function pointer. It's not like we have
retpolines anyway, and we were always calling through a function with a
branch prior, so this seems like a net gain.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-26 23:05:48 +01:00
Josh Bleecher Snyder
7139279cd0 device: change logging interface to use functions
This commit overhauls wireguard-go's logging.

The primary, motivating change is to use a function instead
of a *log.Logger as the basic unit of logging.
Using functions provides a lot more flexibility for
people to bring their own logging system.

It also introduces logging helper methods on Device.
These reduce line noise at the call site.
They also allow for log functions to be nil;
when nil, instead of generating a log line and throwing it away,
we don't bother generating it at all.
This spares allocation and pointless work.

This is a breaking change, although the fix required
of clients is fairly straightforward.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-26 22:40:20 +01:00
Josh Bleecher Snyder
37efdcaccf device: fix shadowing of err in IpcHandle
The declaration of err in

	nextByte, err := buffered.ReadByte

shadows the declaration of err in

	op, err := buffered.ReadString('\n')

above. As a result, the assignments to err in

	err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %c", nextByte)

and in

	err = device.IpcGetOperation(buffered.Writer)

do not modify the correct err variable.

Found by staticcheck.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-26 22:40:10 +01:00
Josh Bleecher Snyder
d3a2b74df2 device: remove extra error arg
Caught by go vet.
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-26 22:36:10 +01:00
Brad Fitzpatrick
8114c9db5f device: reduce allocs in Device.IpcGetOperation
Plenty more to go, but a start:

name       old time/op    new time/op    delta
UAPIGet-4    6.37µs ± 2%    5.56µs ± 1%  -12.70%  (p=0.000 n=8+8)

name       old alloc/op   new alloc/op   delta
UAPIGet-4    1.98kB ± 0%    1.22kB ± 0%  -38.71%  (p=0.000 n=10+10)

name       old allocs/op  new allocs/op  delta
UAPIGet-4      42.0 ± 0%      35.0 ± 0%  -16.67%  (p=0.000 n=10+10)

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-01-26 11:51:52 -08:00
Josh Bleecher Snyder
e6ec3852a9 device: add benchmark for UAPI Device.IpcGetOperation
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-26 11:40:24 -08:00
Brad Fitzpatrick
23b2790aa0 conn: fix interface parameter name in Bind interface docs
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-01-26 15:20:22 +01:00
Jason A. Donenfeld
18e47795e5 device: allow pipelining UAPI requests
The original spec ends with \n\n especially for this reason.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-25 20:48:28 +01:00
Jason A. Donenfeld
a29767dda6 ipc: add missing Windows errno
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-25 20:48:28 +01:00
Josh Bleecher Snyder
cecb41515d device: serialize access to IpcSetOperation
Interleaves IpcSetOperations would spell trouble.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:38:09 -08:00
Josh Bleecher Snyder
a9ce4b762c device: simplify handling of IPC set endpoint
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:37:28 -08:00
Josh Bleecher Snyder
d8f2cc87ee device: remove close processing fwmark
Also, a behavior change: Stop treating a blank value as 0.
It's not in the spec.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:36:53 -08:00
Josh Bleecher Snyder
2b8665f5f9 device: remove unnecessary comment
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:36:41 -08:00
Josh Bleecher Snyder
674a4675a1 device: introduce new IPC error message for unknown error
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:36:17 -08:00
Josh Bleecher Snyder
87bdcb2ae4 device: correct IPC error number for I/O errors
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:35:48 -08:00
Josh Bleecher Snyder
37a239e736 device: simplify IpcHandle error handling
Unify the handling of unexpected UAPI errors.
The comment that says "should never happen" is incorrect;
this could happen due to I/O errors. Correct it.

Change error message capitalization for consistency.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:09:24 -08:00
Josh Bleecher Snyder
6252de0db9 device: split IpcSetOperation into parts
The goal of this change is to make the structure
of IpcSetOperation easier to follow.

IpcSetOperation contains a small state machine:
It starts by configuring the device,
then shifts to configuring one peer at a time.

Having the code all in one giant method obscured that structure.
Split out the parts into helper functions and encapsulate the peer state.

This makes the overall structure more apparent.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 09:09:24 -08:00
Josh Bleecher Snyder
a029b942ae device: expand IPCError
Expand IPCError to contain a wrapped error,
and add a helper to make constructing such errors easier.

Add a defer-based "log on returned error" to IpcSetOperation.
This lets us simplify all of the error return paths.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 08:47:48 -08:00
Josh Bleecher Snyder
db3fa1409c device: remove dead code
If device.NewPeer returns a nil error,
then the returned peer is always non-nil.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 08:47:48 -08:00
Josh Bleecher Snyder
675aae2423 device: return errors from ipc scanner
The code as written will drop any read errors on the floor.
Fix that.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-25 08:47:48 -08:00
Jason A. Donenfeld
fcc8ad05df netstack: further sequester with own go.mod and go.sum
In order to avoid even the flirtation with passing on these dependencies
to ordinary consumers of wireguard-go, this commit makes a new go.mod
that's entirely separate from the root one.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-21 00:25:02 +01:00
Jason A. Donenfeld
1d4eb2727a netstack: introduce new module for gvisor tcp tun adapter
The Go linker isn't smart enough to prevent gvisor from being pulled
into modules that use other parts of tun/, due to the types exposed. So,
we put this into its own standalone module.

We use this as an opportunity to introduce some example code as well.

I'm still not happy that this not only clutters this repo's go.sum, but
all the other projects that consume it, but it seems like making a new
module inside of this repo will lead to even greater confusion.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-21 00:16:59 +01:00
Jason A. Donenfeld
294d3bedf9 device: allow compiling with Go 1.15
Until we depend on Go 1.16 (which isn't released yet), alias our own
variable to the private member of the net package. This will allow an
easy find replace to make this go away when we eventually switch to
1.16.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-20 20:12:32 +01:00
Josh Bleecher Snyder
86a58b51c0 device: remove unused fields from DummyDatagram and DummyBind
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 20:03:40 +01:00
Josh Bleecher Snyder
6a2ecb581b device: remove unused trie test code
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 20:03:40 +01:00
Josh Bleecher Snyder
f07177c762 conn: remove _ method receiver
Minor style fix.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 20:03:40 +01:00
Josh Bleecher Snyder
b00b2c2951 tun: fix fmt.Errorf format strings
Type tcpip.Error is not an error.

I've filed https://github.com/google/gvisor/issues/5314
to fix this upstream.

Until that is fixed, use %v instead of %w,
to keep vet happy.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 20:03:40 +01:00
Josh Bleecher Snyder
7c5d1e355e device: remove unnecessary zeroing
Newly allocated objects are already zeroed.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:07 +01:00
Josh Bleecher Snyder
a86492a567 device: remove QueueInboundElement.dropped
Now that we block when enqueueing to the decryption queue,
there is only one case in which we "drop" a inbound element,
when decryption fails.

We can use a simple, obvious, sync-free sentinel for that, elem.packet == nil.
Also, we can return the message buffer to the pool slightly later,
which further simplifies the code.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:06 +01:00
Josh Bleecher Snyder
7ee95e053c device: remove QueueOutboundElement.dropped
If we block when enqueuing encryption elements to the queue,
then we never drop them.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:05 +01:00
Josh Bleecher Snyder
291dbcf1f0 tun/wintun/memmod: gofmt
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:04 +01:00
Josh Bleecher Snyder
abc88c82b1 tun/wintun/memmod: fix format verb
Caught by 'go vet'.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:02 +01:00
Josh Bleecher Snyder
23642a13be device: check returned errors from NewPeer in TestNoiseHandshake
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:01 +01:00
Josh Bleecher Snyder
2fe19ce54d device: remove selects from encrypt/decrypt/inbound/outbound enqueuing
Block instead. Backpressure here is fine, probably preferable.
This reduces code complexity.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:57:00 +01:00
Josh Bleecher Snyder
0cc15e7c7c device: put handshake buffer in pool in FlushPacketQueues
This appears to have been an oversight.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:56:59 +01:00
Josh Bleecher Snyder
48c3b87eb8 device: use channel close to shut down and drain decryption channel
This is similar to commit e1fa1cc556,
but for the decryption channel.

It is an alternative fix to f9f655567930a4cd78d40fa4ba0d58503335ae6a.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-20 19:56:54 +01:00
Jason A. Donenfeld
675955de5d tun: add tcpip stack tunnel abstraction
This allows people to initiate connections over WireGuard without any
underlying operating system support.

I'm not crazy about the trash it adds to go.sum, but the code this
actually adds to the binaries seems contained to the gvisor repo.

For the TCP/IP implementation, it uses gvisor. And it borrows some
internals from the Go standard library's resolver in order to bring Dial
and DialContext to tun_net, along with the LookupHost helper function.
This allows for things like HTTP2-over-TLS to work quite well:

    package main

    import (
        "io"
        "log"
        "net"
        "net/http"

        "golang.zx2c4.com/wireguard/device"
        "golang.zx2c4.com/wireguard/tun"
    )

    func main() {
        tun, tnet, err := tun.CreateNetTUN([]net.IP{net.ParseIP("192.168.4.29")}, []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, 1420)
        if err != nil {
            log.Panic(err)
        }
        dev := device.NewDevice(tun, &device.Logger{log.Default(), log.Default(), log.Default()})
        dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
    public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
    endpoint=163.172.161.0:12912
    allowed_ip=0.0.0.0/0
    `)
        dev.Up()

        client := http.Client{
            Transport: &http.Transport{
                DialContext: tnet.DialContext,
            },
        }
        resp, err := client.Get("https://www.zx2c4.com/ip")
        if err != nil {
            log.Panic(err)
        }
        body, err := io.ReadAll(resp.Body)
        if err != nil {
            log.Panic(err)
        }
        log.Println(string(body))
    }

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-13 16:33:40 +01:00
Jason A. Donenfeld
ea6c1cd7e6 device: receive: do not exit immediately on transient UDP receive errors
Some users report seeing lines like:

> Routine: receive incoming IPv4 - stopped

Popping up unexpectedly. Let's sleep and try again before failing, and
also log the error, and perhaps we'll eventually understand this
situation better in future versions.

Because we have to distinguish between the socket being closed
explicitly and whatever error this is, we bump the module to require Go
1.16.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-08 14:30:04 +01:00
Jason A. Donenfeld
3b3de758ec conn: linux: do not allow ReceiveIPvX to race with Close
If Close is called after ReceiveIPvX, then ReceiveIPvX will block on an
invalid or potentially reused fd.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 17:08:58 +01:00
Jason A. Donenfeld
29b0477585 device: receive: drain decryption queue before exiting RoutineDecryption
It's possible for RoutineSequentialReceiver to try to lock an elem after
RoutineDecryption has exited. Before this meant we didn't then unlock
the elem, so the whole program deadlocked.

As well, it looks like the flush code (which is now potentially
unnecessary?) wasn't properly dropping the buffers for the
not-already-dropped case.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 17:08:41 +01:00
Josh Bleecher Snyder
85b4950579 device: add latency and throughput benchmarks
These obviously don't perfectly capture real world performance,
in which syscalls and network links have a significant impact.
Nevertheless, they capture some of the internal performance factors,
and they're easy and convenient to work with.

Hat tip to Avery Pennarun for help designing the throughput benchmark.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
8a30415555 device: use LogLevelError for benchmarking
This keeps the output minimal and focused on the benchmark results.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
cdaf4e9a76 device: make test infrastructure usable with benchmarks
Switch from *testing.T to testing.TB.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
3d83df9bf3 memmod: apply explicit build tags to _32 and _64 files
Since _32 and _64 aren't valid goarchs, they don't match _GOOS_GOARCH,
and so the existing tags wind up not being restricted to windows-only.
This fixes the problem by adding windows to the tags explicitly. We
could also fix it by calling the files _32_windows or _64_windows, but
that changes the convention with the other single-arch files.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
d664444928 tun: make customization of WintunPool and requested GUID more obvious
Persnickety consumers can now do:

    func init() {
        tun.WintunPool, _ = wintun.MakePool("Flurp")
        tun.WintunStaticRequestedGUID, _ = windows.GUIDFromString("{5ae2716f-0b3e-4dc4-a8b5-48eba11a6e16}")
    }

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
1481e72107 all: use ++ to increment
Make the code slightly more idiomatic. No functional changes.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
d0f8e9477c device: remove unnecessary zeroing
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
b42e32047d device: call wg.Add outside the goroutine
One of the first rules of WaitGroups is that you call wg.Add
outside of a goroutine, not inside it. Fix this embarrassing mistake.

This prevents an extremely rare race condition (2 per 100,000 runs)
which could occur when attempting to start a new peer
concurrently with shutting down a device.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
b5f966ac24 device: remove QueueInboundElement leak with stopped peers
This is particularly problematic on mobile,
where there is a fixed number of elements.
If most of them leak, it'll impact performance;
if all of them leak, the device will permanently deadlock.

I have a test that detects element leaks, which is how I found this one.
There are some remaining leaks that I have not yet tracked down,
but this is the most prominent by far.

I will commit the test when it passes reliably.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
a1c265b0c5 device: simplify UAPI helper methods
bufio is not required.

strings.Builder is cheaper than bytes.Buffer for constructing strings.

io.Writer is more flexible than io.StringWriter,
and just as cheap (when used with io.WriteString).

Run gofmt.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
25b01723dd device: fix alignment of peer stats member
This was shifted by 2 bytes when making persistent keepalive into a u32.
Fix it by placing it after the aligned region.

Fixes: e739ff7 ("device: fix persistent_keepalive_interval data races")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
40dfc85def device: add UAPI helper methods
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
890cc06ed5 conn: do not SO_REUSEADDR on linux
SO_REUSEADDR does not make sense for unicast UDP sockets.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
ad73ee78e9 device: add missing colon to error line
People are actually hitting this condition, so make it uniform. Also,
change a printf into a println, to match the other conventions.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Brad Fitzpatrick
e9edc16349 device: fix error shadowing before log print
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
f7bbdc31a0 device: fix data race in peer.timersActive
Found by the race detector and existing tests.

To avoid introducing a lock into this hot path,
calculate and cache whether any peers exist.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
70861686d3 device: fix races from changing private_key
Access keypair.sendNonce atomically.
Eliminate one unnecessary initialization to zero.

Mutate handshake.lastSentHandshake with the mutex held.

Co-authored-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
c8faa34cde device: always name *Queue*Element variables elem
They're called elem in most places.
Rename a few local variables to make it consistent.
This makes it easier to grep the code for things like elem.Drop.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
2832e96339 device: use channel close to shut down and drain outbound channel
This is a similar treatment to the handling of the encryption
channel found a few commits ago: Use the closing of the channel
to manage goroutine lifetime and shutdown.
It is considerably simpler because there is only a single writer.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
63066ce406 device: fix persistent_keepalive_interval data races
Co-authored-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
e1fa1cc556 device: use channel close to shut down and drain encryption channel
The new test introduced in this commit used to deadlock about 1% of the time.

I believe that the deadlock occurs as follows:

* The test completes, calling device.Close.
* device.Close closes device.signals.stop.
* RoutineEncryption stops.
* The deferred function in RoutineEncryption drains device.queue.encryption.
* RoutineEncryption exits.
* A peer's RoutineNonce processes an element queued in peer.queue.nonce.
* RoutineNonce puts that element into the outbound and encryption queues.
* RoutineSequentialSender reads that elements from the outbound queue.
* It waits for that element to get Unlocked by RoutineEncryption.
* RoutineEncryption has already exited, so RoutineSequentialSender blocks forever.
* device.RemoveAllPeers calls peer.Stop on all peers.
* peer.Stop waits for peer.routines.stopping, which blocks forever.

Rather than attempt to add even more ordering to the already complex
centralized shutdown orchestration, this commit moves towards a
data-flow-oriented shutdown.

The device.queue.encryption gets closed when there will be no more writes to it.
All device.queue.encryption readers always read until the channel is closed and then exit.
We thus guarantee that any element that enters the encryption queue also exits it.
This removes the need for central control of the lifetime of RoutineEncryption,
removes the need to drain the encryption queue on shutdown, and simplifies RoutineEncryption.

This commit also fixes a data race. When RoutineSequentialSender
drains its queue on shutdown, it needs to lock the elem before operating on it,
just as the main body does.

The new test in this commit passed 50k iterations with the race detector enabled
and 150k iterations with the race detector disabled, with no failures.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
41cd68416c device: simplify copying counter to nonce
Since we already have it packed into a uint64
in a known byte order, write it back out again
the same byte order instead of copying byte by byte.

This should also generate more efficient code,
because the compiler can do a single uint64 write,
instead of eight bounds checks and eight byte writes.

Due to a missed optimization, it actually generates a mishmash
of smaller writes: 1 byte, 4 bytes, 2 bytes, 1 byte.
This is https://golang.org/issue/41663.
The code is still better than before, and will get better yet
once that compiler bug gets fixed.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
94b33ba705 device: add a helper to generate uapi configs
This makes it easier to work with configs in tests.
It'll see heavier use over upcoming commits;
this commit only adds the infrastructure.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
ea8fbb5927 device: use defer to simplify peer.NewTimer
This also makes the lifetime of modifyingLock more prominent.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
93a4313c3a device: accept any io.Reader in device.IpcSetOperation
Any io.Reader will do, and there are no performance concerns here.
This is technically backwards incompatible,
but it is very unlikely to break any existing code.
It is compatible with the existing uses in wireguard-{windows,android,apple}
and also will allow us to slightly simplify it if desired.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
db1edc7e91 device: increase timeout in tests
When running many concurrent test processing using
https://godoc.org/golang.org/x/tools/cmd/stress
the processing sometimes cannot complete a ping in under 300ms.
Increase the timeout to 5s to reduce the rate of false positives.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
fc0aabbae9 device: prevent spurious errors while closing a device
When closing a device, packets that are in flight
can make it to SendBuffer, which then returns an error.
Those errors add noise but no light;
they do not reflect an actual problem.

Adding the synchronization required to prevent
this from occurring is currently expensive and error-prone.
Instead, quietly drop such packets instead of
returning an error.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
c9e4a859ae device: remove starting waitgroups
In each case, the starting waitgroup did nothing but ensure
that the goroutine has launched.

Nothing downstream depends on the order in which goroutines launch,
and if the Go runtime scheduler is so broken that goroutines
don't get launched reasonably promptly, we have much deeper problems.

Given all that, simplify the code.

Passed a race-enabled stress test 25,000 times without failure.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
3591acba76 device: make test setup more robust
Picking two free ports to use for a test is difficult.
The free port we selected might no longer be free when we reach
for it a second time.

On my machine, this failure mode led to failures approximately
once per thousand test runs.

Since failures are rare, and threading through and checking for
all possible errors is complicated, fix this with a big hammer:
Retry if either device fails to come up.

Also, if you accidentally pick the same port twice, delightful confusion ensues.
The handshake failures manifest as crypto errors, which look scary.
Again, fix with retries.

To make these retries easier to implement, use testing.T.Cleanup
instead of defer to close devices. This requires Go 1.14.
Update go.mod accordingly. Go 1.13 is no longer supported anyway.

With these fixes, 'go test -race' ran 100,000 times without failure.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:44 +01:00
Jason A. Donenfeld
ca9edf1c63 wintun: do not load dll in init()
This prevents linking to wintun.dll until it's actually needed, which
should improve startup time.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-01-07 14:49:44 +01:00
Josh Bleecher Snyder
347ce76bbc tun/tuntest: make genICMPv4 allocate less
It doesn't really matter, because it is only used in tests,
but it does remove some noise from pprof profiles.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-01-07 14:49:37 +01:00
Josh Bleecher Snyder
c4895658e6 device: avoid copying lock in tests
This doesn't cause any practical problems as it is,
but vet (rightly) flags this code as copying a mutex.
It is easy to fix, so do so.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2020-12-08 14:25:10 -08:00
Josh Bleecher Snyder
d3ff2d6b62 device: clear pointers when returning elems to pools
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2020-12-08 14:25:02 -08:00
Josh Bleecher Snyder
01d3aaa7f4 device: use labeled for loop instead of goto
Minor code cleanup; no functional changes.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2020-12-08 14:24:20 -08:00
Jason A. Donenfeld
b6303091fc memmod: fix import loading function usage
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-27 13:13:45 +01:00
Simon Rozman
c9fabbd5bf wintun: log when reboot is suggested by Windows
Which really shouldn't happen. But it is a useful information for
troubleshooting.

Signed-off-by: Simon Rozman <simon@rozman.si>
2020-11-25 13:58:11 +01:00
Simon Rozman
4cc7a7a455 wintun: keep original error when Wintun session start fails
Signed-off-by: Simon Rozman <simon@rozman.si>
2020-11-25 13:57:05 +01:00
Jason A. Donenfeld
da19db415a version: bump snapshot 2020-11-18 14:24:17 +01:00
Jason A. Donenfeld
52c834c446 mod: bump
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-18 14:24:00 +01:00
Haichao Liu
913f68ce38 device: add write queue mutex for peer
fix panic: send on closed channel when remove peer

Signed-off-by: Haichao Liu <liuhaichao@bytedance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-18 14:22:15 +01:00
Jason A. Donenfeld
60b3766b89 wintun: load from filesystem by default
We let people loading this from resources opt in via:

    go build -tags load_wintun_from_rsrc

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-11 18:51:44 +01:00
Jason A. Donenfeld
82128c47d9 global: switch to using %w instead of %v for Errorf
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 21:56:32 +01:00
Jason A. Donenfeld
c192b2eeec mod: update deps
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 15:22:18 +01:00
Simon Rozman
a3b231b31e wintun: ring management moved to wintun.dll
Signed-off-by: Simon Rozman <simon@rozman.si>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 15:20:49 +01:00
Simon Rozman
65e03a9182 wintun: load wintun.dll from RCDATA resource
Signed-off-by: Simon Rozman <simon@rozman.si>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 15:20:49 +01:00
Simon Rozman
3e08b8aee0 wintun: migrate to wintun.dll API
Rather than having every application using Wintun driver reinvent the
wheel, the Wintun device/adapter/interface management has been moved
from wireguard-go to wintun.dll deployed with Wintun itself.

Signed-off-by: Simon Rozman <simon@rozman.si>
2020-11-07 12:46:35 +01:00
Jason A. Donenfeld
5ca1218a5c device: format a few things
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-06 18:01:27 +01:00
Tobias Klauser
3b490f30aa tun: use SockaddrCtl from golang.org/x/sys/unix on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Switch to use
unix.Connect with unix.SockaddrCtl instead.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
e6b7c4eef3 tun: use Ioctl{Get,Set}IfreqMTU from golang.org/x/sys/unix on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Switch to use
unix.Ioctl{Get,Set}IfreqMTU to get and set an interface's MTU.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
8ae09213a7 tun: use IoctlCtlInfo from golang.org/x/sys/unix on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Switch to use
unix.IoctlCtlInfo to get the kernel control info.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
36dc8b6994 tun: use GetsockoptString in (*NativeTun).Name on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Instead, use the
existing unix.GetsockoptString wrapper to get the interface name.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
2057f19a61 go.mod: bump golang.org/x/sys to latest version
This adds the fixes for golang/go#41868 which are needed to build
wireguard without direct syscalls on macOS.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Brad Fitzpatrick
58a8f05f50 tun/wintun/registry: fix Go 1.15 race/checkptr failure
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
[Jason: ran go mod tidy.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-21 18:26:10 +02:00
Frank Werner
0b54907a73 Makefile: Add test target
Signed-off-by: Frank Werner <mail@hb9fxq.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-20 12:38:18 +02:00
Riobard Zhan
2c143dce0f replay: minor API changes to more idiomatic Go
Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-14 10:46:00 +02:00
Riobard Zhan
22af3890f6 replay: clean up internals and better documentation
Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-14 10:46:00 +02:00
Jason A. Donenfeld
c8fe925020 device: remove global for roaming escape hatch
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-14 10:45:31 +02:00
Jason A. Donenfeld
0cfa3314ee replay: divide by bits-per-byte
Bits / Bytes-per-Word misses the step of also dividing by Bits-per-Byte,
which we need in order for this to make sense.

Reported-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-09-07 18:51:49 +02:00
Sina Siadat
bc3f505efa device: get free port when testing
Signed-off-by: Sina Siadat <siadat@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-07-31 16:18:53 +02:00
David Crawshaw
507f148e1c device: remove bindsocketshim.go
Both wireguard-windows and wireguard-android access Bind
directly for these methods now.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-07-14 23:18:53 -06:00
Brad Fitzpatrick
31b574ef99 device: remove some unnecessary unsafe
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-07-15 06:59:44 +10:00
Tobias Klauser
3c41141fb4 device: use RTMGRP_IPV4_ROUTE to specify multicast groups mask
Use the RTMGRP_IPV4_ROUTE const from x/sys/unix instead of using the
corresponding RTNLGRP_IPV4_ROUTE const to create the multicast groups
mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-07-13 17:58:10 -06:00
Dmytro Shynkevych
4369db522b device: wait for routines to stop before removing peers
Peers are currently removed after Device's goroutines are signaled to stop,
but without waiting for them to actually do so, which is racy.

For example, RoutineHandshake may be in Peer.SendKeepalive
when the corresponding peer is removed, which closes its nonce channel.
This causes a send on a closed channel, as observed in tailscale/tailscale#487.

This patch seems to be the correct synchronizing action:
Peer's goroutines are receivers and handle channel closure gracefully,
so Device's goroutines are the ones that should be fully stopped first.

Signed-Off-By: Dmytro Shynkevych <dmytro@tailscale.com>
2020-07-04 20:29:31 +10:00
David Crawshaw
b84f1d4db2 device: export Bind and remove socketfd shims for android
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-06-22 10:42:28 +10:00
David Crawshaw
dfb28757f7 ipc: add comment about socketDirectory linker override on android
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-06-22 10:41:19 +10:00
David Crawshaw
00bcd865e6 conn: add comments saying what uses these interfaces
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-06-22 10:40:59 +10:00
Jason A. Donenfeld
f28a6d244b device: do not include sticky sockets on android
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:50:20 -06:00
Jason A. Donenfeld
c403da6a39 conn: unbreak boundif on android
Another thing never tested ever.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:48:28 -06:00
Jason A. Donenfeld
d6de6f3ce6 conn: remove useless comment
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:37:01 -06:00
Jason A. Donenfeld
59e556f24e conn: fix windows situation with boundif
This was evidently never tested before committing.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:26:25 -06:00
Jason A. Donenfeld
31faf4c159 replay: account for fqcodel reordering
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-19 17:46:35 -06:00
Jason A. Donenfeld
99eb7896be device: rework padding calculation and don't shadow paddedSize
Reported-by: Jayakumar S <jayakumar82.s@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-18 15:43:22 -06:00
Dmytro Shynkevych
f60b3919be tai64n: make the test deterministic
In the presence of preemption, the current test may fail transiently.
This uses static test data instead to ensure consistent behavior.

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
2020-05-06 16:01:48 +10:00
Jason A. Donenfeld
da9d300cf8 main: now that we're upstreamed, relax Linux warning
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:20:47 -06:00
Jason A. Donenfeld
59c9929714 README: specify go 1.13
Due to the use of the new errors module, we now require at least 1.13
instead of 1.12.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:08:52 -06:00
Jason A. Donenfeld
db0aa39b76 global: update header comments and modules
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:08:26 -06:00
David Crawshaw
bc77de2aca ipc: deduplicate some unix-specific code
Cleans up and splits out UAPIOpen to its own file.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
[zx2c4: changed const to var for socketDirectory]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:05:41 -06:00
David Crawshaw
c8596328e7 ipc: remove unnecessary error check
os.MkdirAll never returns an os.IsExist error.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 02:02:09 -06:00
Jason A. Donenfeld
28c4d04304 device: use atomic access for unlocked keypair.next
Go's GC semantics might not always guarantee the safety of this, and the
race detector gets upset too, so instead we wrap this all in atomic
accessors.

Reported-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 01:56:48 -06:00
Simon Rozman
fdba6c183a wintun: make remaining HWID comparisons case insensitive
c85e4a410f introduced preliminary HWID
checking to speed up Wintun adapter enumeration. However, all HWID are
case insensitive by Windows convention.

Furthermore, a device might have multiple HWIDs. When DevInfo's
DeviceRegistryProperty(SPDRP_HARDWAREID) method returns []string, all
strings returned should be checked against given hardware ID.

This issue was discovered when researching Wintun and wireguard-go on
Windows 10 ARM64. The Wintun adapter was created using devcon.exe
utility with "wintun" hardware ID, causing wireguard-go fail to
enumerate the adapter properly.

Signed-off-by: Simon Rozman <simon@rozman.si>
2020-05-02 01:50:47 -06:00
Simon Rozman
250b9795f3 setupapi: extend struct size constant definitions for arm(64)
Signed-off-by: Simon Rozman <simon@rozman.si>
2020-05-02 01:50:47 -06:00
Avery Pennarun
d60857e1a7 device: add debug logs describing handshake rejection
Useful in testing when bad network stacks repeat or
batch large numbers of packets.

Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
2020-05-02 01:50:47 -06:00
Brad Fitzpatrick
2fb0a712f0 tun: return a better error message if /dev/net/tun doesn't exist
It was just returning "no such file or directory" (the String of the
syscall.Errno returned by CreateTUN).

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-05-02 01:50:47 -06:00
David Anderson
f2c6faad44 device: return generic error from Ipc{Get,Set}Operation.
This makes uapi.go's public API conform to Go style in terms
of error types.

Signed-off-by: David Anderson <danderson@tailscale.com>
2020-05-02 01:49:47 -06:00
Avery Pennarun
c76b818466 tun: NetlinkListener: don't send EventDown before sending EventUp
This works around a startup race condition when competing with
HackListener, which is trying to do the same job. If HackListener
detects that the tundev is running while there is still an event in the
netlink queue that says it isn't running, then the device receives a
string of events like
	EventUp (HackListener)
	EventDown (NetlinkListener)
	EventUp (NetlinkListener)
Unfortunately, after the first EventDown, the device stops itself,
thinking incorrectly that the administrator has downed its tundev.

The device is ignoring the initial EventDown anyway, so just don't emit
it.

Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
2020-05-02 01:46:42 -06:00
David Crawshaw
de374bfb44 device: give handshake state a type
And unexport handshake constants.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 01:46:42 -06:00
David Crawshaw
1a1c3d0968 tuntest: split out testing package
This code is useful to other packages writing tests.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 01:46:42 -06:00
Brad Fitzpatrick
85a45a9651 tun: fix data race on name field
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-05-02 01:46:42 -06:00
Brad Fitzpatrick
abd287159e tun: remove unused isUp method
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-05-02 01:46:42 -06:00
David Crawshaw
203554620d conn: introduce new package that splits out the Bind and Endpoint types
The sticky socket code stays in the device package for now,
as it reaches deeply into the peer list.

This is the first step in an effort to split some code out of
the very busy device package.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 01:46:42 -06:00
Avery Pennarun
6aefb61355 wintun: split error message for create vs open namespace.
Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
2020-05-02 01:44:58 -06:00
David Anderson
3dce460c88 device: add test to ensure Peer fields are safe for atomic access on 32-bit
Adds a test that will fail consistently on 32-bit platforms if the
struct ever changes again to violate the rules. This is likely not
needed because unaligned access crashes reliably, but this will reliably
fail even if tests accidentally pass due to lucky alignment.

Signed-Off-By: David Anderson <danderson@tailscale.com>
2020-05-02 01:44:58 -06:00
David Crawshaw
224bc9e60c rwcancel: no-op builds for windows and darwin
This lets us include the package on those platforms in a
followup commit where we split out a conn package from device.
It also lets us run `go test ./...` when developing on macOS.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-03-30 18:41:39 +11:00
David Crawshaw
9cd8909df2 ratelimiter: use a fake clock in tests and style cleanups
The existing test would occasionally flake out with:

	--- FAIL: TestRatelimiter (0.12s)
	    ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true
	FAIL
	FAIL    golang.zx2c4.com/wireguard/ratelimiter  0.171s

The fake clock also means the tests run much faster, so
testing this package with -count=1000 now takes < 100ms.

While here, several style cleanups. The most significant one
is unembeding the sync.Mutex fields in the rate limiter objects.
Embedded as they were, the lock methods were accessible
outside the ratelimiter package. As they aren't needed externally,
keep them internal to make them easier to reason about.

Passes `go test -race -count=10000 ./ratelimiter`

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-03-30 18:38:36 +11:00
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
05b03c6750 version: bump snapshot 2020-01-21 16:27:19 +01:00
Jason A. Donenfeld
caebdfe9d0 tun: darwin: ignore ENOMEM errors
Coauthored-by: Andrej Mihajlov <and@mullvad.net>
2020-01-15 13:39:37 -05:00
Jason A. Donenfeld
4fa2ea6a2d tun: windows: serialize write calls 2020-01-07 11:40:45 -05:00
Jason A. Donenfeld
89dd065e53 README: update repo urls 2019-12-30 11:53:39 +01:00
Jason A. Donenfeld
ddfad453cf device: SendmsgN mutates the input sockaddr
So we take a new granular lock to prevent concurrent writes from
racing.

WARNING: DATA RACE
Write at 0x00c0011f2740 by goroutine 27:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:318
+0x4b8

Previous write at 0x00c0011f2740 by goroutine 386:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.expiredRetransmitHandshake()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:110
+0x40c
  golang.zx2c4.com/wireguard/device.(*Peer).NewTimer.func1()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:42
+0xd8

Goroutine 27 (running) created at:
  golang.zx2c4.com/wireguard/device.NewDevice()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/device.go:322
+0x5e8
  main.main()
      /go/src/x/main.go:102 +0x58e

Goroutine 386 (finished) created at:
  time.goFunc()
      /usr/local/go/src/time/sleep.go:168 +0x51

Reported-by: Ben Burkert <ben@benburkert.com>
2019-11-28 11:11:13 +01:00
Jason A. Donenfeld
2b242f9393 wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially
well.
2019-11-22 13:13:55 +01:00
Jason A. Donenfeld
4cdf805b29 constants: recalculate rekey max based on a one minute flood
Discussed-with: Mathias Hall-Andersen <mathias@hall-andersen.dk>
2019-10-30 14:29:32 +01:00
Jonathan Tooker
f7d0edd2ec global: fix a few typos courtesy of codespell
Signed-off-by: Jonathan Tooker <jonathan.tooker@netprotect.com>
2019-10-22 11:51:25 +02:00
Jason A. Donenfeld
ffffbbcc8a device: allow blackholing sockets 2019-10-21 13:29:57 +02:00
Jason A. Donenfeld
47b02c618b device: remove dead error reporting code 2019-10-21 11:46:54 +02:00
Jason A. Donenfeld
fd23c66fcd namespaceapi: remove tasteless comment 2019-10-21 09:02:29 +02:00
Jason A. Donenfeld
ae492d1b35 device: recheck counters while holding write lock 2019-10-17 15:43:06 +02:00
Jason A. Donenfeld
95fbfccf60 wintun: normalize variable names for their types 2019-10-17 15:30:56 +02:00
Avery Pennarun
c85e4a410f wintun: quickly ignore non-Wintun devices
Some devices take ~2 seconds to enumerate on Windows if we try to get
their instance name.  The hardware id property, on the other hand,
is available right away.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: inlined this to where it makes sense, reused setupapi const]
2019-10-17 15:19:20 +02:00
Avery Pennarun
1b6c8ddbe8 tun: match windows CreateTUN signature to the Linux variant
Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: fix default value]
2019-10-17 15:19:20 +02:00
Avery Pennarun
0abb6b668c rwcancel: handle EINTR and EAGAIN in unixSelect()
On my Chromebook (Linux 4.19.44 in a VM) and on an AWS EC2
machine, select() was sometimes returning EINTR. This is
harmless and just means you should try again. So let's try
again.

This eliminates a problem where the tunnel fails to come up
correctly and the program needs to be restarted.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
2019-10-17 15:19:17 +02:00
David Crawshaw
540d01e54a device: test packets between two fake devices
Signed-off-by: David Crawshaw <crawshaw@tailscale.io>
2019-10-16 11:38:28 +02:00
Jason A. Donenfeld
f2ea85e9f9 version: bump snapshot 2019-10-12 22:34:10 +02:00
Jason A. Donenfeld
222f0f8000 Makefile: remove v prefix 2019-10-08 16:48:18 +02:00
Jason A. Donenfeld
1f146a5e7a wintun: expose version 2019-10-08 09:58:58 +02:00
Jason A. Donenfeld
f2501aa6c8 uapi: allow preventing creation of new peers when updating
This enables race-free updates for wg-dynamic and similar tools.

Suggested-by: Thomas Gschwantner <tharre3@gmail.com>
2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
cb8d01f58a mod: bump versions 2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
01f8ef4e84 winpipe: use x/sys/windows instead of syscall 2019-09-16 23:39:16 -06:00
Jason A. Donenfeld
70f6c42556 wintun: use correct length for security attributes 2019-09-16 19:38:33 -06:00
Jason A. Donenfeld
bb0b2514c0 tun: windows: unify error message format 2019-09-08 13:52:44 -05:00
Jason A. Donenfeld
7c97fdb1e3 version: bump snapshot 2019-09-08 10:56:55 -05:00
Jason A. Donenfeld
84b5a4d83d main: simplify warnings 2019-09-08 10:56:00 -05:00
Jason A. Donenfeld
4cd06c0925 tun: openbsd: check for interface already being up
In some cases, we operate on an already-up interface, or the user brings
up the interface before we start monitoring. For those situations, we
should first check if the interface is already up.

This still technically races between the initial check and the start of
the route loop, but fixing that is a bit ugly and probably not worth it
at the moment.

Reported-by: Theo Buehler <tb@theobuehler.org>
2019-09-07 00:13:23 -05:00
Jason A. Donenfeld
d12eb91f9a namespaceapi: AddSIDToBoundaryDescriptor modifies the handle 2019-09-05 21:48:21 -06:00
Jason A. Donenfeld
73d3bd9cd5 wintun: take mutex first always
This prevents an ABA deadlock with setupapi's internal locks.
2019-09-01 21:32:28 -06:00
Jason A. Donenfeld
f3dba4c194 wintun: consider abandoned mutexes as released 2019-09-01 21:25:47 -06:00
Jason A. Donenfeld
7937840f96 ipc: windows: use protected prefix 2019-08-31 07:48:42 -06:00
Jason A. Donenfeld
e4b957183c winpipe: enforce ownership of client connection 2019-08-30 13:21:47 -06:00
Jason A. Donenfeld
950ca2ba8c wintun: put mutex into private namespace 2019-08-30 11:03:21 -06:00
Jason A. Donenfeld
df2bf34373 namespaceapi: fix mistake 2019-08-30 09:59:36 -06:00
Simon Rozman
a12b765784 namespaceapi: initial version
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-30 15:34:17 +02:00
Jason A. Donenfeld
14df9c3e75 wintun: take mutex so that deletion uses the right name 2019-08-30 15:34:17 +02:00
Jason A. Donenfeld
353f0956bc wintun: move ring constants into module 2019-08-29 13:22:17 -06:00
Jason A. Donenfeld
fa7763c268 wintun: delete all interfaces is not used anymore 2019-08-29 12:22:15 -06:00
Jason A. Donenfeld
d94bae8348 wintun: Wintun->Interface 2019-08-29 12:20:40 -06:00
Jason A. Donenfeld
7689d09336 wintun: keep reference to pool in wintun object 2019-08-29 12:13:16 -06:00
Simon Rozman
69c26dc258 wintun: introduce adapter pools
This makes wintun package reusable for non-WireGuard applications.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-29 18:00:44 +02:00
Jason A. Donenfeld
e862131d3c wintun: simplify rename logic 2019-08-28 19:31:20 -06:00
Jason A. Donenfeld
da28a3e9f3 wintun: give better errors when ndis interface listing fails 2019-08-28 08:39:26 -06:00
Jason A. Donenfeld
3bf3322b2c wintun: also check for numbered suffix and friendly name 2019-08-28 08:08:07 -06:00
Simon Rozman
7305b4ce93 wintun: upgrade deleting all interfaces and make it reusable
DeleteAllInterfaces() didn't check if SPDRP_DEVICEDESC == "WireGuard
Tunnel". It deleted _all_ Wintun adapters, not just WireGuard's.

Furthermore, the DeleteAllInterfaces() was upgraded into a new function
called DeleteMatchingInterfaces() for selectively deletion. This will
be used by WireGuard to clean stale Wintun adapters.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-28 11:39:01 +02:00
Jason A. Donenfeld
26fb615b11 wintun: cleanup earlier 2019-08-27 11:59:15 -06:00
Jason A. Donenfeld
7fbb24afaa wintun: rename duplicate adapters instead of ourselves 2019-08-27 11:59:15 -06:00
Jason A. Donenfeld
d9008ac35c wintun: match suffix numbers 2019-08-26 14:46:43 -06:00
Jason A. Donenfeld
f8198c0428 device: getsockname on linux to determine port
It turns out Go isn't passing the pointer properly so we wound up with a
zero port every time.
2019-08-25 12:45:13 -06:00
Jason A. Donenfeld
0c540ad60e wintun: make description consistent across fields 2019-08-24 12:29:17 +02:00
Jason A. Donenfeld
3cedc22d7b wintun: try multiple names until one isn't a duplicate 2019-08-22 08:52:59 +02:00
Jason A. Donenfeld
68fea631d8 wintun: use nci.dll directly instead of buggy netshell 2019-08-21 09:16:12 +02:00
Jason A. Donenfeld
ef23100a4f wintun: set friendly a bit better
This is still wrong, but NETSETUPPKEY_Driver_FriendlyName seems a bit
tricky to use.
2019-08-20 16:06:55 +02:00
Jason A. Donenfeld
eb786cd7c1 wintun: also set friendly name after setting interface name 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
333de75370 wintun: defer requires unique variable 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
d20459dc69 wintun: set adapter description name 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
01786286c1 tun: windows: don't spin unless we really need it 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
b16dba47a7 version: bump snapshot 2019-08-05 19:29:12 +02:00
Jason A. Donenfeld
4be9630ddc device: drop lock before expiring keys 2019-08-05 17:46:34 +02:00
Jason A. Donenfeld
4e3018a967 uapi: skip peers with invalid keys 2019-08-05 16:57:41 +02:00
Jason A. Donenfeld
b4010123f7 tun: windows: spin for only a millisecond/80
Performance stays the same as before.
2019-08-03 19:11:21 +02:00
Simon Rozman
1ff37e2b07 wintun: merge opening device registry key
This also introduces waiting for key to appear on initial access.

See if this resolves the issue caused by HDD power-up delay resulting in
failure to create the adapter.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-02 16:08:49 +02:00
Simon Rozman
f5e54932e6 wintun: simplify checking reboot requirement
We never checked checkReboot() reported error anyway.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-02 16:08:49 +02:00
Simon Rozman
73698066d1 wintun: refactor err == nil error checking
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-02 15:18:58 +02:00
Jason A. Donenfeld
05ece4d167 wintun: handle error for deadgwdetect 2019-08-02 14:37:09 +02:00
Jason A. Donenfeld
6d78f89557 tun: darwin: do not attempt to close tun.event twice
Previously it was possible for this to race. It turns out we really
don't need to set anything to -1 anyway.
2019-08-02 12:24:17 +02:00
Jason A. Donenfeld
a2249449d6 wintun: get interface path properly with cfgmgr 2019-07-23 14:58:46 +02:00
Jason A. Donenfeld
eeeac287ef tun: windows: style 2019-07-23 11:45:48 +02:00
Jason A. Donenfeld
b5a7cbf069 wintun: simplify resolution of dev node 2019-07-23 11:45:13 +02:00
Jason A. Donenfeld
50cd522cb0 wintun: enable sharing of pnp node 2019-07-22 17:01:27 +02:00
Jason A. Donenfeld
5ba866a5c8 tun: windows: close event handle on shutdown 2019-07-22 09:37:20 +02:00
Jason A. Donenfeld
2f101fedec ipc: windows: match SDDL of WDK and make monkeyable 2019-07-19 15:34:26 +02:00
Jason A. Donenfeld
3341e2d444 tun: windows: get rid of retry logic
Things work fine on Windows 8.
2019-07-19 14:01:34 +02:00
Jason A. Donenfeld
1b550f6583 tun: windows: use specific IOCTL code 2019-07-19 08:30:19 +02:00
Jason A. Donenfeld
7bc0e11831 device: do not crash on nil'd bind in windows binding 2019-07-18 19:34:45 +02:00
Jason A. Donenfeld
31ff9c02fe tun: windows: open file at startup time 2019-07-18 19:27:27 +02:00
Jason A. Donenfeld
1e39c33ab1 tun: windows: silently drop packet when ring is full 2019-07-18 15:48:34 +02:00
Jason A. Donenfeld
6c50fedd8e tun: windows: switch to NDIS device object 2019-07-18 12:26:57 +02:00
Jason A. Donenfeld
298d759f3e wintun: calculate path of NDIS device object symbolic link 2019-07-18 10:25:20 +02:00
Michael Zeltner
4d5819183e tun: openbsd: don't change MTU when it's already the expected size
Allows for running wireguard-go as non-root user.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2019-07-18 10:25:20 +02:00
Jason A. Donenfeld
9ea9a92117 tun: windows: spin for a bit before falling back to event object 2019-07-18 10:25:20 +02:00
Simon Rozman
2e24e7dcae tun: windows: implement ring buffers
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-07-17 14:32:13 +02:00
Jason A. Donenfeld
a961aacc9f device: immediately rekey all peers after changing device private key
Reported-by: Derrick Pallas <derrick@pallas.us>
2019-07-11 17:37:35 +02:00
Jason A. Donenfeld
b0cf53b078 README: update windows info 2019-07-08 14:52:49 +02:00
Jason A. Donenfeld
5c3d333f10 tun: windows: registration of write buffer no longer required 2019-07-05 14:17:48 +02:00
Jason A. Donenfeld
d8448f8a02 tun: windows: decrease alignment to 4 2019-07-05 07:53:19 +02:00
Jason A. Donenfeld
13abbdf14b tun: windows: delay initial write
Otherwise we provoke Wintun 0.3.
2019-07-04 22:41:42 +02:00
Jason A. Donenfeld
f361e59001 device: receive: uniform message for source address check 2019-07-01 15:24:50 +02:00
Jason A. Donenfeld
b844f1b3cc tun: windows: packetNum is unused 2019-07-01 15:23:44 +02:00
Jason A. Donenfeld
dd8817f50e device: receive: simplify flush loop 2019-07-01 15:23:24 +02:00
Jason A. Donenfeld
5e6eff81b6 tun: windows: inform wintun of maximum buffer length for writes 2019-06-26 13:27:48 +02:00
Jason A. Donenfeld
c69d026649 tun: windows: never retry open on Windows 10 2019-06-18 17:51:29 +02:00
Matt Layher
1f48971a80 tun: remove TUN prefix from types to reduce stutter elsewhere
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-14 18:35:57 +02:00
Jason A. Donenfeld
3371f8dac6 device: update transfer counters correctly
The rule is to always update them to the full packet size minus UDP/IP
encapsulation for all authenticated packet types.
2019-06-11 18:13:52 +02:00
Jason A. Donenfeld
41fdbf0971 wintun: increase registry timeout 2019-06-11 00:33:07 +02:00
Jason A. Donenfeld
03eee4a778 wintun: add helper for cleaning up 2019-06-10 11:34:59 +02:00
Jason A. Donenfeld
700860f8e6 wintun: simplify error matching and remove dumb comments 2019-06-10 11:10:49 +02:00
Jason A. Donenfeld
a304f69e0d wintun: fix comments and remove hwnd param
This now looks more idiomatic.
2019-06-10 11:03:36 +02:00
Simon Rozman
baafe92888 setupapi: add SetDeviceRegistryPropertyString description
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-10 10:43:04 +02:00
Simon Rozman
a1a97d1e41 setupapi: unify ERROR_INSUFFICIENT_BUFFER handling
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-10 10:43:03 +02:00
Jason A. Donenfeld
e924280baa wintun: allow controlling GUID 2019-06-10 10:43:02 +02:00
Jason A. Donenfeld
bb3f1932fa setupapi: add DeviceInstanceID() 2019-06-10 10:43:01 +02:00
Jason A. Donenfeld
eaf17becfa global: fixup TODO comment spacing 2019-06-06 23:00:15 +02:00
Jason A. Donenfeld
6d8b68c8f3 wintun: guid functions are upstream 2019-06-06 22:39:20 +02:00
Simon Rozman
c2ed133df8 wintun: simplify DeleteInterface method signature
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-06 08:58:26 +02:00
Jason A. Donenfeld
108c37a056 wintun: don't run HrRenameConnection in separate thread
It's very slow, but unfortunately we haven't a choice. NLA needs this to
have completed.
2019-06-05 13:09:20 +02:00
Simon Rozman
e4b0ef29a1 tun: windows: obsolete 256 packets per exchange buffer limitation
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-05 11:55:28 +02:00
Simon Rozman
625e445b22 setupapi, wintun: replace syscall with golang.org/x/sys/windows
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-04 14:54:56 +02:00
Simon Rozman
85b85e62e5 wintun: set DI_QUIETINSTALL flag for GUI-less device management
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-04 14:45:23 +02:00
Simon Rozman
014f736480 setupapi: define PropChangeParams struct
This structure is required for calling DIF_PROPERTYCHANGE installer
class.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-04 14:45:23 +02:00
Matt Layher
43a4589043 device: remove redundant return statements
More staticcheck fixes:

$ staticcheck ./... | grep S1023
device/noise-helpers.go:45:2: redundant return statement (S1023)
device/noise-helpers.go:54:2: redundant return statement (S1023)
device/noise-helpers.go:64:2: redundant return statement (S1023)

Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-04 13:01:52 +02:00
Matt Layher
8d76ac8cc4 device: use bytes.Equal for equality check, simplify assertEqual
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-04 13:01:52 +02:00
Matt Layher
18b6627f33 device, ratelimiter: replace uses of time.Now().Sub() with time.Since()
Simplification found by staticcheck:

$ staticcheck ./... | grep S1012
device/cookie.go:90:5: should use time.Since instead of time.Now().Sub (S1012)
device/cookie.go:127:5: should use time.Since instead of time.Now().Sub (S1012)
device/cookie.go:242:5: should use time.Since instead of time.Now().Sub (S1012)
device/noise-protocol.go:304:13: should use time.Since instead of time.Now().Sub (S1012)
device/receive.go:82:46: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:132:5: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:139:5: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:235:59: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:393:9: should use time.Since instead of time.Now().Sub (S1012)
ratelimiter/ratelimiter.go:79:10: should use time.Since instead of time.Now().Sub (S1012)
ratelimiter/ratelimiter.go:87:10: should use time.Since instead of time.Now().Sub (S1012)

Change applied using:

$ find . -type f -name "*.go" -exec sed -i "s/Now().Sub(/Since(/g" {} \;

Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-03 22:15:41 +02:00
Matt Layher
80ef2a42e6 ipc/winpipe: go fmt
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-03 22:15:36 +02:00
Jason A. Donenfeld
da61947ec3 tun: windows: mitigate infinite loop in Flush()
It's possible that for whatever reason, we keep returning EOF, resulting
in repeated close/open/write operations, except with empty packets.
2019-05-31 16:55:03 +02:00
Jason A. Donenfeld
d9f995209c device: add SendKeepalivesToPeersWithCurrentKeypair for handover 2019-05-30 15:16:16 +02:00
Jason A. Donenfeld
d0ab883ada tai64n: account for whitening in test 2019-05-29 18:44:53 +02:00
Matt Layher
32912dc778 device, tun: rearrange code and fix device tests
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-05-29 18:34:55 +02:00
Jason A. Donenfeld
d4034e5f8a wintun: remove extra / 2019-05-26 02:20:01 +02:00
Jason A. Donenfeld
fbcd995ec1 device: darwin actually doesn't need bound interfaces 2019-05-25 18:10:52 +02:00
Jason A. Donenfeld
e7e286ba6c device: make initiations per second match kernel implementation 2019-05-25 02:07:18 +02:00
Jason A. Donenfeld
f70546bc2e device: timers: add jitter on ack failure reinitiation 2019-05-24 13:48:25 +02:00
Simon Rozman
6a0a3a5406 wintun: revise GetInterface()
- Make foreign interface found error numeric to ease condition
  detection.
- Update GetInterface() documentation.
- Make tun.CreateTUN() quit when foreign interface found before
  attempting to create a Wintun interface with a duplicate name.
  Creation is futile.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-24 09:29:57 +02:00
Jason A. Donenfeld
8fdcf5ee30 wintun: never return nil, nil 2019-05-23 15:25:53 +02:00
Jason A. Donenfeld
a74a29bc93 ipc: use simplified fork of winio 2019-05-23 15:16:02 +02:00
Simon Rozman
dc9bbec9db setupapi: trim "Get" from getters
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-22 19:31:52 +02:00
Jason A. Donenfeld
a6dbe4f475 wintun: don't try to flush interface, but rather delete 2019-05-17 16:06:02 +02:00
Jason A. Donenfeld
c718f3940d device: fail to give bind if it doesn't exist 2019-05-17 15:35:20 +02:00
Jason A. Donenfeld
95c70b8032 wintun: make certain methods private 2019-05-17 15:01:08 +02:00
Jason A. Donenfeld
583ebe99f1 version: bump snapshot 2019-05-17 10:28:04 +02:00
Jason A. Donenfeld
a6dd282600 makefile: do not show warning on non-linux 2019-05-17 10:27:51 +02:00
Simon Rozman
7d5f5bcc0d wintun: change acronyms to uppercase
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-17 10:22:34 +02:00
Jason A. Donenfeld
3bf41b06ae global: regroup all imports 2019-05-14 09:09:52 +02:00
Jason A. Donenfeld
3147f00089 wintun: registry: fix nits 2019-05-11 17:25:48 +02:00
Simon Rozman
6c1b66802f wintun: registry: revise value reading
- Make getStringValueRetry() reusable for reading any value type. This
  merges code from GetIntegerValueWait().
- expandString() >> toString() and extend to support REG_MULTI_SZ
  (to return first value of REG_MULTI_SZ). Furthermore, doing our own
  UTF-16 to UTF-8 conversion works around a bug in windows/registry's
  GetStringValue() non-zero terminated string handling.
- Provide toInteger() analogous to toString()
- GetStringValueWait() tolerates and reads REG_MULTI_SZ too now. It
  returns REG_MULTI_SZ[0], making GetFirstStringValueWait() redundant.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-11 17:14:37 +02:00
Jason A. Donenfeld
5669ed326f wintun: call HrRenameConnection in another thread 2019-05-10 21:31:37 +02:00
Jason A. Donenfeld
2d847a38a2 wintun: add LUID accessor 2019-05-10 21:30:23 +02:00
Jason A. Donenfeld
7a8553aef0 wintun: enumerate faster by using COMPATDRIVER instead of CLASSDRIVER 2019-05-10 20:30:59 +02:00
Jason A. Donenfeld
a6045ac042 wintun: destroy devinfolist after usage 2019-05-10 20:19:11 +02:00
Simon Rozman
1c92b48415 wintun: registry: replace REG_NOTIFY with NOTIFY
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-10 18:09:20 +02:00
Jason A. Donenfeld
c267965bf8 wintun: IpConfig is a MULTI_SZ, and fix errors 2019-05-10 18:06:49 +02:00
Jason A. Donenfeld
1bf1dadf15 wintun: poll for device key
It's actually pretty hard to guess where it is.
2019-05-10 17:34:03 +02:00
Jason A. Donenfeld
f9dcfccbb7 wintun: fix scope of error object 2019-05-10 16:59:24 +02:00
Simon Rozman
7e962a9932 wintun: wait for interface registry key on device creation
By using RegNotifyChangeKeyValue(). Also disable dead gateway detection.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-10 16:43:58 +02:00
Jason A. Donenfeld
586112b5d7 conn: remove scope when sanity checking IP address format 2019-05-09 15:42:35 +02:00
Simon Rozman
dcb8f1aa6b wintun: fix GUID leading zero padding
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-09 12:16:21 +02:00
Jason A. Donenfeld
b16b0e4cf7 mod: update deps 2019-05-03 09:37:29 +02:00
Jason A. Donenfeld
81ca08f1b3 setupapi: safer aliasing of slice types 2019-05-03 09:34:00 +02:00
Jason A. Donenfeld
2e988467c2 wintun: work around GetInterface staleness bug 2019-05-03 00:42:36 +02:00
Jason A. Donenfeld
46dbf54040 wintun: don't retry when not creating
The only time we're trying to counteract the race condition is when
we're creating a driver. When we're simply looking up all drivers, it
doesn't make sense to retry.
2019-05-02 23:53:15 +02:00
Jason A. Donenfeld
247e14693a wintun: try harder to open registry key
This sucks. Can we please find a deterministic way of doing this
instead?
2019-04-29 14:00:49 +02:00
Jason A. Donenfeld
3945a299ff go.mod: use vendored winio 2019-04-29 08:09:38 +02:00
Jason A. Donenfeld
bb42ec7d18 tun: freebsd: work around numerous kernel panics on shutdown
There are numerous race conditions. But even this will crash it:

while true; do ifconfig tun0 create; ifconfig tun0 destroy; done

It seems like LLv6 is related, which we're not using anyway, so
explicitly disable it on the interface.
2019-04-23 18:00:23 +09:00
Simon Rozman
f1dc167901 setupapi: Fix struct size mismatches
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-19 10:08:11 +02:00
Jason A. Donenfeld
c7a26dfef3 setupapi: actually fix padding by rounding up to sizeof(void*) 2019-04-19 10:19:00 +09:00
Jason A. Donenfeld
d024393335 tun: darwin: write routeSocket variable in helper
Otherwise the race detector "complains".
2019-04-19 07:53:19 +09:00
Jason A. Donenfeld
d9078fe772 main: revise warnings 2019-04-19 07:48:09 +09:00
Jason A. Donenfeld
d3dd991e4e device: send: check packet length before freeing element 2019-04-18 23:23:03 +09:00
Simon Rozman
5811447b38 setupapi: Revise DrvInfoDetailData struct size calculation
Go adds trailing padding to DrvInfoDetailData struct in GOARCH=386 which
confuses SetupAPI expecting exactly sizeof(SP_DRVINFO_DETAIL_DATA).

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-18 10:39:22 +02:00
Jason A. Donenfeld
e0a8c22aa6 windows: use proper constants from updated x/sys 2019-04-13 02:02:02 +02:00
Jason A. Donenfeld
0b77bf78cd conn: linux: RTA_MARK has moved to x/sys 2019-04-13 02:01:20 +02:00
Simon Rozman
ef5f3ad80a tun: windows: Adopt new error codes returned by Wintun
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-11 19:38:11 +02:00
Simon Rozman
a291fdd746 tun: windows: do not sleep after OPERATION_ABORTED on write
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-11 19:37:04 +02:00
Jason A. Donenfeld
d50e390904 main_windows: use proper version constant 2019-04-09 10:45:40 +02:00
Jason A. Donenfeld
18fa270472 version: put version in right place 2019-04-09 10:39:48 +02:00
Jason A. Donenfeld
f156a53ff4 version: bump snapshot 2019-04-09 07:37:22 +02:00
Jason A. Donenfeld
e680008700 tun: windows: do not sleep after OPERATION_ABORTED 2019-04-09 07:36:03 +02:00
Simon Rozman
767c86f8cb tun: windows: Retry R/W on ERROR_OPERATION_ABORTED
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-04 09:20:18 +02:00
Simon Rozman
421c1f9143 tun: windows: Attempt to reopen handle on all errors
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-03 05:41:38 +02:00
Jason A. Donenfeld
ac25702eaf wintun: rename device using undocumented API that netsh.exe uses 2019-04-01 12:04:44 +02:00
Jason A. Donenfeld
92f8474832 wintun: add more retry loops 2019-04-01 09:07:43 +02:00
Jason A. Donenfeld
2e0ed4614a tun: windows: cancel ongoing reads on closing and delete after close
This reverts commit 52ec440d79 and adds
some spice.
2019-03-26 16:14:32 +01:00
Jason A. Donenfeld
2fa80c0cb7 wintun: query for NetCfgInstanceId several times 2019-03-22 16:48:40 -06:00
Jason A. Donenfeld
52ec440d79 tun: windows: delete interface before deleting file handles 2019-03-22 16:45:58 -06:00
Simon Rozman
2faf2dcf90 tun: windows: Make adapter rename asynchronous
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-22 16:36:30 +01:00
Simon Rozman
41c30a7279 tun: windows: Adapter devices renamed to WINTUN<LUID Index>
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-22 15:29:14 +01:00
Simon Rozman
4b1db1d39b tun: windows: Increase unavailable adapter timeout to 30sec
5 seconds was too short when debugging.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-22 13:52:51 +01:00
Simon Rozman
a80db5e65e tun: windows: Make writing persistent too
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-22 13:52:51 +01:00
Simon Rozman
9748a52073 tun: windows: Fix paused adapter test
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-22 13:52:51 +01:00
Jason A. Donenfeld
317d716d66 tun: windows: just open two file handles 2019-03-21 15:20:09 -06:00
Jason A. Donenfeld
6440f010ee receive: implement flush semantics 2019-03-21 14:45:41 -06:00
Jason A. Donenfeld
49ea0c9b1a tun: windows: add dummy overlapped events back
These seem basically wrong to me, but we get crashes without them.
2019-03-21 02:29:09 -06:00
Jason A. Donenfeld
ca59b60aa7 tun: windows: use new constants in sys 2019-03-20 23:42:30 -06:00
Jason A. Donenfeld
c050c6e60f uapi: remove unhelpful log messages 2019-03-20 23:40:20 -06:00
Simon Rozman
91b4e909bb wintun: Use native Win32 API for I/O
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-21 00:56:45 +01:00
Jason A. Donenfeld
2c51d6af48 uapi: report endpoint error 2019-03-19 00:34:04 -06:00
Jason A. Donenfeld
03f2e2614a tun: windows: wintun does iocp 2019-03-18 02:42:45 -06:00
Jason A. Donenfeld
b0e0ab308d tun: windows: temporary hack for forcing MTU 2019-03-13 02:52:32 -06:00
Jason A. Donenfeld
66fb5caf02 wintun: Poll more often 2019-03-10 03:47:54 +01:00
Jason A. Donenfeld
3dd9a0535f uapi: make ipcerror conform to interface 2019-03-10 02:49:44 +01:00
Simon Rozman
c2a2b8d739 wintun: Make errors more descriptive
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-08 10:03:57 +01:00
Simon Rozman
70449f1a97 wintun: Return correct reboot-req flag on CreateInterface() error too
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-08 10:03:57 +01:00
Simon Rozman
33c3528430 wintun: Fix double-quoted strings escaping on output
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-08 10:03:57 +01:00
Simon Rozman
30ab07e354 wintun: Introduce SetupAPI enumerator and machineName consts
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-08 10:03:57 +01:00
Odd Stranne
a6d5ef82f4 Windows: Apply strict security descriptor on pipe server
Signed-off-by: Odd Stranne <odd@mullvad.net>
2019-03-08 10:03:56 +01:00
Jason A. Donenfeld
5c7cc256e3 uapi: windows: work out pipe semantics
Pipes can be arranged like this, so that's fine. We also apply a strict
SDDL that can't be inherited and only gives access to local system.

Developed-with: Odd Stranne <odd@mullvad.net>
2019-03-08 01:40:54 +01:00
Simon Rozman
368dea72fe wintun: Cleanup
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-07 21:12:20 +01:00
Simon Rozman
9b22255cad wintun: Refactor network registry key name generation
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-07 21:12:20 +01:00
Simon Rozman
11f5780250 wintun: Revise interface creation wait
DIF_INSTALLDEVICE returns almost immediately, while the device
installation continues in the background. It might take a while, before
all registry keys and values are populated.

Previously, wireguard-go waited for HKLM\SYSTEM\CurrentControlSet\
Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\<id> registry key
only.

Followed by a SetInterfaceName() method of Wintun struct which tried to
access HKLM\SYSTEM\CurrentControlSet\Control\Network\
{4D36E972-E325-11CE-BFC1-08002BE10318}\<id>\Connection registry key
might not be available yet.

This commit loops until both registry keys are available before
returning from CreateInterface() function.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-07 21:12:20 +01:00
Jason A. Donenfeld
26af6c4651 receive: squelch tear down error 2019-03-07 02:03:48 +01:00
Jason A. Donenfeld
92f72f5aa6 tun: linux: work out netpoll trick 2019-03-07 01:51:41 +01:00
Simon Rozman
1fdf7b19a3 wintun: Resolve some of golint warnings
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-04 16:37:11 +01:00
Simon Rozman
a1aabb21ae Elaborate the failing step when forwarding errors on return
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-04 16:37:11 +01:00
Simon Rozman
9041d38e2d Simplify reading NetCfgInstanceId from registry
As querying non-existing registry value and reading non-existing
registry string value both return ERROR_FILE_NOT_FOUND, we can
use later only.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-04 16:37:11 +01:00
Simon Rozman
cddfd9a0d8 Unify interface-specific network registry key open
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-03-04 16:37:11 +01:00
Jason A. Donenfeld
68f0721c6a tun: import mobile particularities 2019-03-04 16:37:11 +01:00
Jason A. Donenfeld
b8e85267cf boundif: introduce API for socket binding 2019-03-04 16:37:11 +01:00
Jason A. Donenfeld
69f0fe67b6 global: begin modularization 2019-03-03 05:00:40 +01:00
Jason A. Donenfeld
d435be35ca tun: windows: expose GUID 2019-03-01 00:11:12 +01:00
Jason A. Donenfeld
967d1a0f3d tun: allow special methods in NativeTun 2019-03-01 00:05:57 +01:00
Jason A. Donenfeld
88ff67fb6f tun: linux: netpoll is broken for tun's epoll
So this mostly reverts the switch to Sysconn for Linux.

Issue: https://github.com/golang/go/issues/30426
2019-02-27 04:38:26 +01:00
Jason A. Donenfeld
971be13e77 tun: linux: netlink sock needs cleaning up but file will be gc'd 2019-02-27 04:11:41 +01:00
Jason A. Donenfeld
366cbd11a4 tun: use netpoll instead of rwcancel
The new sysconn function of Go 1.12 makes this possible:

package main

import "log"
import "os"
import "unsafe"
import "time"
import "syscall"
import "sync"
import "golang.org/x/sys/unix"

func main() {
	fd, err := os.OpenFile("/dev/net/tun", os.O_RDWR, 0)
	if err != nil {
		log.Fatal(err)
	}

	var ifr [unix.IFNAMSIZ + 64]byte
	copy(ifr[:], []byte("cheese"))
	*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = unix.IFF_TUN

	var errno syscall.Errno
	s, _ := fd.SyscallConn()
	s.Control(func(fd uintptr) {
		_, _, errno = unix.Syscall(
			unix.SYS_IOCTL,
			fd,
			uintptr(unix.TUNSETIFF),
			uintptr(unsafe.Pointer(&ifr[0])),
		)
	})
	if errno != 0 {
		log.Fatal(errno)
	}

	b := [4]byte{}
	wait := sync.WaitGroup{}
	wait.Add(1)
	go func() {
		_, err := fd.Read(b[:])
		log.Print("Read errored: ", err)
		wait.Done()
	}()
	time.Sleep(time.Second)
	log.Print("Closing")
	err = fd.Close()
	if err != nil {
		log.Print("Close errored: " , err)
	}
	wait.Wait()
	log.Print("Exiting")
}
2019-02-27 01:52:55 +01:00
Jason A. Donenfeld
ab0f442daf tun: use sysconn instead of .Fd with Go 1.12 2019-02-27 01:34:11 +01:00
Jason A. Donenfeld
66524c1f7e Rearrange imports 2019-02-22 20:59:43 +01:00
Jason A. Donenfeld
6e4460ae65 device: send persistent keepalive when bringing up device
Reported-by: Marcelo Bello
2019-02-22 19:33:28 +01:00
Simon Rozman
d002eff155 wintun: Read/write packet size from/to exchange buffer directly
Driver <-> user-space communication is local and using native endian.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-22 16:16:14 +01:00
Simon Rozman
e06a8f8f9f wintun: Make two-step slicing a one step
Stop relying to Go compiler optimizations and calculate the end offset
directly.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-22 16:11:33 +01:00
Simon Rozman
ac4944a708 wintun: Write exchange buffer increased back to 1MiB
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-20 20:13:33 +01:00
Simon Rozman
2491f9d454 wintun: Migrate from unsafe buffer handling to encoding/binary
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-20 20:10:24 +01:00
Simon Rozman
8091c6474a wintun: Adopt new packet data alignment
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-20 19:56:10 +01:00
Simon Rozman
040da43889 wintun: Cleanup
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-20 18:38:18 +01:00
Simon Rozman
b7025b5627 wintun: Add TUN device locking
In case reading from TUN device detected TUN device was closed, it
closed the file handle and set tunFile to nil. The tunFile is
automatically reopened on retry, but... If another packet comes in the
WireGuard calls Write() method. With tunFile set to nil, this will
cause access violation.

Therefore, locking was introduced.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-20 13:12:08 +01:00
Simon Rozman
6581cfb885 wintun: Move exchange buffer in separate struct on heap
This allows buffer alignment and keeps it together with its meta-data.

Furthermore, the write buffer has been reduced - as long as we flush
after _every_ write, we don't need a 1MiB write buffer.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-20 11:41:37 +01:00
Simon Rozman
4863089120 wintun: Switch to dynamic packet sizes
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-19 18:50:42 +01:00
Jason A. Donenfeld
42c6d0e261 Change package path 2019-02-18 05:11:39 +01:00
Jason A. Donenfeld
f7170e5de2 Bump dependencies for ARM ChaCha20 2019-02-14 10:59:54 +01:00
Simon Rozman
b719a09a26 wintun: Auto-calculate TUN exchange buffer size
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-08 15:21:24 +01:00
Simon Rozman
f05f52637f wintun: Simplify Read method()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-08 14:31:05 +01:00
Simon Rozman
713477cfb1 wintun: Make constants private and adopt Go recommended case
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-08 08:55:23 +01:00
Simon Rozman
5981d5cacf wintun: Check for user close in read loop regardless the load
Do the WaitForSingleObject() always to provide high-load responsiveness.

Reorder events so TUN_SIGNAL_CLOSE has priority over
TUN_SIGNAL_DATA_AVAIL, to provide high-load responsiveness at all.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-08 08:48:35 +01:00
Simon Rozman
b13739ada2 wintun: Adjust tunRWQueue.left member to match Wintun driver
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-08 07:32:12 +01:00
Simon Rozman
c4988999ac setupapi: Merge _SP_DRVINFO_DETAIL_DATA and DrvInfoDetailData
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 23:50:43 +01:00
Simon Rozman
b662896cf4 setupapi: Merge SP_DRVINFO_DATA and DrvInfoData
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 23:50:43 +01:00
Simon Rozman
0525f6b112 setupapi: Rename SP_REMOVEDEVICE_PARAMS to RemoveDeviceParams
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 23:50:43 +01:00
Simon Rozman
9d830826c5 setupapi: Rename SP_CLASSINSTALL_HEADER to ClassInstallHeader
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 23:50:43 +01:00
Simon Rozman
bd963497da setupapi: Merge _SP_DEVINSTALL_PARAMS and DevInstallParams
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 23:50:30 +01:00
Simon Rozman
05d25fd1b7 setupapi: Merge _SP_DEVINFO_LIST_DETAIL_DATA and DevInfoListDetailData
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 23:49:50 +01:00
Simon Rozman
6d2729dccc setupapi: Rename SP_DEVINFO_DATA to DevInfoData
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 22:43:02 +01:00
Simon Rozman
d87cbeeb2f wintun: Detect if a foreign interface with the same name exists
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 22:02:51 +01:00
Simon Rozman
043b7e8013 wintun: Clean excessive setupapi.DevInfo.GetDeviceInfoListDetail() call
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 20:49:41 +01:00
Simon Rozman
ef48d4fa95 wintun: Explain rationale behind case-insensitive interface names
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 19:42:59 +01:00
Simon Rozman
f7276ed522 wintun: Implement TODO in TestSetupDiGetDeviceRegistryProperty()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-07 18:59:34 +01:00
Jason A. Donenfeld
c4b43e35a7 wintun: add FlushInterface stub 2019-02-07 18:24:28 +01:00
Jason A. Donenfeld
2efafecab5 main_windows: Get iface name from argument 2019-02-07 15:44:07 +01:00
Jason A. Donenfeld
fac1fbcd72 wintun: Compare values of GUID, not pointers, when removing 2019-02-07 04:49:15 +01:00
Jason A. Donenfeld
52aa00f3ba main_windows: Catch more exit events 2019-02-07 04:42:35 +01:00
Jason A. Donenfeld
ea59177f1c wintun: Introduce new package for obscuring Windows bits 2019-02-07 04:39:59 +01:00
Jason A. Donenfeld
306d08e692 tun_windows: Style 2019-02-07 04:08:05 +01:00
Jason A. Donenfeld
3b7a4fa3ef setupapi: Lower case params 2019-02-07 03:46:31 +01:00
Jason A. Donenfeld
223685875f setupapi: Do not export the toGo/toWindows functions 2019-02-07 02:56:31 +01:00
Jason A. Donenfeld
652158ec3c setupapi: Pass pointers instead of values 2019-02-07 02:37:19 +01:00
Simon Rozman
cb2bc4b34c tun_windows: Introduce preliminary TUN interface creation
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-06 22:30:14 +01:00
Simon Rozman
46279ad0f9 tun_windows: Stop checking minimum size of received TUN packets
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-06 20:22:04 +01:00
Simon Rozman
73df1c0871 setupapi: Add DrvInfoDetailData.IsCompatible() to simplify HID detection
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-06 20:18:44 +01:00
Simon Rozman
069016bbc4 setupapi: Add SP_DRVINFO_DATA.IsNewer() method to simplify comparison
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-06 20:17:47 +01:00
Simon Rozman
3c29434a79 setupapi: Make toUTF16() public and add UTF16ToBuf() counterpart
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-06 20:15:40 +01:00
Jason A. Donenfeld
c599bf9497 Fix up errors and paths 2019-02-05 22:06:25 +09:00
Jason A. Donenfeld
f7f63765d1 conn: close ipv4 socket when ipv6 socket fails 2019-02-05 21:55:33 +09:00
Simon Rozman
3e8f2e3fa5 setupapi: Add support for driver info lists
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 16:29:17 +01:00
Simon Rozman
7b636380e5 setupapi: Move Go<>Windows struct marshaling to types_windows.go
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 14:03:28 +01:00
Simon Rozman
99a3b628e9 setupapi: Add support for SetupDi(Get|Set)DeviceRegistryProperty()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
e7ffce0d21 setupapi: Introduce DevInfo methods for cleaner code
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
35f72239ac Add support for setupapi.SetupDi(Get|Set)SelectedDevice()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
c15cbefc12 Reorder data-types and functions to match SetupAPI.h
Adding functions with non-consistent order made setupapi package a mess.
While we could reorder data-types and functions by alphabet - it would
make searching easier - it would put ...Get... and ...Set... functions
quite apart.

Therefore, the SetupAPI.h order was adopted.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
dd998ca86a Add support for setupapi.SetupDiCreateDeviceInfo()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
024a4916c2 Add support for setupapi.setupDiCreateDeviceInfoListEx()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
963be8e993 Stop accessing SetupDiGetDeviceInfoListDetail() output on error
The data returned by SetupDiGetDeviceInfoListDetail() is nil on error
which will cause the test to crash should the function fail.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
e821cdabd2 Unify certain variable names
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
38c7acd70f Simplify SetupDiEnumDeviceInfo() synopsis
The SetupDiEnumDeviceInfo() now returns a SP_DEVINFO_DATA rather than
taking it on input to fill it on return.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
20f1512b7c Change generic local variable names with meaningful replacements
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
348b4e9f7c Add support for setupapi.SetupDiClassGuidsFromNameEx()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
f81882ee8b Clean an unused constant
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
3e0e61dd26 Replace SetupDiClassNameFromGuid() with SetupDiClassNameFromGuidEx()
The former is only a subset of the later. To minimize future
maintenance, we'll provide support for extended version only.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
9635a0b3a6 Add support for setupapi.SetupDiClassNameFromGuid()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
90b6938ca0 Stop checking for valid handle in DevInfo.Close()
User should not have called or deferred the Close() method should
SetupDiGetClassDevsEx() return an error (and invalid handle). And even
if user does that, a SetupDiDestroyDeviceInfoList(INVALID_HANDLE_VALUE)
is harmless. It just returns ERROR_INVALID_HANDLE - we have a unit test
for this in TestSetupDiDestroyDeviceInfoList().

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
269944002f Add support for setupapi.SetupDiCallClassInstaller()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
a5a1ece32f Add support for setupapi.SetupDi(Get|Set)ClassInstallParams()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
f1d5db6547 Add support for setupapi.SetupDi(Get|Set)DeviceInstallParams()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
dce5192d86 Add support for setupapi.SetupDiOpenDevRegKey()
Furthermore setupapi.DevInfoData has been obsoleted.
SetupDiEnumDeviceInfo() fills existing SP_DEVINFO_DATA structure now.
As other functions of SetupAPI use SP_DEVINFO_DATA, converting it to
DevInfoData and back would hurt performance.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
955d8dfe04 Add support for setupapi.SetupDiEnumDeviceInfo()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
25e18d01e6 Update exported types and functions annotations
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
45959c116a Add support for setupapi.SetupDiGetDeviceInfoListDetail()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
d41bc015cc Finish support for setupapi.SetupDiGetClassDevsEx()
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Simon Rozman
31949136df Introduce SetupAPI - Windows device and driver management API
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
6f76edd045 Import windows scafolding 2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
3af9aa88a3 noise: store clamped key instead of raw key 2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
a5ca02d79a tai64n: whiten nano seconds
Avoid being too precise of a time oracle.
2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
2b7562abbb uapi: Simpler function signature 2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
89d2c5ed7a Extend structs rather than embed, when possible 2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
dff424baf8 Update copyright 2019-02-05 12:59:42 +01:00
Jason A. Donenfeld
6e61c369e8 Properly bubble up setsockopt error from closure 2018-12-25 22:56:36 +01:00
Jason A. Donenfeld
8fde8334dc version: bump snapshot 2018-12-22 17:34:23 +01:00
Jason A. Donenfeld
a8326ae753 Make error messages consistent 2018-12-19 00:35:53 +01:00
Jason A. Donenfeld
05cc0c8298 Freebsd is finally normal in sys/unix 2018-12-11 18:33:13 +01:00
Jason A. Donenfeld
c967f15e44 Separate out mark setting for Windows 2018-12-11 18:29:46 +01:00
Jason A. Donenfeld
5ace0fdfe2 Use upstream's xchacha20poly1305 2018-12-10 04:23:17 +01:00
Jason A. Donenfeld
849fa400e9 Update go x/ libraries
Android 9's Bionic disallows inotify_init with seccomp, so we want the
latest unix change, and while we're at it, we update the others too.

Reported-by: Berk D. Demir <bdd@mindcast.org>
Go CL: https://go-review.googlesource.com/c/sys/+/153318
Fixes: https://lists.zx2c4.com/pipermail/wireguard/2018-December/003642.html
2018-12-10 04:04:19 +01:00
Jason A. Donenfeld
651744561e tun: remove nonblock hack for linux
This is no longer necessary and actually breaks things

Reported-by: Chris Branch <cbranch@cloudflare.com>
2018-12-06 17:17:51 +01:00
Jason A. Donenfeld
4fd55daafe tai64n: use proper nanoseconds offset
The code before was obviously wrong.

Reported-by: Vlad Krasnov <vlad@cloudflare.com>
2018-11-08 03:58:01 +01:00
Jason A. Donenfeld
276bf973e8 Use darwin tun on ios 2018-11-06 16:24:35 +01:00
Jason A. Donenfeld
c37c4ece9e uapi: typo 2018-11-05 05:46:27 +01:00
Jason A. Donenfeld
b803276061 receive: make started status uniform 2018-11-01 19:54:25 +01:00
Jason A. Donenfeld
8be1fc9c00 send: do not unlock already freed object 2018-10-18 18:15:24 +02:00
136 changed files with 14977 additions and 6787 deletions

41
.github/workflows/build-if-tag.yml vendored Normal file
View file

@ -0,0 +1,41 @@
name: build-if-tag
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
env:
APP: amneziawg-go
jobs:
build:
runs-on: ubuntu-latest
name: build
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup metadata
uses: docker/metadata-action@v5
id: metadata
with:
images: amneziavpn/${{ env.APP }}
tags: type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.metadata.outputs.tags }}

5
.gitignore vendored
View file

@ -1,4 +1 @@
wireguard-go amneziawg-go
vendor
.gopath
ireallywantobuildon_linux.go

338
COPYING
View file

@ -1,338 +0,0 @@
GNU GENERAL PUBLIC LICENSE
Version 2, June 1991
Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The licenses for most software are designed to take away your
freedom to share and change it. By contrast, the GNU General Public
License is intended to guarantee your freedom to share and change free
software--to make sure the software is free for all its users. This
General Public License applies to most of the Free Software
Foundation's software and to any other program whose authors commit to
using it. (Some other Free Software Foundation software is covered by
the GNU Lesser General Public License instead.) You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
this service if you wish), that you receive source code or can get it
if you want it, that you can change the software or use pieces of it
in new free programs; and that you know you can do these things.
To protect your rights, we need to make restrictions that forbid
anyone to deny you these rights or to ask you to surrender the rights.
These restrictions translate to certain responsibilities for you if you
distribute copies of the software, or if you modify it.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must give the recipients all the rights that
you have. You must make sure that they, too, receive or can get the
source code. And you must show them these terms so they know their
rights.
We protect your rights with two steps: (1) copyright the software, and
(2) offer you this license which gives you legal permission to copy,
distribute and/or modify the software.
Also, for each author's protection and ours, we want to make certain
that everyone understands that there is no warranty for this free
software. If the software is modified by someone else and passed on, we
want its recipients to know that what they have is not the original, so
that any problems introduced by others will not reflect on the original
authors' reputations.
Finally, any free program is threatened constantly by software
patents. We wish to avoid the danger that redistributors of a free
program will individually obtain patent licenses, in effect making the
program proprietary. To prevent this, we have made it clear that any
patent must be licensed for everyone's free use or not licensed at all.
The precise terms and conditions for copying, distribution and
modification follow.
GNU GENERAL PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. This License applies to any program or other work which contains
a notice placed by the copyright holder saying it may be distributed
under the terms of this General Public License. The "Program", below,
refers to any such program or work, and a "work based on the Program"
means either the Program or any derivative work under copyright law:
that is to say, a work containing the Program or a portion of it,
either verbatim or with modifications and/or translated into another
language. (Hereinafter, translation is included without limitation in
the term "modification".) Each licensee is addressed as "you".
Activities other than copying, distribution and modification are not
covered by this License; they are outside its scope. The act of
running the Program is not restricted, and the output from the Program
is covered only if its contents constitute a work based on the
Program (independent of having been made by running the Program).
Whether that is true depends on what the Program does.
1. You may copy and distribute verbatim copies of the Program's
source code as you receive it, in any medium, provided that you
conspicuously and appropriately publish on each copy an appropriate
copyright notice and disclaimer of warranty; keep intact all the
notices that refer to this License and to the absence of any warranty;
and give any other recipients of the Program a copy of this License
along with the Program.
You may charge a fee for the physical act of transferring a copy, and
you may at your option offer warranty protection in exchange for a fee.
2. You may modify your copy or copies of the Program or any portion
of it, thus forming a work based on the Program, and copy and
distribute such modifications or work under the terms of Section 1
above, provided that you also meet all of these conditions:
a) You must cause the modified files to carry prominent notices
stating that you changed the files and the date of any change.
b) You must cause any work that you distribute or publish, that in
whole or in part contains or is derived from the Program or any
part thereof, to be licensed as a whole at no charge to all third
parties under the terms of this License.
c) If the modified program normally reads commands interactively
when run, you must cause it, when started running for such
interactive use in the most ordinary way, to print or display an
announcement including an appropriate copyright notice and a
notice that there is no warranty (or else, saying that you provide
a warranty) and that users may redistribute the program under
these conditions, and telling the user how to view a copy of this
License. (Exception: if the Program itself is interactive but
does not normally print such an announcement, your work based on
the Program is not required to print an announcement.)
These requirements apply to the modified work as a whole. If
identifiable sections of that work are not derived from the Program,
and can be reasonably considered independent and separate works in
themselves, then this License, and its terms, do not apply to those
sections when you distribute them as separate works. But when you
distribute the same sections as part of a whole which is a work based
on the Program, the distribution of the whole must be on the terms of
this License, whose permissions for other licensees extend to the
entire whole, and thus to each and every part regardless of who wrote it.
Thus, it is not the intent of this section to claim rights or contest
your rights to work written entirely by you; rather, the intent is to
exercise the right to control the distribution of derivative or
collective works based on the Program.
In addition, mere aggregation of another work not based on the Program
with the Program (or with a work based on the Program) on a volume of
a storage or distribution medium does not bring the other work under
the scope of this License.
3. You may copy and distribute the Program (or a work based on it,
under Section 2) in object code or executable form under the terms of
Sections 1 and 2 above provided that you also do one of the following:
a) Accompany it with the complete corresponding machine-readable
source code, which must be distributed under the terms of Sections
1 and 2 above on a medium customarily used for software interchange; or,
b) Accompany it with a written offer, valid for at least three
years, to give any third party, for a charge no more than your
cost of physically performing source distribution, a complete
machine-readable copy of the corresponding source code, to be
distributed under the terms of Sections 1 and 2 above on a medium
customarily used for software interchange; or,
c) Accompany it with the information you received as to the offer
to distribute corresponding source code. (This alternative is
allowed only for noncommercial distribution and only if you
received the program in object code or executable form with such
an offer, in accord with Subsection b above.)
The source code for a work means the preferred form of the work for
making modifications to it. For an executable work, complete source
code means all the source code for all modules it contains, plus any
associated interface definition files, plus the scripts used to
control compilation and installation of the executable. However, as a
special exception, the source code distributed need not include
anything that is normally distributed (in either source or binary
form) with the major components (compiler, kernel, and so on) of the
operating system on which the executable runs, unless that component
itself accompanies the executable.
If distribution of executable or object code is made by offering
access to copy from a designated place, then offering equivalent
access to copy the source code from the same place counts as
distribution of the source code, even though third parties are not
compelled to copy the source along with the object code.
4. You may not copy, modify, sublicense, or distribute the Program
except as expressly provided under this License. Any attempt
otherwise to copy, modify, sublicense or distribute the Program is
void, and will automatically terminate your rights under this License.
However, parties who have received copies, or rights, from you under
this License will not have their licenses terminated so long as such
parties remain in full compliance.
5. You are not required to accept this License, since you have not
signed it. However, nothing else grants you permission to modify or
distribute the Program or its derivative works. These actions are
prohibited by law if you do not accept this License. Therefore, by
modifying or distributing the Program (or any work based on the
Program), you indicate your acceptance of this License to do so, and
all its terms and conditions for copying, distributing or modifying
the Program or works based on it.
6. Each time you redistribute the Program (or any work based on the
Program), the recipient automatically receives a license from the
original licensor to copy, distribute or modify the Program subject to
these terms and conditions. You may not impose any further
restrictions on the recipients' exercise of the rights granted herein.
You are not responsible for enforcing compliance by third parties to
this License.
7. If, as a consequence of a court judgment or allegation of patent
infringement or for any other reason (not limited to patent issues),
conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot
distribute so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you
may not distribute the Program at all. For example, if a patent
license would not permit royalty-free redistribution of the Program by
all those who receive copies directly or indirectly through you, then
the only way you could satisfy both it and this License would be to
refrain entirely from distribution of the Program.
If any portion of this section is held invalid or unenforceable under
any particular circumstance, the balance of the section is intended to
apply and the section as a whole is intended to apply in other
circumstances.
It is not the purpose of this section to induce you to infringe any
patents or other property right claims or to contest validity of any
such claims; this section has the sole purpose of protecting the
integrity of the free software distribution system, which is
implemented by public license practices. Many people have made
generous contributions to the wide range of software distributed
through that system in reliance on consistent application of that
system; it is up to the author/donor to decide if he or she is willing
to distribute software through any other system and a licensee cannot
impose that choice.
This section is intended to make thoroughly clear what is believed to
be a consequence of the rest of this License.
8. If the distribution and/or use of the Program is restricted in
certain countries either by patents or by copyrighted interfaces, the
original copyright holder who places the Program under this License
may add an explicit geographical distribution limitation excluding
those countries, so that distribution is permitted only in or among
countries not thus excluded. In such case, this License incorporates
the limitation as if written in the body of this License.
9. The Free Software Foundation may publish revised and/or new versions
of the General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the Program
specifies a version number of this License which applies to it and "any
later version", you have the option of following the terms and conditions
either of that version or of any later version published by the Free
Software Foundation. If the Program does not specify a version number of
this License, you may choose any version ever published by the Free Software
Foundation.
10. If you wish to incorporate parts of the Program into other free
programs whose distribution conditions are different, write to the author
to ask for permission. For software which is copyrighted by the Free
Software Foundation, write to the Free Software Foundation; we sometimes
make exceptions for this. Our decision will be guided by the two goals
of preserving the free status of all derivatives of our free software and
of promoting the sharing and reuse of software generally.
NO WARRANTY
11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGES.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
convey the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License version 2
as published by the Free Software Foundation.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
Also add information on how to contact you by electronic and paper mail.
If the program is interactive, make it output a short notice like this
when it starts in an interactive mode:
Gnomovision version 69, Copyright (C) year name of author
Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, the commands you use may
be called something other than `show w' and `show c'; they could even be
mouse-clicks or menu items--whatever suits your program.
You should also get your employer (if you work as a programmer) or your
school, if any, to sign a "copyright disclaimer" for the program, if
necessary. Here is a sample; alter the names:
Yoyodyne, Inc., hereby disclaims all copyright interest in the program
`Gnomovision' (which makes passes at compilers) written by James Hacker.
<signature of Ty Coon>, 1 April 1989
Ty Coon, President of Vice
This General Public License does not permit incorporating your program into
proprietary programs. If your program is a subroutine library, you may
consider it more useful to permit linking proprietary applications with the
library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License.

17
Dockerfile Normal file
View file

@ -0,0 +1,17 @@
FROM golang:1.24 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
go mod verify && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
unzip -j alpine-3.19-amneziawg-tools.zip && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go

17
LICENSE Normal file
View file

@ -0,0 +1,17 @@
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,42 +1,31 @@
PREFIX ?= /usr PREFIX ?= /usr
DESTDIR ?= DESTDIR ?=
BINDIR ?= $(PREFIX)/bin BINDIR ?= $(PREFIX)/bin
export GOPATH ?= $(CURDIR)/.gopath
export GO111MODULE := on export GO111MODULE := on
all: generate-version-and-build all: generate-version-and-build
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
else
ireallywantobuildon_linux.go:
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"
clean-ireallywantobuildon_linux.go:
@rm -f ireallywantobuildon_linux.go
.PHONY: clean-ireallywantobuildon_linux.go
clean: clean-ireallywantobuildon_linux.go
wireguard-go: ireallywantobuildon_linux.go
endif
MAKEFLAGS += --no-print-directory MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \ tag="$$(git describe --tags --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \ ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \ echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true git update-index --assume-unchanged version.go || true
@$(MAKE) wireguard-go @$(MAKE) amneziawg-go
wireguard-go: $(wildcard *.go) $(wildcard */*.go) amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@" go build -v -o "$@"
install: wireguard-go install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
test:
go test ./...
clean: clean:
rm -f wireguard-go rm -f amneziawg-go
.PHONY: all clean install generate-version-and-build .PHONY: all clean test install generate-version-and-build

View file

@ -1,26 +1,27 @@
# Go Implementation of [WireGuard](https://www.wireguard.com/) # Go Implementation of AmneziaWG
This is an implementation of WireGuard in Go. AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
***WARNING:*** This is a work in progress and not ready for prime time, with no official "releases" yet. It is extremely rough around the edges and leaves much to be desired. There are bugs and we are not yet in a position to make claims about its security. Beware. The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
## Usage ## Usage
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run: Simply run:
``` ```
$ wireguard-go wg0 $ amneziawg-go wg0
``` ```
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down. This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
To run wireguard-go without forking to the background, pass `-f` or `--foreground`: To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
``` ```
$ wireguard-go -f wg0 $ amneziawg-go -f wg0
``` ```
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`. To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@ -28,57 +29,24 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux ### Linux
This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions. This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
### macOS ### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
### Windows ### Windows
It is currently a work in progress to strip out the beginnings of an experiment done with the OpenVPN tuntap driver and instead port to the new UWP APIs for tunnels. In other words, this does not *yet* work on Windows. This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
### FreeBSD
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
### OpenBSD
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
## Building ## Building
This requires an installation of [go](https://golang.org) ≥ 1.11. This requires an installation of the latest version of [Go](https://go.dev/).
``` ```
$ git clone https://git.zx2c4.com/wireguard-go $ git clone https://github.com/amnezia-vpn/amneziawg-go
$ cd wireguard-go $ cd amneziawg-go
$ make $ make
``` ```
## License
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License version 2 as
published by the Free Software Foundation.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
---------------------------------------------------------------------------
Additional Permissions For Submission to Apple App Store: Provided that you
are otherwise in compliance with the GPLv2 for each covered work you convey
(including without limitation making the Corresponding Source available in
compliance with Section 3 of the GPLv2), you are granted the additional
permission to convey through the Apple App Store non-source executable
versions of the Program as incorporated into each applicable covered work
as Executable Versions only under the Mozilla Public License version 2.0
(https://www.mozilla.org/en-US/MPL/2.0/).

View file

@ -1,251 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"errors"
"math/bits"
"net"
"sync"
"unsafe"
)
type trieEntry struct {
cidr uint
child [2]*trieEntry
bits net.IP
peer *Peer
// index of "branching" bit
bit_at_byte uint
bit_at_shift uint
}
func isLittleEndian() bool {
one := uint32(1)
return *(*byte)(unsafe.Pointer(&one)) != 0
}
func swapU32(i uint32) uint32 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes32(i)
}
func swapU64(i uint64) uint64 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
return uint(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x)))
}
a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = *a ^ *b
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
} else {
panic("Wrong size bit string")
}
}
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
// walk recursively
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// at leaf
if node == nil {
return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
}
// traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.peer = peer
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}
// split node
newNode := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
cidr = min(cidr, common)
// check for shorter prefix
if newNode.cidr == cidr {
bit := newNode.choose(node.bits)
newNode.child[bit] = node
return newNode
}
// create new parent for node & newNode
parent := &trieEntry{
bits: ip,
peer: nil,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
bit := parent.choose(ip)
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
}
func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bit_at_byte == size {
break
}
bit := node.choose(ip)
node = node.child[bit]
}
return found
}
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
results = append(results, net.IPNet{
Mask: mask,
IP: node.bits.Mask(mask),
})
}
results = node.child[0].entriesForPeer(p, results)
results = node.child[1].entriesForPeer(p, results)
return results
}
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 0, 10)
allowed = table.IPv4.entriesForPeer(peer, allowed)
allowed = table.IPv6.entriesForPeer(peer, allowed)
return allowed
}
func (table *AllowedIPs) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer)
table.IPv6 = table.IPv6.removeByPeer(peer)
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
switch len(ip) {
case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
default:
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv4.lookup(address)
}
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv6.lookup(address)
}

View file

@ -1,131 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"math/rand"
"sort"
"testing"
)
const (
NumberOfPeers = 100
NumberOfAddresses = 250
NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
cidr uint
bits []byte
}
type SlowRouter []*SlowNode
func (r SlowRouter) Len() int {
return len(r)
}
func (r SlowRouter) Less(i, j int) bool {
return r[i].cidr > r[j].cidr
}
func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
t.bits = addr
return r
}
}
r = append(r, &SlowNode{
cidr: cidr,
bits: addr,
peer: peer,
})
sort.Sort(r)
return r
}
func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r {
common := commonBits(t.bits, addr)
if common >= t.cidr {
return t.peer
}
}
return nil
}
func TestTrieRandomIPv4(t *testing.T) {
var trie *trieEntry
var slow SlowRouter
var peers []*Peer
rand.Seed(1)
const AddressLength = 4
for n := 0; n < NumberOfPeers; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
}
}
}
func TestTrieRandomIPv6(t *testing.T) {
var trie *trieEntry
var slow SlowRouter
var peers []*Peer
rand.Seed(1)
const AddressLength = 16
for n := 0; n < NumberOfPeers; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
}
}
}

180
conn.go
View file

@ -1,180 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
)
const (
ConnRoutineNumber = 2
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.mutex.Lock()
defer device.net.mutex.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap {
peer.mutex.Lock()
defer peer.mutex.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.mutex.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.mutex.Lock()
defer device.net.mutex.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port, device)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap {
peer.mutex.Lock()
defer peer.mutex.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.mutex.RUnlock()
// start receiving routines
device.net.starting.Add(ConnRoutineNumber)
device.net.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.mutex.Lock()
err := unsafeCloseBind(device)
device.net.mutex.Unlock()
return err
}

544
conn/bind_std.go Normal file
View file

@ -0,0 +1,544 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
_ Bind = (*StdNetBind)(nil)
)
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
}
func NewStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if
// supported. Typically this is a PKTINFO structure from/for control
// messages, see unix.PKTINFO for an example.
src []byte
}
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
}
func (e *StdNetEndpoint) ClearSrc() {
if e.src != nil {
// Truncate src, no need to reallocate.
e.src = e.src[:0]
}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
s.ipv4PC = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
s.ipv6PC = nil
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
br = s.ipv6PC
is6 = true
offload = s.ipv6TxOffload
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
return err
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}

250
conn/bind_std_test.go Normal file
View file

@ -0,0 +1,250 @@
package conn
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
fns, _, err := bind.Open(0)
if err != nil {
t.Fatal(err)
}
bind.Close()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, 1)
sizes := make([]int, 1)
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(bufs, sizes, eps)
}
}
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func Test_coalesceMessages(t *testing.T) {
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1").To4(),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func mockGetGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_splitCoalescedMessages(t *testing.T) {
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1<<16-1)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}

601
conn/bind_windows.go Normal file
View file

@ -0,0 +1,601 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"encoding/binary"
"io"
"net"
"net/netip"
"strconv"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
)
const (
packetsPerRing = 1024
bytesPerPacket = 2048 - 32
receiveSpins = 15
)
type ringPacket struct {
addr WinRingEndpoint
data [bytesPerPacket]byte
}
type ringBuffer struct {
packets uintptr
head, tail uint32
id winrio.BufferId
iocp windows.Handle
isFull bool
cq winrio.Cq
mu sync.Mutex
overlapped windows.Overlapped
}
func (rb *ringBuffer) Push() *ringPacket {
for rb.isFull {
panic("ring is full")
}
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
rb.tail += 1
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
rb.isFull = true
}
return ret
}
func (rb *ringBuffer) Return(count uint32) {
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
return
}
rb.head += count
rb.isFull = false
}
type afWinRingBind struct {
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
mu sync.Mutex
blackhole bool
}
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
isOpen atomic.Uint32 // 0, 1, or 2
}
func NewDefaultBind() Bind { return NewWinRingBind() }
func NewWinRingBind() Bind {
if !winrio.Initialize() {
return NewStdNetBind()
}
return new(WinRingBind)
}
type WinRingEndpoint struct {
family uint16
data [30]byte
}
var (
_ Bind = (*WinRingBind)(nil)
_ Endpoint = (*WinRingEndpoint)(nil)
)
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
host16, err := windows.UTF16PtrFromString(host)
if err != nil {
return nil, err
}
port16, err := windows.UTF16PtrFromString(port)
if err != nil {
return nil, err
}
hints := windows.AddrinfoW{
Flags: windows.AI_NUMERICHOST,
Family: windows.AF_UNSPEC,
Socktype: windows.SOCK_DGRAM,
Protocol: windows.IPPROTO_UDP,
}
var addrinfo *windows.AddrinfoW
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
if err != nil {
return nil, err
}
defer windows.FreeAddrInfoW(addrinfo)
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS
}
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
}
func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
return netip.Addr{}
}
func (e *WinRingEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
switch e.family {
case windows.AF_INET:
b := make([]byte, 0, 6)
b = append(b, e.data[2:6]...)
b = append(b, e.data[1], e.data[0])
return b
case windows.AF_INET6:
b := make([]byte, 0, 18)
b = append(b, e.data[6:22]...)
b = append(b, e.data[1], e.data[0])
return b
}
return nil
}
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
func (e *WinRingEndpoint) SrcToString() string {
return ""
}
func (ring *ringBuffer) CloseAndZero() {
if ring.cq != 0 {
winrio.CloseCompletionQueue(ring.cq)
ring.cq = 0
}
if ring.iocp != 0 {
windows.CloseHandle(ring.iocp)
ring.iocp = 0
}
if ring.id != 0 {
winrio.DeregisterBuffer(ring.id)
ring.id = 0
}
if ring.packets != 0 {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0
}
ring.head = 0
ring.tail = 0
ring.isFull = false
}
func (bind *afWinRingBind) CloseAndZero() {
bind.rx.CloseAndZero()
bind.tx.CloseAndZero()
if bind.sock != 0 {
windows.CloseHandle(bind.sock)
bind.sock = 0
}
bind.blackhole = false
}
func (bind *WinRingBind) closeAndZero() {
bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
func (ring *ringBuffer) Open() error {
var err error
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return err
}
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
if err != nil {
return err
}
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return err
}
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
if err != nil {
return err
}
return nil
}
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
var err error
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return nil, err
}
err = bind.rx.Open()
if err != nil {
return nil, err
}
err = bind.tx.Open()
if err != nil {
return nil, err
}
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
if err != nil {
return nil, err
}
err = windows.Bind(bind.sock, sa)
if err != nil {
return nil, err
}
sa, err = windows.Getsockname(bind.sock)
if err != nil {
return nil, err
}
return sa, nil
}
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock()
defer bind.mu.Unlock()
defer func() {
if err != nil {
bind.closeAndZero()
}
}()
if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil {
return nil, 0, err
}
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil {
return nil, 0, err
}
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
err = bind.v6.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
}
bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
bind.mu.RUnlock()
bind.mu.Lock()
defer bind.mu.Unlock()
bind.closeAndZero()
return nil
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (bind *WinRingBind) BatchSize() int {
// TODO: implement batching in and out of the ring
return 1
}
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
func (bind *afWinRingBind) InsertReceiveRequest() error {
packet := bind.rx.Push()
dataBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
Length: uint32(len(packet.data)),
}
addressBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
defer bind.rx.mu.Unlock()
var err error
var count uint32
var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
}
if count == 0 {
err = winrio.Notify(bind.rx.cq)
if err != nil {
return 0, nil, err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return 0, nil, err
}
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
if count == 0 {
return 0, nil, io.ErrNoProgress
}
}
bind.rx.Return(1)
err = bind.InsertReceiveRequest()
if err != nil {
return 0, nil, err
}
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
}
if results[0].Status != 0 {
return 0, nil, windows.Errno(results[0].Status)
}
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
ep := packet.addr
n := copy(buf, packet.data[:results[0].BytesTransferred])
return n, &ep, nil
}
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
}
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
return io.ErrShortBuffer
}
bind.tx.mu.Lock()
defer bind.tx.mu.Unlock()
var results [packetsPerRing]winrio.Result
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 && bind.tx.isFull {
err := winrio.Notify(bind.tx.cq)
if err != nil {
return err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return err
}
if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 {
return io.ErrNoProgress
}
}
if count > 0 {
bind.tx.Return(count)
}
packet := bind.tx.Push()
packet.addr = *nend
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
Length: uint32(len(buf)),
}
addressBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
for _, buf := range bufs {
switch nend.family {
case windows.AF_INET:
if bind.v4.blackhole {
continue
}
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
case windows.AF_INET6:
if bind.v6.blackhole {
continue
}
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
}
}
return nil
}
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
s.blackhole4 = blackhole
return nil
}
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
s.blackhole6 = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
if err != nil {
return err
}
bind.v4.blackhole = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
if err != nil {
return err
}
bind.v6.blackhole = blackhole
return nil
}
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
const IP_UNICAST_IF = 31
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
if err != nil {
return err
}
return nil
}
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
}

136
conn/bindtest/bindtest.go Normal file
View file

@ -0,0 +1,136 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type ChannelBind struct {
rx4, tx4 *chan []byte
rx6, tx6 *chan []byte
closeSignal chan bool
source4, source6 ChannelEndpoint
target4, target6 ChannelEndpoint
}
type ChannelEndpoint uint16
var (
_ conn.Bind = (*ChannelBind)(nil)
_ conn.Endpoint = (*ChannelEndpoint)(nil)
)
func NewChannelBinds() [2]conn.Bind {
arx4 := make(chan []byte, 8192)
brx4 := make(chan []byte, 8192)
arx6 := make(chan []byte, 8192)
brx6 := make(chan []byte, 8192)
var binds [2]ChannelBind
binds[0].rx4 = &arx4
binds[0].tx4 = &brx4
binds[1].rx4 = &brx4
binds[1].tx4 = &arx4
binds[0].rx6 = &arx6
binds[0].tx6 = &brx6
binds[1].rx6 = &brx6
binds[1].tx6 = &arx6
binds[0].target4 = ChannelEndpoint(1)
binds[1].target4 = ChannelEndpoint(2)
binds[0].target6 = ChannelEndpoint(3)
binds[1].target6 = ChannelEndpoint(4)
binds[0].source4 = binds[1].target4
binds[0].source6 = binds[1].target6
binds[1].source4 = binds[0].target4
binds[1].source6 = binds[0].target6
return [2]conn.Bind{&binds[0], &binds[1]}
}
func (c ChannelEndpoint) ClearSrc() {}
func (c ChannelEndpoint) SrcToString() string { return "" }
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
fns = append(fns, c.makeReceiveFunc(*c.rx4))
fns = append(fns, c.makeReceiveFunc(*c.rx6))
if rand.Uint32()&1 == 0 {
return fns, uint16(c.source4), nil
} else {
return fns, uint16(c.source6), nil
}
}
func (c *ChannelBind) Close() error {
if c.closeSignal != nil {
select {
case <-c.closeSignal:
default:
close(c.closeSignal)
}
}
return nil
}
func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
return 0, net.ErrClosed
case rx := <-ch:
copied := copy(bufs[0], rx)
sizes[0] = copied
eps[0] = c.target6
return 1, nil
}
}
}
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
for _, b := range bufs {
select {
case <-c.closeSignal:
return net.ErrClosed
default:
bc := make([]byte, len(b))
copy(bc, b)
if ep.(ChannelEndpoint) == c.target4 {
*c.tx4 <- bc
} else if ep.(ChannelEndpoint) == c.target6 {
*c.tx6 <- bc
} else {
return os.ErrInvalid
}
}
}
return nil
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return ChannelEndpoint(addr.Port()), nil
}

34
conn/boundif_android.go Normal file
View file

@ -0,0 +1,34 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return -1, err
}
return
}
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return -1, err
}
return
}

133
conn/conn.go Normal file
View file

@ -0,0 +1,133 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
package conn
import (
"errors"
"fmt"
"net/netip"
"reflect"
"runtime"
"strings"
)
const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation.
type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection.
// fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// Send writes one or more packets in bufs to address ep. The length of
// bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
}
// PeekLookAtSocketFd is implemented by Bind objects that support having their
// file descriptor peeked at. Used by wireguard-android.
type PeekLookAtSocketFd interface {
PeekLookAtSocketFd4() (fd int, err error)
PeekLookAtSocketFd6() (fd int, err error)
}
// An Endpoint maintains the source/destination caching for a peer.
//
// dst: the remote address of a peer ("endpoint" in uapi terminology)
// src: the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() netip.Addr
SrcIP() netip.Addr
}
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}

24
conn/conn_test.go Normal file
View file

@ -0,0 +1,24 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"testing"
)
func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
)
const want = "TestPrettyName"
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
if got := recvFunc.PrettyName(); got != want {
t.Errorf("PrettyName() = %v, want %v", got, want)
}
})
}

43
conn/controlfns.go Normal file
View file

@ -0,0 +1,43 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"syscall"
)
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
// the max supported by a default configuration of macOS. Some platforms will
// silently clamp the value to other maximums, such as linux clamping to
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
// around this limitation)
const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

61
conn/controlfns_linux.go Normal file
View file

@ -0,0 +1,61 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
// fail silently - the result of failure is lower performance on very fast
// links or high latency links.
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// Set up to *mem_max
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
})
},
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
if runtime.GOOS != "android" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
}
case "udp6":
c.Control(func(fd uintptr) {
if runtime.GOOS != "android" {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

35
conn/controlfns_unix.go Normal file
View file

@ -0,0 +1,35 @@
//go:build !windows && !linux && !wasm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
})
},
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View file

@ -0,0 +1,23 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/windows"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
})
},
)
}

10
conn/default.go Normal file
View file

@ -0,0 +1,10 @@
//go:build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func NewDefaultBind() Bind { return NewStdNetBind() }

12
conn/errors_default.go Normal file
View file

@ -0,0 +1,12 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

28
conn/errors_linux.go Normal file
View file

@ -0,0 +1,28 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
// It occurs when the peer mtu + wg headers is greater than path mtu.
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
}
return false
}

15
conn/features_default.go Normal file
View file

@ -0,0 +1,15 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

31
conn/features_linux.go Normal file
View file

@ -0,0 +1,31 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
// use setsockopt workaround
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
rxOffload = errSyscall == nil
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

21
conn/gso_default.go Normal file
View file

@ -0,0 +1,21 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

65
conn/gso_linux.go Normal file
View file

@ -0,0 +1,65 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

12
conn/mark_default.go Normal file
View file

@ -0,0 +1,12 @@
//go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}

65
conn/mark_unix.go Normal file
View file

@ -0,0 +1,65 @@
//go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"runtime"
"golang.org/x/sys/unix"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if s.ipv4 != nil {
fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
if s.ipv6 != nil {
fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
return nil
}

42
conn/sticky_default.go Normal file
View file

@ -0,0 +1,42 @@
//go:build !linux || android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net/netip"
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

112
conn/sticky_linux.go Normal file
View file

@ -0,0 +1,112 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
if cap(*control) < len(ep.src) {
return
}
*control = (*control)[:0]
*control = append(*control, ep.src...)
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

266
conn/sticky_linux_test.go Normal file
View file

@ -0,0 +1,266 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
var buf []byte
if addr.Is4() {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet4Pktinfo{
Ifindex: ifidx,
Spec_dst: addr.As4(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
} else {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet6Pktinfo{
Ifindex: uint32(ifidx),
Addr: addr.As16(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
}
ep.src = buf
}
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
var control []byte
ep := &StdNetEndpoint{}
setSrc(ep, netip.MustParseAddr("::1"), 5)
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 0 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("Multiple", func(t *testing.T) {
zeroControl := make([]byte, unix.CmsgSpace(0))
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
zeroHdr.SetLen(unix.CmsgLen(0))
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
combined := make([]byte, 0)
combined = append(combined, zeroControl...)
combined = append(combined, control...)
ep := &StdNetEndpoint{}
getSrcFromControl(combined, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

254
conn/winrio/rio_windows.go Normal file
View file

@ -0,0 +1,254 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package winrio
import (
"log"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
const (
MsgDontNotify = 1
MsgDefer = 2
MsgWaitAll = 4
MsgCommitOnly = 8
MaxCqSize = 0x8000000
invalidBufferId = 0xFFFFFFFF
invalidCq = 0
invalidRq = 0
corruptCq = 0xFFFFFFFF
)
var extensionFunctionTable struct {
cbSize uint32
rioReceive uintptr
rioReceiveEx uintptr
rioSend uintptr
rioSendEx uintptr
rioCloseCompletionQueue uintptr
rioCreateCompletionQueue uintptr
rioCreateRequestQueue uintptr
rioDequeueCompletion uintptr
rioDeregisterBuffer uintptr
rioNotify uintptr
rioRegisterBuffer uintptr
rioResizeCompletionQueue uintptr
rioResizeRequestQueue uintptr
}
type Cq uintptr
type Rq uintptr
type BufferId uintptr
type Buffer struct {
Id BufferId
Offset uint32
Length uint32
}
type Result struct {
Status int32
BytesTransferred uint32
SocketContext uint64
RequestContext uint64
}
type notificationCompletionType uint32
const (
eventCompletion notificationCompletionType = 1
iocpCompletion notificationCompletionType = 2
)
type eventNotificationCompletion struct {
completionType notificationCompletionType
event windows.Handle
notifyReset uint32
}
type iocpNotificationCompletion struct {
completionType notificationCompletionType
iocp windows.Handle
key uintptr
overlapped *windows.Overlapped
}
var (
initialized sync.Once
available bool
)
func Initialize() bool {
initialized.Do(func() {
var (
err error
socket windows.Handle
cq Cq
)
defer func() {
if err == nil {
return
}
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
return
}
log.Printf("Registered I/O is unavailable: %v", err)
}()
socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return
}
defer windows.CloseHandle(socket)
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
ob := uint32(0)
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
&ob, nil, 0)
if err != nil {
return
}
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
var iocp windows.Handle
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return
}
defer windows.CloseHandle(iocp)
var overlapped windows.Overlapped
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
if err != nil {
return
}
defer CloseCompletionQueue(cq)
_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
if err != nil {
return
}
available = true
})
return available
}
func Socket(af, typ, proto int32) (windows.Handle, error) {
return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
}
func CloseCompletionQueue(cq Cq) {
_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
}
func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
notificationCompletion := &eventNotificationCompletion{
completionType: eventCompletion,
event: event,
}
if notifyReset {
notificationCompletion.notifyReset = 1
}
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
notificationCompletion := &iocpNotificationCompletion{
completionType: iocpCompletion,
iocp: iocp,
key: key,
overlapped: overlapped,
}
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
if ret == invalidRq {
return 0, err
}
return Rq(ret), nil
}
func DequeueCompletion(cq Cq, results []Result) uint32 {
var array uintptr
if len(results) > 0 {
array = uintptr(unsafe.Pointer(&results[0]))
}
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
if ret == corruptCq {
panic("cq is corrupt")
}
return uint32(ret)
}
func DeregisterBuffer(id BufferId) {
_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
}
func RegisterBuffer(buffer []byte) (BufferId, error) {
var buf unsafe.Pointer
if len(buffer) > 0 {
buf = unsafe.Pointer(&buffer[0])
}
return RegisterPointer(buf, uint32(len(buffer)))
}
func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
if ret == invalidBufferId {
return 0, err
}
return BufferId(ret), nil
}
func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
if ret == 0 {
return err
}
return nil
}
func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
if ret == 0 {
return err
}
return nil
}
func Notify(cq Cq) error {
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
if ret != 0 {
return windows.Errno(ret)
}
return nil
}

View file

@ -1,217 +0,0 @@
// +build !linux android
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"golang.org/x/sys/unix"
"net"
"os"
"runtime"
"syscall"
)
/* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type NativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
}
type NativeEndpoint net.UDPAddr
var _ Bind = (*NativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*NativeEndpoint)(addr), err
}
func (_ *NativeEndpoint) ClearSrc() {}
func (e *NativeEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
}
func (e *NativeEndpoint) SrcIP() net.IP {
return nil // not supported
}
func (e *NativeEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP.To4()
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *NativeEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String()
}
func (e *NativeEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// retrieve port
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func extractErrno(err error) error {
opErr, ok := err.(*net.OpError)
if !ok {
return nil
}
syscallErr, ok := opErr.Err.(*os.SyscallError)
if !ok {
return nil
}
return syscallErr.Err
}
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
var err error
var bind NativeBind
port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port)
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
return nil, 0, err
}
bind.ipv6, port, err = listenNet("udp6", port)
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
return nil, 0, err
bind.ipv4.Close()
bind.ipv4 = nil
return nil, 0, err
}
return &bind, uint16(port), nil
}
func (bind *NativeBind) Close() error {
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
}
if err1 != nil {
return err1
}
return err2
}
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err
}
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (bind *NativeBind) SetMark(mark uint32) error {
if fwmarkIoctl == 0 {
return nil
}
if bind.ipv4 != nil {
fd, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err != nil {
return err
}
}
if bind.ipv6 != nil {
fd, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err != nil {
return err
}
}
return nil
}

View file

@ -1,746 +0,0 @@
// +build !android
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package main
import (
"errors"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/sys/unix"
"net"
"strconv"
"sync"
"syscall"
"unsafe"
)
const (
FD_ERR = -1
)
type IPv4Source struct {
src [4]byte
ifindex int32
}
type IPv6Source struct {
src [16]byte
//ifindex belongs in dst.ZoneId
}
type NativeEndpoint struct {
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
}
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) src6() *IPv6Source {
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
type NativeBind struct {
sock4 int
sock6 int
netlinkSock int
netlinkCancel *rwcancel.RWCancel
lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = (*NativeBind)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
return nil, errors.New("Invalid IP address")
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}
func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
var err error
var bind NativeBind
var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket()
if err != nil {
return nil, 0, err
}
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
if err != nil {
unix.Close(bind.netlinkSock)
return nil, 0, err
}
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful
bind.sock6, newPort, err = create6(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
return nil, 0, err
}
} else {
port = newPort
}
// attempt ipv4 bind, update port if succesful
bind.sock4, newPort, err = create4(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
return nil, 0, err
}
} else {
port = newPort
}
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
return nil, 0, errors.New("ipv4 and ipv6 not supported")
}
return &bind, port, nil
}
func (bind *NativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
if bind.sock4 != -1 {
err := unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
bind.lastMark = value
return nil
}
func closeUnblock(fd int) error {
// shutdown to unblock readers and writers
unix.Shutdown(fd, unix.SHUT_RDWR)
return unix.Close(fd)
}
func (bind *NativeBind) Close() error {
var err1, err2, err3 error
if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6)
}
if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4)
}
err3 = bind.netlinkCancel.Cancel()
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive4(
bind.sock4,
buff,
&end,
)
return n, &end, err
}
func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
if !nend.isV6 {
if bind.sock4 == -1 {
return syscall.EAFNOSUPPORT
}
return send4(bind.sock4, nend, buff)
} else {
if bind.sock6 == -1 {
return syscall.EAFNOSUPPORT
}
return send6(bind.sock6, nend, buff)
}
}
func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.src4().src[0],
end.src4().src[1],
end.src4().src[2],
end.src4().src[3],
)
} else {
return end.src6().src[:]
}
}
func (end *NativeEndpoint) DstIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else {
return end.dst6().Addr[:]
}
}
func (end *NativeEndpoint) DstToBytes() []byte {
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
}
func (end *NativeEndpoint) SrcToString() string {
return end.SrcIP().String()
}
func (end *NativeEndpoint) DstToString() string {
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
if !end.isV6 {
udpAddr.Port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
}
return udpAddr.String()
}
func (end *NativeEndpoint) ClearDst() {
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *NativeEndpoint) ClearSrc() {
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return FD_ERR, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return FD_ERR, 0, err
}
return fd, uint16(addr.Port), err
}
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return FD_ERR, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return FD_ERR, 0, err
}
return fd, uint16(addr.Port), err
}
func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().src,
Ifindex: end.src4().ifindex,
},
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
}
return err
}
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
}
return err
}
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = false
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().src = cmsg.pktinfo.Spec_dst
end.src4().ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = true
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
return size, nil
}
func (bind *NativeBind) routineRouteListener(device *Device) {
type peerEndpointPtr struct {
peer *Peer
endpoint *Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(bind.netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !bind.netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.mutex.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.mutex.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
pePtr.peer.mutex.Unlock()
break
}
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
pePtr.peer.mutex.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.mutex.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.mutex.RLock()
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
peer.mutex.RUnlock()
continue
}
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.mutex.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
peer.endpoint.(*NativeEndpoint).dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{
Len: 8,
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
},
uint32(bind.lastMark),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.mutex.RUnlock()
i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.mutex.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}

393
device.go
View file

@ -1,393 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"git.zx2c4.com/wireguard-go/ratelimiter"
"git.zx2c4.com/wireguard-go/tun"
"runtime"
"sync"
"sync/atomic"
"time"
)
const (
DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
)
type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger
// synchronized resources (locks acquired in order)
state struct {
starting sync.WaitGroup
stopping sync.WaitGroup
mutex sync.Mutex
changing AtomicBool
current bool
}
net struct {
starting sync.WaitGroup
stopping sync.WaitGroup
mutex sync.RWMutex
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
mutex sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
}
peers struct {
mutex sync.RWMutex
keyMap map[NoisePublicKey]*Peer
}
// unprotected / "self-synchronising resources"
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
rate struct {
underLoadUntil atomic.Value
limiter ratelimiter.Ratelimiter
}
pool struct {
messageBufferPool *sync.Pool
messageBufferReuseChan chan *[MaxMessageSize]byte
inboundElementPool *sync.Pool
inboundElementReuseChan chan *QueueInboundElement
outboundElementPool *sync.Pool
outboundElementReuseChan chan *QueueOutboundElement
}
queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
signals struct {
stop chan struct{}
}
tun struct {
device tun.TUNDevice
mtu int32
}
}
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
* Must hold device.peers.mutex.
*/
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
delete(device.peers.keyMap, key)
}
func deviceUpdateState(device *Device) {
// check if state already being updated (guard)
if device.state.changing.Swap(true) {
return
}
// compare to current state of device
device.state.mutex.Lock()
newIsUp := device.isUp.Get()
if newIsUp == device.state.current {
device.state.changing.Set(false)
device.state.mutex.Unlock()
return
}
// change state of device
switch newIsUp {
case true:
if err := device.BindUpdate(); err != nil {
device.isUp.Set(false)
break
}
device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
}
device.peers.mutex.RUnlock()
case false:
device.BindClose()
device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap {
peer.Stop()
}
device.peers.mutex.RUnlock()
}
// update state variables
device.state.current = newIsUp
device.state.changing.Set(false)
device.state.mutex.Unlock()
// check for state change in the mean time
deviceUpdateState(device)
}
func (device *Device) Up() {
// closed device cannot be brought up
if device.isClosed.Get() {
return
}
device.isUp.Set(true)
deviceUpdateState(device)
}
func (device *Device) Down() {
device.isUp.Set(false)
deviceUpdateState(device)
}
func (device *Device) IsUnderLoad() bool {
// check if currently under load
now := time.Now()
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
return true
}
// check if recently under load
until := device.rate.underLoadUntil.Load().(time.Time)
return until.After(now)
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources
device.staticIdentity.mutex.Lock()
defer device.staticIdentity.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock()
defer peer.handshake.mutex.RUnlock()
}
// remove peers with matching public keys
publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
unsafeRemovePeer(device, peer, key)
}
}
// update key material
device.staticIdentity.privateKey = sk
device.staticIdentity.publicKey = publicKey
device.cookieChecker.Init(publicKey)
// do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
for key, peer := range device.peers.keyMap {
handshake := &peer.handshake
if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
}
}
return nil
}
func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
device := new(Device)
device.isUp.Set(false)
device.isClosed.Set(false)
device.log = logger
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
logger.Error.Println("Trouble determining MTU, assuming default:", err)
mtu = DefaultMTU
}
device.tun.mtu = int32(mtu)
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
device.indexTable.Init()
device.allowedips.Reset()
device.PopulatePools()
// create queues
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
device.signals.stop = make(chan struct{})
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers
cpus := runtime.NumCPU()
device.state.starting.Wait()
device.state.stopping.Wait()
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 {
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
}
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
device.state.starting.Wait()
return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.mutex.RLock()
defer device.peers.mutex.RUnlock()
return device.peers.keyMap[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// stop peer and remove from routing
peer, ok := device.peers.keyMap[key]
if ok {
unsafeRemovePeer(device, peer, key)
}
}
func (device *Device) RemoveAllPeers() {
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for key, peer := range device.peers.keyMap {
unsafeRemovePeer(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) FlushPacketQueues() {
for {
select {
case elem, ok := <-device.queue.decryption:
if ok {
elem.Drop()
}
case elem, ok := <-device.queue.encryption:
if ok {
elem.Drop()
}
case <-device.queue.handshake:
default:
return
}
}
}
func (device *Device) Close() {
if device.isClosed.Swap(true) {
return
}
device.state.starting.Wait()
device.log.Info.Println("Device closing")
device.state.changing.Set(true)
device.state.mutex.Lock()
defer device.state.mutex.Unlock()
device.tun.device.Close()
device.BindClose()
device.isUp.Set(false)
close(device.signals.stop)
device.RemoveAllPeers()
device.state.stopping.Wait()
device.FlushPacketQueues()
device.rate.limiter.Close()
device.state.changing.Set(false)
device.log.Info.Println("Interface closed")
}
func (device *Device) Wait() chan struct{} {
return device.signals.stop
}

294
device/allowedips.go Normal file
View file

@ -0,0 +1,294 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
"encoding/binary"
"errors"
"math/bits"
"net"
"net/netip"
"sync"
"unsafe"
)
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct {
peer *Peer
child [2]*trieEntry
parent parentIndirection
cidr uint8
bitAtByte uint8
bitAtShift uint8
bits []byte
perPeerElem *list.Element
}
func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := binary.BigEndian.Uint32(ip1)
b := binary.BigEndian.Uint32(ip2)
x := a ^ b
return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
a := binary.BigEndian.Uint64(ip1)
b := binary.BigEndian.Uint64(ip2)
x := a ^ b
if x != 0 {
return uint8(bits.LeadingZeros64(x))
}
a = binary.BigEndian.Uint64(ip1[8:])
b = binary.BigEndian.Uint64(ip2[8:])
x = a ^ b
return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
}
func (node *trieEntry) addToPeerEntries() {
node.perPeerElem = node.peer.trieEntries.PushBack(node)
}
func (node *trieEntry) removeFromPeerEntries() {
if node.perPeerElem != nil {
node.peer.trieEntries.Remove(node.perPeerElem)
node.perPeerElem = nil
}
}
func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
func (node *trieEntry) maskSelf() {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i]
}
}
func (node *trieEntry) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
parent: trie,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
*trie.parentBit = node
return
}
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
if exact {
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return
}
newNode := &trieEntry{
peer: peer,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
newNode.maskSelf()
newNode.addToPeerEntries()
var down *trieEntry
if node == nil {
down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr {
bit := newNode.choose(down.bits)
down.parent = parentIndirection{&newNode.child[bit], bit}
newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
}
node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
bit := node.choose(down.bits)
down.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = down
bit = node.choose(newNode.bits)
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
}
func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bitAtByte == size {
break
}
bit := node.choose(ip)
node = node.child[bit]
}
return found
}
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
}
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
switch len(ip) {
case net.IPv6len:
return table.IPv6.lookup(ip)
case net.IPv4len:
return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
}

View file

@ -0,0 +1,141 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"net"
"net/netip"
"sort"
"testing"
)
const (
NumberOfPeers = 100
NumberOfPeerRemovals = 4
NumberOfAddresses = 250
NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
cidr uint8
bits []byte
}
type SlowRouter []*SlowNode
func (r SlowRouter) Len() int {
return len(r)
}
func (r SlowRouter) Less(i, j int) bool {
return r[i].cidr > r[j].cidr
}
func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
t.bits = addr
return r
}
}
r = append(r, &SlowNode{
cidr: cidr,
bits: addr,
peer: peer,
})
sort.Sort(r)
return r
}
func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r {
common := commonBits(t.bits, addr)
if common >= t.cidr {
return t.peer
}
}
return nil
}
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
n := 0
for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
var p int
for p = 0; ; p++ {
for n := 0; n < NumberOfTests; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.Lookup(addr4[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
}
var addr6 [16]byte
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
}
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
}
}

View file

@ -1,47 +1,24 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"math/rand" "math/rand"
"net" "net"
"net/netip"
"testing" "testing"
) )
/* Todo: More comprehensive
*/
type testPairCommonBits struct { type testPairCommonBits struct {
s1 []byte s1 []byte
s2 []byte s2 []byte
match uint match uint8
}
type testPairTrieInsert struct {
key []byte
cidr uint
peer *Peer
}
type testPairTrieLookup struct {
key []byte
peer *Peer
}
func printTrie(t *testing.T, p *trieEntry) {
if p == nil {
return
}
t.Log(p)
printTrie(t, p.child[0])
printTrie(t, p.child[1])
} }
func TestCommonBits(t *testing.T) { func TestCommonBits(t *testing.T) {
tests := []testPairCommonBits{ tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) {
} }
} }
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry var trie *trieEntry
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1) rand.Seed(1)
const AddressLength = 4 const AddressLength = 4
for n := 0; n < peerNumber; n += 1 { for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < addressNumber; n += 1 { for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index]) root.insert(addr[:], cidr, peers[index])
} }
for n := 0; n < b.N; n += 1 { for n := 0; n < b.N; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
trie.lookup(addr[:]) trie.lookup(addr[:])
@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }
} }
assertNEQ := func(peer *Peer, a, b, c, d byte) { assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer { if p == peer {
t.Error("Assert NEQ failed") t.Error("Assert NEQ failed")
} }
@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0) assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0) assertEQ(a, 255, 0, 0, 0)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0)
@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
trie = nil allowedIPs.RemoveByPeer(a)
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1) assertNEQ(a, 192, 168, 0, 1)
} }
@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
expand := func(a uint32) []byte { expand := func(a uint32) []byte {
var out [4]byte var out [4]byte
@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:] return out[:]
} }
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte var addr []byte
addr = append(addr, expand(a)...) addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
trie = trie.insert(addr, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {
@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
p := trie.lookup(addr) p := allowedIPs.Lookup(addr)
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }

View file

@ -1,23 +1,24 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import "errors" import (
"errors"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type DummyDatagram struct { type DummyDatagram struct {
msg []byte msg []byte
endpoint Endpoint endpoint conn.Endpoint
world bool // better type
} }
type DummyBind struct { type DummyBind struct {
in6 chan DummyDatagram in6 chan DummyDatagram
ou6 chan DummyDatagram
in4 chan DummyDatagram in4 chan DummyDatagram
ou4 chan DummyDatagram
closed bool closed bool
} }
@ -25,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil return nil
} }
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6 datagram, ok := <-b.in6
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
copy(buff, datagram.msg) copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4 datagram, ok := <-b.in4
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
copy(buff, datagram.msg) copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
@ -50,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil return nil
} }
func (b *DummyBind) Send(buff []byte, end Endpoint) error { func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil return nil
} }

137
device/channels.go Normal file
View file

@ -0,0 +1,137 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"runtime"
"sync"
)
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
// An outboundQueue is ref-counted using its wg field.
// An outboundQueue created with newOutboundQueue has one reference.
// Every additional writer must call wg.Add(1).
// Every completed writer must call wg.Done().
// When no further writers will be added,
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
// A handshakeQueue is similar to an outboundQueue; see those docs.
type handshakeQueue struct {
c chan QueueHandshakeElement
wg sync.WaitGroup
}
func newHandshakeQueue() *handshakeQueue {
q := &handshakeQueue{
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
type autodrainingInboundQueue struct {
c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
// It is useful in cases in which is it hard to manage the lifetime of the channel.
// The returned channel must not be closed. Senders should signal shutdown using
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
}
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
default:
return
}
}
}
type autodrainingOutboundQueue struct {
c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
// It is useful in cases in which is it hard to manage the lifetime of the channel.
// The returned channel must not be closed. Senders should signal shutdown using
// some other means, such as sending a sentinel nil values.
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
}
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
}
}

View file

@ -1,9 +1,9 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"time" "time"
@ -12,8 +12,8 @@ import (
/* Specification constants */ /* Specification constants */
const ( const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 13) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 RekeyTimeout = time.Second * 5
@ -22,7 +22,7 @@ const (
RejectAfterTime = time.Second * 180 RejectAfterTime = time.Second * 180
KeepaliveTimeout = time.Second * 10 KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120 CookieRefreshTime = time.Second * 120
HandshakeInitationRate = time.Second / 20 HandshakeInitationRate = time.Second / 50
PaddingMultiple = 16 PaddingMultiple = 16
) )
@ -35,7 +35,6 @@ const (
/* Implementation constants */ /* Implementation constants */
const ( const (
UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers MaxPeers = 1 << 16 // maximum number of configured peers
) )

View file

@ -1,23 +1,23 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"git.zx2c4.com/wireguard-go/xchacha20poly1305"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
) )
type CookieChecker struct { type CookieChecker struct {
mutex sync.RWMutex sync.RWMutex
mac1 struct { mac1 struct {
key [blake2s.Size]byte key [blake2s.Size]byte
} }
mac2 struct { mac2 struct {
@ -28,8 +28,8 @@ type CookieChecker struct {
} }
type CookieGenerator struct { type CookieGenerator struct {
mutex sync.RWMutex sync.RWMutex
mac1 struct { mac1 struct {
key [blake2s.Size]byte key [blake2s.Size]byte
} }
mac2 struct { mac2 struct {
@ -42,8 +42,8 @@ type CookieGenerator struct {
} }
func (st *CookieChecker) Init(pk NoisePublicKey) { func (st *CookieChecker) Init(pk NoisePublicKey) {
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
// mac1 state // mac1 state
@ -67,8 +67,8 @@ func (st *CookieChecker) Init(pk NoisePublicKey) {
} }
func (st *CookieChecker) CheckMAC1(msg []byte) bool { func (st *CookieChecker) CheckMAC1(msg []byte) bool {
st.mutex.RLock() st.RLock()
defer st.mutex.RUnlock() defer st.RUnlock()
size := len(msg) size := len(msg)
smac2 := size - blake2s.Size128 smac2 := size - blake2s.Size128
@ -83,11 +83,11 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2]) return hmac.Equal(mac1[:], msg[smac1:smac2])
} }
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.mutex.RLock() st.RLock()
defer st.mutex.RUnlock() defer st.RUnlock()
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { if time.Since(st.mac2.secretSet) > CookieRefreshTime {
return false return false
} }
@ -119,22 +119,21 @@ func (st *CookieChecker) CreateReply(
recv uint32, recv uint32,
src []byte, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.RLock()
st.mutex.RLock()
// refresh cookie secret // refresh cookie secret
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { if time.Since(st.mac2.secretSet) > CookieRefreshTime {
st.mutex.RUnlock() st.RUnlock()
st.mutex.Lock() st.Lock()
_, err := rand.Read(st.mac2.secret[:]) _, err := rand.Read(st.mac2.secret[:])
if err != nil { if err != nil {
st.mutex.Unlock() st.Unlock()
return nil, err return nil, err
} }
st.mac2.secretSet = time.Now() st.mac2.secretSet = time.Now()
st.mutex.Unlock() st.Unlock()
st.mutex.RLock() st.RLock()
} }
// derive cookie // derive cookie
@ -159,26 +158,21 @@ func (st *CookieChecker) CreateReply(
_, err := rand.Read(reply.Nonce[:]) _, err := rand.Read(reply.Nonce[:])
if err != nil { if err != nil {
st.mutex.RUnlock() st.RUnlock()
return nil, err return nil, err
} }
xchacha20poly1305.Encrypt( xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
reply.Cookie[:0], xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
&reply.Nonce,
cookie[:],
msg[smac1:smac2],
&st.mac2.encryptionKey,
)
st.mutex.RUnlock() st.RUnlock()
return reply, nil return reply, nil
} }
func (st *CookieGenerator) Init(pk NoisePublicKey) { func (st *CookieGenerator) Init(pk NoisePublicKey) {
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
func() { func() {
hash, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
@ -198,8 +192,8 @@ func (st *CookieGenerator) Init(pk NoisePublicKey) {
} }
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
if !st.mac2.hasLastMAC1 { if !st.mac2.hasLastMAC1 {
return false return false
@ -207,14 +201,8 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
var cookie [blake2s.Size128]byte var cookie [blake2s.Size128]byte
_, err := xchacha20poly1305.Decrypt( xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
cookie[:0], _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
&msg.Nonce,
msg.Cookie[:],
st.mac2.lastMAC1[:],
&st.mac2.encryptionKey,
)
if err != nil { if err != nil {
return false return false
} }
@ -225,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
} }
func (st *CookieGenerator) AddMacs(msg []byte) { func (st *CookieGenerator) AddMacs(msg []byte) {
size := len(msg) size := len(msg)
smac2 := size - blake2s.Size128 smac2 := size - blake2s.Size128
@ -234,8 +221,8 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
mac1 := msg[smac1:smac2] mac1 := msg[smac1:smac2]
mac2 := msg[smac2:] mac2 := msg[smac2:]
st.mutex.Lock() st.Lock()
defer st.mutex.Unlock() defer st.Unlock()
// set mac1 // set mac1
@ -249,7 +236,7 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
// set mac2 // set mac2
if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
return return
} }

View file

@ -1,16 +1,15 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"testing" "testing"
) )
func TestCookieMAC1(t *testing.T) { func TestCookieMAC1(t *testing.T) {
// setup generator / checker // setup generator / checker
var ( var (
@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20 msg[5] ^= 0x20
srcBad1 := []byte{192, 168, 13, 37, 40, 01} srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) { if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }
srcBad2 := []byte{192, 168, 13, 38, 40, 01} srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) { if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }

807
device/device.go Normal file
View file

@ -0,0 +1,807 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/tevino/abool/v2"
)
type Device struct {
state struct {
// state holds the device's state. It is accessed atomically.
// Use the device.deviceState method to read it.
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
// During state transitions, the state variable is updated before the device itself.
// The state is thus either the current state of the device or
// the intended future state of the device.
// For example, while executing a call to Up, state will be deviceStateUp.
// There is no guarantee that that intended future state of the device
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
sync.Mutex
}
net struct {
stopping sync.WaitGroup
sync.RWMutex
bind conn.Bind // bind interface
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
brokenRoaming bool
}
staticIdentity struct {
sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
}
peers struct {
sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer
}
rate struct {
underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter
}
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
pool struct {
inboundElementsContainer *WaitPool
outboundElementsContainer *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
}
queue struct {
encryption *outboundQueue
decryption *inboundQueue
handshake *handshakeQueue
}
tun struct {
device tun.Device
mtu atomic.Int32
}
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
}
type aSecCfgType struct {
isSet bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
// deviceState represents the state of a Device.
// There are three states: down, up, closed.
// Transitions:
//
// down -----+
// ↑↓ ↓
// up -> closed
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
const (
deviceStateDown deviceState = iota
deviceStateUp
deviceStateClosed
)
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
return deviceState(device.state.state.Load())
}
// isClosed reports whether the device is closed (or is closing).
// See device.state.state comments for how to interpret this value.
func (device *Device) isClosed() bool {
return device.deviceState() == deviceStateClosed
}
// isUp reports whether the device is up (or is attempting to come up).
// See device.state.state comments for how to interpret this value.
func (device *Device) isUp() bool {
return device.deviceState() == deviceStateUp
}
// Must hold device.peers.Lock()
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
delete(device.peers.keyMap, key)
}
// changeState attempts to change the device state to match want.
func (device *Device) changeState(want deviceState) (err error) {
device.state.Lock()
defer device.state.Unlock()
old := device.deviceState()
if old == deviceStateClosed {
// once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want)
return nil
}
switch want {
case old:
return nil
case deviceStateUp:
device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked()
if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked()
if err == nil {
err = errDown
}
}
device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
return
}
// upLocked attempts to bring the device up and reports whether it succeeded.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) upLocked() error {
if err := device.BindUpdate(); err != nil {
device.log.Errorf("Unable to update bind: %v", err)
return err
}
// The IPC set operation waits for peers to be created before calling Start() on them,
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
device.peers.RUnlock()
return nil
}
// downLocked attempts to bring the device down.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) downLocked() error {
err := device.BindClose()
if err != nil {
device.log.Errorf("Bind close failed: %v", err)
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Stop()
}
device.peers.RUnlock()
return err
}
func (device *Device) Up() error {
return device.changeState(deviceStateUp)
}
func (device *Device) Down() error {
return device.changeState(deviceStateDown)
}
func (device *Device) IsUnderLoad() bool {
// check if currently under load
now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
return device.rate.underLoadUntil.Load() > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources
device.staticIdentity.Lock()
defer device.staticIdentity.Unlock()
if sk.Equals(device.staticIdentity.privateKey) {
return nil
}
device.peers.Lock()
defer device.peers.Unlock()
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock()
lockedPeers = append(lockedPeers, peer)
}
// remove peers with matching public keys
publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
peer.handshake.mutex.RUnlock()
removePeerLocked(device, peer, key)
peer.handshake.mutex.RLock()
}
}
// update key material
device.staticIdentity.privateKey = sk
device.staticIdentity.publicKey = publicKey
device.cookieChecker.Init(publicKey)
// do static-static DH pre-computations
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
for _, peer := range lockedPeers {
peer.handshake.mutex.RUnlock()
}
for _, peer := range expiredPeers {
peer.ExpireCurrentKeypairs()
}
return nil
}
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
device.PopulatePools()
// create queues
device.queue.handshake = newHandshakeQueue()
device.queue.encryption = newOutboundQueue()
device.queue.decryption = newInboundQueue()
// start workers
cpus := runtime.NumCPU()
device.state.stopping.Wait()
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
for i := 0; i < cpus; i++ {
go device.RoutineEncryption(i + 1)
go device.RoutineDecryption(i + 1)
go device.RoutineHandshake(i + 1)
}
device.state.stopping.Add(1) // RoutineReadFromTUN
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
return device
}
// BatchSize returns the BatchSize for the device as a whole which is the max of
// the bind batch size and the tun batch size. The batch size reported by device
// is the size used to construct memory pools, and is the allowed batch size for
// the lifetime of the device.
func (device *Device) BatchSize() int {
size := device.net.bind.BatchSize()
dSize := device.tun.device.BatchSize()
if size < dSize {
size = dSize
}
return size
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock()
defer device.peers.Unlock()
// stop peer and remove from routing
peer, ok := device.peers.keyMap[key]
if ok {
removePeerLocked(device, peer, key)
}
}
func (device *Device) RemoveAllPeers() {
device.peers.Lock()
defer device.peers.Unlock()
for key, peer := range device.peers.keyMap {
removePeerLocked(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
device.tun.device.Close()
device.downLocked()
// Remove peers before closing queues,
// because peers assume that queues are active.
device.RemoveAllPeers()
// We kept a reference to the encryption and decryption queues,
// in case we started any new peers that might write to them.
// No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done()
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.state.stopping.Wait()
device.rate.limiter.Close()
device.resetProtocol()
device.log.Verbosef("Device closed")
close(device.closed)
}
func (device *Device) Wait() chan struct{} {
return device.closed
}
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
if !device.isUp() {
return
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.keypairs.RLock()
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
peer.keypairs.RUnlock()
if sendKeepalive {
peer.SendKeepalive()
}
}
device.peers.RUnlock()
}
// closeBindLocked closes the device's net.bind.
// The caller must hold the net mutex.
func closeBindLocked(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
netc.netlinkCancel.Cancel()
}
if netc.bind != nil {
err = netc.bind.Close()
}
netc.stopping.Wait()
return err
}
func (device *Device) Bind() conn.Bind {
device.net.Lock()
defer device.net.Unlock()
return device.net.bind
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := closeBindLocked(device); err != nil {
return err
}
// open new sockets
if !device.isUp() {
return nil
}
// bind to new port
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
// start receiving routines
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := closeBindLocked(device)
device.net.Unlock()
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
}
isASecOn := false
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
)
}
} else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = 1
}
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = 2
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = 3
}
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = 4
}
isSameMap := map[uint32]bool{}
isSameMap[MessageInitiationType] = true
isSameMap[MessageResponseType] = true
isSameMap[MessageCookieReplyType] = true
isSameMap[MessageTransportType] = true
// size will be different if same values
if len(isSameMap) != 4 {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
)
}
}
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ; %w`,
newInitSize,
newResponseSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
)
}
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err
}

572
device/device_test.go Normal file
View file

@ -0,0 +1,572 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/netip"
"os"
"runtime"
"runtime/pprof"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
// cfg is a series of alternating key/value strings.
// uapiCfg exists because editors and humans like to insert
// whitespace into configs, which can cause failures, some of which are silent.
// For example, a leading blank newline causes the remainder
// of the config to be silently ignored.
func uapiCfg(cfg ...string) string {
if len(cfg)%2 != 0 {
panic("odd number of args to uapiReader")
}
buf := new(bytes.Buffer)
for i, s := range cfg {
buf.WriteString(s)
sep := byte('\n')
if i%2 == 0 {
sep = '='
}
buf.WriteByte(sep)
}
return buf.String()
}
// genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
// A testPair is a pair of testPeers.
type testPair [2]testPeer
// A testPeer is a peer used for testing.
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
ip netip.Addr
}
type SendDirection bool
const (
Ping SendDirection = true
Pong SendDirection = false
)
func (d SendDirection) String() string {
if d == Ping {
return "ping"
}
return "pong"
}
func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
) {
tb.Helper()
p0, p1 := pair[0], pair[1]
if !ping {
// pong is the new ping
p0, p1 = p1, p0
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
var err error
select {
case msgRecv := <-p0.tun.Inbound:
if !bytes.Equal(msg, msgRecv) {
err = fmt.Errorf("%s did not transit correctly", ping)
}
case <-timer.C:
err = fmt.Errorf("%s did not transit", ping)
case <-done:
}
if err != nil {
// The error may have occurred because the test is done.
select {
case <-done:
return
default:
}
// Real error.
tb.Error(err)
}
}
// genTestPair creates a testPair.
func genTestPair(
tb testing.TB,
realSocket, withASecurity bool,
) (pair testPair) {
var cfg, endpointCfg [2]string
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
} else {
cfg, endpointCfg = genConfigs(tb)
}
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
} else {
binds = bindtest.NewChannelBinds()
}
// Bring up a ChannelTun for each config.
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
continue
}
if err := p.dev.Up(); err != nil {
tb.Errorf("failed to bring up device %d: %v", i, err)
p.dev.Close()
continue
}
endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
}
for i := range pair {
p := &pair[i]
if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
tb.Errorf("failed to configure device endpoint %d: %v", i, err)
p.dev.Close()
continue
}
// The device is ready. Close it when the test completes.
tb.Cleanup(p.dev.Close)
}
return
}
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, false)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestUpDown(t *testing.T) {
goroutineLeakCheck(t)
const itrials = 50
const otrials = 10
for n := 0; n < otrials; n++ {
pair := genTestPair(t, false, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
}
}
var wg sync.WaitGroup
wg.Add(len(pair))
for i := range pair {
go func(d *Device) {
defer wg.Done()
for i := 0; i < itrials; i++ {
if err := d.Up(); err != nil {
t.Errorf("failed up bring up device: %v", err)
}
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
if err := d.Down(); err != nil {
t.Errorf("failed to bring down device: %v", err)
}
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
}
}(pair[i].dev)
}
wg.Wait()
for i := range pair {
pair[i].dev.Up()
pair[i].dev.Close()
}
}
}
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true, false)
done := make(chan struct{})
const warmupIters = 10
var warmup sync.WaitGroup
warmup.Add(warmupIters)
go func() {
// Send data continuously back and forth until we're done.
// Note that we may continue to attempt to send data
// even after done is closed.
i := warmupIters
for ping := Ping; ; ping = !ping {
pair.Send(t, ping, done)
select {
case <-done:
return
default:
}
if i > 0 {
warmup.Done()
i--
}
}
}()
warmup.Wait()
applyCfg := func(cfg string) {
err := pair[0].dev.IpcSet(cfg)
if err != nil {
t.Fatal(err)
}
}
// Change persistent_keepalive_interval concurrently with tunnel use.
t.Run("persistentKeepaliveInterval", func(t *testing.T) {
var pub NoisePublicKey
for key := range pair[0].dev.peers.keyMap {
pub = key
break
}
cfg := uapiCfg(
"public_key", hex.EncodeToString(pub[:]),
"persistent_keepalive_interval", "1",
)
for i := 0; i < 1000; i++ {
applyCfg(cfg)
}
})
// Change private keys concurrently with tunnel use.
t.Run("privateKey", func(t *testing.T) {
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
// Set iters to a large number like 1000 to flush out data races quickly.
// Don't leave it large. That can cause logical races
// in which the handshake is interleaved with key changes
// such that the private key appears to be unchanging but
// other state gets reset, which can cause handshake failures like
// "Received packet with invalid mac1".
const iters = 1
for i := 0; i < iters; i++ {
applyCfg(bad)
applyCfg(good)
}
})
// Perform bind updates and keepalive sends concurrently with tunnel use.
t.Run("bindUpdate and keepalive", func(t *testing.T) {
const iters = 10
for i := 0; i < iters; i++ {
for _, peer := range pair {
peer.dev.BindUpdate()
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
}
}
})
close(done)
}
func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
}
}
func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
var recv atomic.Uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var start time.Time
for {
<-pair[0].tun.Inbound
new := recv.Add(1)
if new == 1 {
start = time.Now()
}
// Careful! Don't change this to else if; b.N can be equal to 1.
if new == uint64(b.N) {
elapsed = time.Since(start)
return
}
}
}()
// Send packets as fast as we can until we've received enough.
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
for recv.Load() != uint64(b.N) {
sent++
pingc <- ping
}
wg.Wait()
b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
}
func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true, false)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
pair[0].dev.IpcGetOperation(io.Discard)
}
}
func goroutineLeakCheck(t *testing.T) {
goroutines := func() (int, []byte) {
p := pprof.Lookup("goroutine")
b := new(bytes.Buffer)
p.WriteTo(b, 1)
return p.Count(), b.Bytes()
}
startGoroutines, startStacks := goroutines()
t.Cleanup(func() {
if t.Failed() {
return
}
// Give goroutines time to exit, if they need it.
for i := 0; i < 10000; i++ {
if runtime.NumGoroutine() <= startGoroutines {
return
}
time.Sleep(1 * time.Millisecond)
}
endGoroutines, endStacks := goroutines()
t.Logf("starting stacks:\n%s\n", startStacks)
t.Logf("ending stacks:\n%s\n", endStacks)
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
type fakeBindSized struct {
size int
}
func (b *fakeBindSized) Open(
port uint16,
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 1, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
}

View file

@ -0,0 +1,16 @@
// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
package device
import "strconv"
const _deviceState_name = "DownUpClosed"
var _deviceState_index = [...]uint8{0, 4, 6, 12}
func (i deviceState) String() string {
if i >= deviceState(len(_deviceState_index)-1) {
return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
}

49
device/endpoint_test.go Normal file
View file

@ -0,0 +1,49 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"net/netip"
)
type DummyEndpoint struct {
src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
var src, dst [16]byte
if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
_, err := rand.Read(dst[:])
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
func (e *DummyEndpoint) DstToBytes() []byte {
out := e.DstIP().AsSlice()
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
}
func (e *DummyEndpoint) DstIP() netip.Addr {
return e.dst
}
func (e *DummyEndpoint) SrcIP() netip.Addr {
return e.src
}

View file

@ -1,14 +1,14 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"sync" "sync"
"unsafe"
) )
type IndexTableEntry struct { type IndexTableEntry struct {
@ -18,31 +18,32 @@ type IndexTableEntry struct {
} }
type IndexTable struct { type IndexTable struct {
mutex sync.RWMutex sync.RWMutex
table map[uint32]IndexTableEntry table map[uint32]IndexTableEntry
} }
func randUint32() (uint32, error) { func randUint32() (uint32, error) {
var integer [4]byte var integer [4]byte
_, err := rand.Read(integer[:]) _, err := rand.Read(integer[:])
return *(*uint32)(unsafe.Pointer(&integer[0])), err // Arbitrary endianness; both are intrinsified by the Go compiler.
return binary.LittleEndian.Uint32(integer[:]), err
} }
func (table *IndexTable) Init() { func (table *IndexTable) Init() {
table.mutex.Lock() table.Lock()
defer table.mutex.Unlock() defer table.Unlock()
table.table = make(map[uint32]IndexTableEntry) table.table = make(map[uint32]IndexTableEntry)
} }
func (table *IndexTable) Delete(index uint32) { func (table *IndexTable) Delete(index uint32) {
table.mutex.Lock() table.Lock()
defer table.mutex.Unlock() defer table.Unlock()
delete(table.table, index) delete(table.table, index)
} }
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
table.mutex.Lock() table.Lock()
defer table.mutex.Unlock() defer table.Unlock()
entry, ok := table.table[index] entry, ok := table.table[index]
if !ok { if !ok {
return return
@ -65,19 +66,19 @@ func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake)
// check if index used // check if index used
table.mutex.RLock() table.RLock()
_, ok := table.table[index] _, ok := table.table[index]
table.mutex.RUnlock() table.RUnlock()
if ok { if ok {
continue continue
} }
// check again while locked // check again while locked
table.mutex.Lock() table.Lock()
_, found := table.table[index] _, found := table.table[index]
if found { if found {
table.mutex.Unlock() table.Unlock()
continue continue
} }
table.table[index] = IndexTableEntry{ table.table[index] = IndexTableEntry{
@ -85,13 +86,13 @@ func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake)
handshake: handshake, handshake: handshake,
keypair: nil, keypair: nil,
} }
table.mutex.Unlock() table.Unlock()
return index, nil return index, nil
} }
} }
func (table *IndexTable) Lookup(id uint32) IndexTableEntry { func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
table.mutex.RLock() table.RLock()
defer table.mutex.RUnlock() defer table.RUnlock()
return table.table[id] return table.table[id]
} }

View file

@ -1,9 +1,9 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"net" "net"

69
device/junk_creator.go Normal file
View file

@ -0,0 +1,69 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

124
device/junk_creator_test.go Normal file
View file

@ -0,0 +1,124 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets()
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -1,14 +1,15 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"encoding/hex" "encoding/hex"
"golang.org/x/crypto/blake2s"
"testing" "testing"
"golang.org/x/crypto/blake2s"
) )
type KDFTest struct { type KDFTest struct {
@ -19,7 +20,7 @@ type KDFTest struct {
t2 string t2 string
} }
func assertEquals(t *testing.T, a string, b string) { func assertEquals(t *testing.T, a, b string) {
if a != b { if a != b {
t.Fatal("expected", a, "=", b) t.Fatal("expected", a, "=", b)
} }

View file

@ -1,15 +1,17 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"crypto/cipher" "crypto/cipher"
"git.zx2c4.com/wireguard-go/replay"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/replay"
) )
/* Due to limitations in Go and /x/crypto there is currently /* Due to limitations in Go and /x/crypto there is currently
@ -20,10 +22,10 @@ import (
*/ */
type Keypair struct { type Keypair struct {
sendNonce uint64 sendNonce atomic.Uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter replay.ReplayFilter replayFilter replay.Filter
isInitiator bool isInitiator bool
created time.Time created time.Time
localIndex uint32 localIndex uint32
@ -31,15 +33,15 @@ type Keypair struct {
} }
type Keypairs struct { type Keypairs struct {
mutex sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next *Keypair next atomic.Pointer[Keypair]
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {
kp.mutex.RLock() kp.RLock()
defer kp.mutex.RUnlock() defer kp.RUnlock()
return kp.current return kp.current
} }

48
device/logger.go Normal file
View file

@ -0,0 +1,48 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"log"
"os"
)
// A Logger provides logging for a Device.
// The functions are Printf-style functions.
// They must be safe for concurrent use.
// They do not require a trailing newline in the format.
// If nil, that level of logging will be silent.
type Logger struct {
Verbosef func(format string, args ...any)
Errorf func(format string, args ...any)
}
// Log levels for use with NewLogger.
const (
LogLevelSilent = iota
LogLevelError
LogLevelVerbose
)
// Function for use in Logger for discarding logged lines.
func DiscardLogf(format string, args ...any) {}
// NewLogger constructs a Logger that writes to stdout.
// It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf}
logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
}
if level >= LogLevelVerbose {
logger.Verbosef = logf("DEBUG")
}
if level >= LogLevelError {
logger.Errorf = logf("ERROR")
}
return logger
}

19
device/mobilequirks.go Normal file
View file

@ -0,0 +1,19 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
// though it will try to deal with it, and race maybe, if called after.
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.endpoint.Lock()
peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.endpoint.Unlock()
}
device.peers.RUnlock()
}

View file

@ -1,17 +1,19 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"errors"
"hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"hash"
) )
/* KDF related functions. /* KDF related functions.
@ -41,7 +43,6 @@ func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
HMAC1(t0, key, input) HMAC1(t0, key, input)
HMAC1(t0, t0[:], []byte{0x1}) HMAC1(t0, t0[:], []byte{0x1})
return
} }
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
@ -50,7 +51,6 @@ func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
HMAC1(t0, prk[:], []byte{0x1}) HMAC1(t0, prk[:], []byte{0x1})
HMAC2(t1, prk[:], t0[:], []byte{0x2}) HMAC2(t1, prk[:], t0[:], []byte{0x2})
setZero(prk[:]) setZero(prk[:])
return
} }
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
@ -60,7 +60,6 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
HMAC2(t1, prk[:], t0[:], []byte{0x2}) HMAC2(t1, prk[:], t0[:], []byte{0x2})
HMAC2(t2, prk[:], t1[:], []byte{0x3}) HMAC2(t2, prk[:], t1[:], []byte{0x3})
setZero(prk[:]) setZero(prk[:])
return
} }
func isZero(val []byte) bool { func isZero(val []byte) bool {
@ -78,12 +77,14 @@ func setZero(arr []byte) {
} }
} }
func newPrivateKey() (sk NoisePrivateKey, err error) { func (sk *NoisePrivateKey) clamp() {
// clamping: https://cr.yp.to/ecdh.html
_, err = rand.Read(sk[:])
sk[0] &= 248 sk[0] &= 248
sk[31] &= 127 sk[31] = (sk[31] & 127) | 64
sk[31] |= 64 }
func newPrivateKey() (sk NoisePrivateKey, err error) {
_, err = rand.Read(sk[:])
sk.clamp()
return return
} }
@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return return
} }
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { var errInvalidPublicKey = errors.New("invalid public key")
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk) apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk) ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk) curve25519.ScalarMult(&ss, ask, apk)
return ss if isZero(ss[:]) {
return ss, errInvalidPublicKey
}
return ss, nil
} }

View file

@ -1,28 +1,50 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"errors" "errors"
"git.zx2c4.com/wireguard-go/tai64n" "fmt"
"sync"
"time"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"sync"
"time" "github.com/amnezia-vpn/amneziawg-go/tai64n"
) )
type handshakeState int
const ( const (
HandshakeZeroed = iota handshakeZeroed = handshakeState(iota)
HandshakeInitiationCreated handshakeInitiationCreated
HandshakeInitiationConsumed handshakeInitiationConsumed
HandshakeResponseCreated handshakeResponseCreated
HandshakeResponseConsumed handshakeResponseConsumed
) )
func (hs handshakeState) String() string {
switch hs {
case handshakeZeroed:
return "handshakeZeroed"
case handshakeInitiationCreated:
return "handshakeInitiationCreated"
case handshakeInitiationConsumed:
return "handshakeInitiationConsumed"
case handshakeResponseCreated:
return "handshakeResponseCreated"
case handshakeResponseConsumed:
return "handshakeResponseConsumed"
default:
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
}
}
const ( const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@ -30,21 +52,21 @@ const (
WGLabelCookie = "cookie--" WGLabelCookie = "cookie--"
) )
const ( var (
MessageInitiationType = 1 MessageInitiationType uint32 = 1
MessageResponseType = 2 MessageResponseType uint32 = 2
MessageCookieReplyType = 3 MessageCookieReplyType uint32 = 3
MessageTransportType = 4 MessageTransportType uint32 = 4
) )
const ( const (
MessageInitiationSize = 148 // size of handshake initation message MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
) )
const ( const (
@ -53,6 +75,10 @@ const (
MessageTransportOffsetContent = 16 MessageTransportOffsetContent = 16
) )
var packetSizeToMsgType map[int]uint32
var msgTypeToJunkSize map[uint32]int
/* Type is an 8-bit field, followed by 3 nul bytes, /* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder * by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now) * we can treat these as a 32-bit unsigned int (for now)
@ -89,16 +115,16 @@ type MessageTransport struct {
type MessageCookieReply struct { type MessageCookieReply struct {
Type uint32 Type uint32
Receiver uint32 Receiver uint32
Nonce [24]byte Nonce [chacha20poly1305.NonceSizeX]byte
Cookie [blake2s.Size128 + poly1305.TagSize]byte Cookie [blake2s.Size128 + poly1305.TagSize]byte
} }
type Handshake struct { type Handshake struct {
state int state handshakeState
mutex sync.RWMutex mutex sync.RWMutex
hash [blake2s.Size]byte // hash value hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk presharedKey NoisePresharedKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending remoteIndex uint32 // index for sending
@ -116,11 +142,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data) KDF1(dst, c[:], data)
} }
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hash.Write(h[:]) hash.Write(h[:])
hash.Write(data) hash.Write(data)
@ -134,7 +160,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:]) setZero(h.chainKey[:])
setZero(h.hash[:]) setZero(h.hash[:])
h.localIndex = 0 h.localIndex = 0
h.state = HandshakeZeroed h.state = handshakeZeroed
} }
func (h *Handshake) mixHash(data []byte) { func (h *Handshake) mixHash(data []byte) {
@ -153,20 +179,14 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.staticIdentity.RLock()
device.staticIdentity.mutex.RLock() defer device.staticIdentity.RUnlock()
defer device.staticIdentity.mutex.RUnlock()
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key // create ephemeral key
var err error var err error
handshake.hash = InitialHash handshake.hash = InitialHash
handshake.chainKey = InitialChainKey handshake.chainKey = InitialChainKey
@ -175,59 +195,58 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err return nil, err
} }
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
} }
device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
func() { if err != nil {
var key [chacha20poly1305.KeySize]byte return nil, err
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
KDF2( var key [chacha20poly1305.KeySize]byte
&handshake.chainKey, KDF2(
&key, &handshake.chainKey,
handshake.chainKey[:], &key,
ss[:], handshake.chainKey[:],
) ss[:],
aead, _ := chacha20poly1305.New(key[:]) )
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) aead, _ := chacha20poly1305.New(key[:])
}() aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
timestamp := tai64n.Now() timestamp := tai64n.Now()
func() { aead, _ = chacha20poly1305.New(key[:])
var key [chacha20poly1305.KeySize]byte aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
KDF2(
&handshake.chainKey, // assign index
&key, device.indexTable.Delete(handshake.localIndex)
handshake.chainKey[:], msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
handshake.precomputedStaticStatic[:], if err != nil {
) return nil, err
aead, _ := chacha20poly1305.New(key[:]) }
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) handshake.localIndex = msg.Sender
}()
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated handshake.state = handshakeInitiationCreated
return &msg, nil return &msg, nil
} }
@ -237,28 +256,30 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.aSecMux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:]) mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
func() { var key [chacha20poly1305.KeySize]byte
var key [chacha20poly1305.KeySize]byte ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil {
KDF2(&chainKey, &key, chainKey[:], ss[:]) return nil
aead, _ := chacha20poly1305.New(key[:]) }
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
}() aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil { if err != nil {
return nil return nil
} }
@ -267,28 +288,29 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer // lookup peer
peer := device.LookupPeer(peerPK) peer := device.LookupPeer(peerPK)
if peer == nil { if peer == nil || !peer.isRunning.Load() {
return nil return nil
} }
handshake := &peer.handshake handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity // verify identity
var timestamp tai64n.Timestamp var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock() handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2( KDF2(
&chainKey, &chainKey,
&key, &key,
chainKey[:], chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil { if err != nil {
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
@ -298,11 +320,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// protect against replay & flood // protect against replay & flood
var ok bool replay := !timestamp.After(handshake.lastTimestamp)
ok = timestamp.After(handshake.lastTimestamp) flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
if !ok { if replay {
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
return nil
}
if flood {
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
return nil return nil
} }
@ -314,9 +340,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp if timestamp.After(handshake.lastTimestamp) {
handshake.lastInitiationConsumption = time.Now() handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitiationConsumed }
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()
@ -331,7 +362,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed { if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first") return nil, errors.New("handshake initiation must be consumed first")
} }
@ -345,7 +376,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
var msg MessageResponse var msg MessageResponse
device.aSecMux.RLock()
msg.Type = MessageResponseType msg.Type = MessageResponseType
device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -359,12 +392,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) if err != nil {
handshake.mixKey(ss[:]) return nil, err
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
}() ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
// add preshared key // add preshared key
@ -381,21 +418,22 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:]) handshake.mixHash(tau[:])
func() { aead, _ := chacha20poly1305.New(key[:])
aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.mixHash(msg.Empty[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = HandshakeResponseCreated handshake.state = handshakeResponseCreated
return &msg, nil return &msg, nil
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver
@ -411,37 +449,38 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
) )
ok := func() bool { ok := func() bool {
// lock handshake state // lock handshake state
handshake.mutex.RLock() handshake.mutex.RLock()
defer handshake.mutex.RUnlock() defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated { if handshake.state != handshakeInitiationCreated {
return false return false
} }
// lock private key for reading // lock private key for reading
device.staticIdentity.mutex.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.mutex.RUnlock() defer device.staticIdentity.RUnlock()
// finish 3-way DH // finish 3-way DH
mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
func() { ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
// add preshared key (psk) // add preshared key (psk)
@ -459,7 +498,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript // authenticate transcript
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
return false return false
} }
@ -478,7 +517,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash handshake.hash = hash
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()
@ -503,7 +542,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed { if handshake.state == handshakeResponseConsumed {
KDF2( KDF2(
&sendKey, &sendKey,
&recvKey, &recvKey,
@ -511,7 +550,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil, nil,
) )
isInitiator = true isInitiator = true
} else if handshake.state == HandshakeResponseCreated { } else if handshake.state == handshakeResponseCreated {
KDF2( KDF2(
&recvKey, &recvKey,
&sendKey, &sendKey,
@ -520,7 +559,7 @@ func (peer *Peer) BeginSymmetricSession() error {
) )
isInitiator = false isInitiator = false
} else { } else {
return errors.New("invalid state for keypair derivation") return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
} }
// zero handshake // zero handshake
@ -528,7 +567,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:]) setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:]) setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed peer.handshake.state = handshakeZeroed
// create AEAD instances // create AEAD instances
@ -540,8 +579,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:]) setZero(recvKey[:])
keypair.created = time.Now() keypair.created = time.Now()
keypair.sendNonce = 0 keypair.replayFilter.Reset()
keypair.replayFilter.Init()
keypair.isInitiator = isInitiator keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex keypair.remoteIndex = peer.handshake.remoteIndex
@ -554,16 +592,16 @@ func (peer *Peer) BeginSymmetricSession() error {
// rotate key pairs // rotate key pairs
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.mutex.Lock() keypairs.Lock()
defer keypairs.mutex.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.next next := keypairs.next.Load()
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.next = nil keypairs.next.Store(nil)
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@ -572,7 +610,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.next = keypair keypairs.next.Store(keypair)
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@ -583,18 +621,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {
if keypairs.next.Load() != receivedKeypair {
return false return false
} }
keypairs.mutex.Lock() keypairs.Lock()
defer keypairs.mutex.Unlock() defer keypairs.Unlock()
if keypairs.next != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next keypairs.current = keypairs.next.Load()
keypairs.next = nil keypairs.next.Store(nil)
return true return true
} }

View file

@ -1,26 +1,26 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"crypto/subtle" "crypto/subtle"
"encoding/hex" "encoding/hex"
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305"
) )
const ( const (
NoisePublicKeySize = 32 NoisePublicKeySize = 32
NoisePrivateKeySize = 32 NoisePrivateKeySize = 32
NoisePresharedKeySize = 32
) )
type ( type (
NoisePublicKey [NoisePublicKeySize]byte NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte
NoiseSymmetricKey [chacha20poly1305.KeySize]byte NoisePresharedKey [NoisePresharedKeySize]byte
NoiseNonce uint64 // padded to 12-bytes NoiseNonce uint64 // padded to 12-bytes
) )
@ -45,22 +45,25 @@ func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
} }
func (key *NoisePrivateKey) FromHex(src string) error { func (key *NoisePrivateKey) FromHex(src string) (err error) {
return loadExactHex(key[:], src) err = loadExactHex(key[:], src)
key.clamp()
return
} }
func (key NoisePrivateKey) ToHex() string { func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
return hex.EncodeToString(key[:]) err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
} }
func (key *NoisePublicKey) FromHex(src string) error { func (key *NoisePublicKey) FromHex(src string) error {
return loadExactHex(key[:], src) return loadExactHex(key[:], src)
} }
func (key NoisePublicKey) ToHex() string {
return hex.EncodeToString(key[:])
}
func (key NoisePublicKey) IsZero() bool { func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey var zero NoisePublicKey
return key.Equals(zero) return key.Equals(zero)
@ -70,10 +73,6 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
} }
func (key *NoiseSymmetricKey) FromHex(src string) error { func (key *NoisePresharedKey) FromHex(src string) error {
return loadExactHex(key[:], src) return loadExactHex(key[:], src)
} }
func (key NoiseSymmetricKey) ToHex() string {
return hex.EncodeToString(key[:])
}

View file

@ -1,14 +1,17 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"testing" "testing"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
) )
func TestCurveWrappers(t *testing.T) { func TestCurveWrappers(t *testing.T) {
@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey() pk1 := sk1.publicKey()
pk2 := sk2.publicKey() pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2) ss1, err1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1) ss2, err2 := sk2.sharedSecret(pk1)
if ss1 != ss2 { if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet") t.Fatal("Failed to compute shared secet")
} }
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := tuntest.NewChannelTUN()
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
device.SetPrivateKey(sk)
return device
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a, b []byte) {
if !bytes.Equal(a, b) {
t.Fatal(a, "!=", b)
}
}
func TestNoiseHandshake(t *testing.T) { func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t) dev1 := randDevice(t)
dev2 := randDevice(t) dev2 := randDevice(t)
@ -36,8 +63,16 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close() defer dev1.Close()
defer dev2.Close() defer dev2.Close()
peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) if err != nil {
t.Fatal(err)
}
peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
if err != nil {
t.Fatal(err)
}
peer1.Start()
peer2.Start()
assertEqual( assertEqual(
t, t,
@ -113,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.next key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

296
device/peer.go Normal file
View file

@ -0,0 +1,296 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type Peer struct {
isRunning atomic.Bool
keypairs Keypairs
handshake Handshake
device *Device
stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
endpoint struct {
sync.Mutex
val conn.Endpoint
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
disableRoaming bool
}
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
handshakeAttempts atomic.Uint32
needAnotherKeepalive atomic.Bool
sentLastMinuteHandshake atomic.Bool
}
state struct {
sync.Mutex // protects against concurrent Start/Stop
}
queue struct {
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
trieEntries list.List
persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed() {
return nil, errors.New("device closed")
}
// lock resources
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
device.peers.Lock()
defer device.peers.Unlock()
// check if over limit
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers")
}
// create peer
peer := new(Peer)
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
}
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
// init timers
peer.timersInit()
// add
device.peers.keyMap[pk] = peer
return peer, nil
}
func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
if peer.device.isClosed() {
return nil
}
peer.endpoint.Lock()
endpoint := peer.endpoint.val
if endpoint == nil {
peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
var totalLen uint64
for _, b := range buffers {
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
}
return err
}
func (peer *Peer) String() string {
// The awful goo that follows is identical to:
//
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
//
// except that it is considerably more efficient.
src := peer.handshake.remoteStatic
b64 := func(input byte) byte {
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
}
b := []byte("peer(____…____)")
const first = len("peer(")
const second = len("peer(____…")
b[first+0] = b64((src[0] >> 2) & 63)
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
b[first+3] = b64(src[2] & 63)
b[second+0] = b64(src[29] & 63)
b[second+1] = b64((src[30] >> 2) & 63)
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
b[second+3] = b64((src[31] << 2) & 63)
return string(b)
}
func (peer *Peer) Start() {
// should never start a peer on a closed device
if peer.device.isClosed() {
return
}
// prevent simultaneous start/stop operations
peer.state.Lock()
defer peer.state.Unlock()
if peer.isRunning.Load() {
return
}
device := peer.device
device.log.Verbosef("%v - Starting", peer)
// reset routine state
peer.stopping.Wait()
peer.stopping.Add(2)
peer.handshake.mutex.Lock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
peer.handshake.mutex.Unlock()
peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
peer.timersStart()
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
// Use the device batch size, not the bind batch size, as the device size is
// the size of the batch pools.
batchSize := peer.device.BatchSize()
go peer.RoutineSequentialSender(batchSize)
go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
device := peer.device
// clear key pairs
keypairs := &peer.keypairs
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
handshake := &peer.handshake
handshake.mutex.Lock()
device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
handshake.mutex.Unlock()
peer.FlushStagedPackets()
}
func (peer *Peer) ExpireCurrentKeypairs() {
handshake := &peer.handshake
handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
handshake.mutex.Unlock()
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
keypairs.current.sendNonce.Store(RejectAfterMessages)
}
if next := keypairs.next.Load(); next != nil {
next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
func (peer *Peer) Stop() {
peer.state.Lock()
defer peer.state.Unlock()
if !peer.isRunning.Swap(false) {
return
}
peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop()
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
peer.queue.inbound.c <- nil
peer.queue.outbound.c <- nil
peer.stopping.Wait()
peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
peer.ZeroAndFlushAll()
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
return
}
peer.endpoint.clearSrcOnTx = false
peer.endpoint.val = endpoint
}
func (peer *Peer) markEndpointSrcForClearing() {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return
}
peer.endpoint.clearSrcOnTx = true
}

120
device/pools.go Normal file
View file

@ -0,0 +1,120 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync"
"sync/atomic"
)
type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count atomic.Uint32
max uint32
}
func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count.Load() >= p.max {
p.cond.Wait()
}
p.count.Add(1)
p.lock.Unlock()
}
return p.pool.Get()
}
func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
p.count.Add(^uint32(0))
p.cond.Signal()
}
func (device *Device) PopulatePools() {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueInboundElement)
})
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueOutboundElement)
})
}
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.inboundElementsContainer.Put(c)
}
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.outboundElementsContainer.Put(c)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg)
}
func (device *Device) GetInboundElement() *QueueInboundElement {
return device.pool.inboundElements.Get().(*QueueInboundElement)
}
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
elem.clearPointers()
device.pool.inboundElements.Put(elem)
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
return device.pool.outboundElements.Get().(*QueueOutboundElement)
}
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
elem.clearPointers()
device.pool.outboundElements.Put(elem)
}

139
device/pools_test.go Normal file
View file

@ -0,0 +1,139 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup
var trials atomic.Int32
startTrials := int32(100000)
if raceEnabled {
// This test can be very slow with -race.
startTrials /= 10
}
trials.Store(startTrials)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
var max atomic.Uint32
updateMax := func() {
count := p.count.Load()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
old := max.Load()
if count <= old {
break
}
if max.CompareAndSwap(old, count) {
break
}
}
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
updateMax()
x := p.Get()
updateMax()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
updateMax()
p.Put(x)
updateMax()
}
}()
}
wg.Wait()
if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}
func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkWaitPoolEmpty(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(0, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkSyncPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := sync.Pool{New: func() any { return make([]byte, 16) }}
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}

View file

@ -0,0 +1,19 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
/* Reduce memory consumption for Android */
const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)

View file

@ -1,13 +1,16 @@
/* SPDX-License-Identifier: GPL-2.0 //go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package device
/* Implementation specific constants */ import "github.com/amnezia-vpn/amneziawg-go/conn"
const ( const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024

View file

@ -0,0 +1,21 @@
//go:build ios
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
// These are vars instead of consts, because heavier network extensions might want to reduce
// them further.
var (
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
PreallocatedBuffersPerPool uint32 = 1024
)
const MaxSegmentSize = 1700

View file

@ -0,0 +1,15 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const (
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
)

View file

@ -0,0 +1,10 @@
//go:build !race
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const raceEnabled = false

View file

@ -0,0 +1,10 @@
//go:build race
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const raceEnabled = true

577
device/receive.go Normal file
View file

@ -0,0 +1,577 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/binary"
"errors"
"net"
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
endpoint conn.Endpoint
buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keypair *Keypair
endpoint conn.Endpoint
}
type QueueInboundElementsContainer struct {
sync.Mutex
elems []*QueueInboundElement
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
// It also reduces the possible collateral damage from use-after-free bugs.
func (elem *QueueInboundElement) clearPointers() {
elem.buffer = nil
elem.packet = nil
elem.keypair = nil
elem.endpoint = nil
}
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(
maxBatchSize int,
recv conn.ReceiveFunc,
) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed
var (
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
bufs = make([][]byte, maxBatchSize)
err error
sizes = make([]int, maxBatchSize)
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if bufsArrs[i] != nil {
device.PutMessageBuffer(bufsArrs[i])
}
}
}()
for {
count, err = recv(bufs, sizes, endpoints)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return
}
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
continue
}
return
}
deathSpiral = 0
device.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
continue
}
// check size of packet
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
device.log.Verbosef("Transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
if msgType != MessageTransportType {
device.log.Verbosef("ASec: Received message with unknown type")
continue
}
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
}
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = bufsArrs[i]
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsContainer()
elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}
default:
device.log.Verbosef("Received message with unknown type")
continue
}
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
buffer: bufsArrs[i],
packet: packet,
endpoint: endpoints[i],
}:
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
default:
}
}
device.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
device.queue.decryption.c <- elemsContainer
} else {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
}
}
func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
for elemsContainer := range device.queue.decryption.c {
for _, elem := range elemsContainer.elems {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
}
}
elemsContainer.Unlock()
}
}
/* Handles incoming packets related to handshake
*/
func (device *Device) RoutineHandshake(id int) {
defer func() {
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
device.queue.encryption.wg.Done()
}()
device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c {
device.aSecMux.RLock()
// handle cookie fields and ratelimiting
switch elem.msgType {
case MessageCookieReplyType:
// unmarshal packet
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
device.log.Verbosef("Failed to decode cookie reply")
goto skip
}
// lookup peer from index
entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
goto skip
}
// consume reply
if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef(
"Receiving cookie response from %s",
elem.endpoint.DstToString(),
)
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef(
"Could not decrypt invalid cookie response",
)
}
}
goto skip
case MessageInitiationType, MessageResponseType:
// check mac fields and maybe ratelimit
if !device.cookieChecker.CheckMAC1(elem.packet) {
device.log.Verbosef("Received packet with invalid mac1")
goto skip
}
// endpoints destination address is the source of the datagram
if device.IsUnderLoad() {
// verify MAC2 field
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
goto skip
}
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
goto skip
}
}
default:
device.log.Errorf("Invalid packet ended up in the handshake queue")
goto skip
}
// handle handshake initiation/response content
switch elem.msgType {
case MessageInitiationType:
// unmarshal
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode initiation message")
goto skip
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
goto skip
}
// update timers
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
case MessageResponseType:
// unmarshal
var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode response message")
goto skip
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
goto skip
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// derive keypair
err = peer.BeginSymmetricSession()
if err != nil {
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
goto skip
}
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
}
skip:
device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
peer.stopping.Done()
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
bufs := make([][]byte, 0, maxBatchSize)
for elemsContainer := range peer.queue.inbound.c {
if elemsContainer == nil {
return
}
elemsContainer.Lock()
validTailPacket := -1
dataPacketReceived := false
rxBytesLen := uint64(0)
for i, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
}
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
validTailPacket = i
if peer.ReceivedWithKeypair(elem.keypair) {
peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
dataPacketReceived = true
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
}
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
}
default:
device.log.Verbosef(
"Packet with invalid IP version from %v",
peer,
)
continue
}
bufs = append(
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
}
peer.rxBytes.Add(rxBytesLen)
if validTailPacket >= 0 {
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
}
if dataPacketReceived {
peer.timersDataReceived()
}
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
bufs = bufs[:0]
device.PutInboundElementsContainer(elemsContainer)
}
}

608
device/send.go Normal file
View file

@ -0,0 +1,608 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/binary"
"errors"
"net"
"os"
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
/* Outbound flow
*
* 1. TUN queue
* 2. Routing (sequential)
* 3. Nonce assignment (sequential)
* 4. Encryption (parallel)
* 5. Transmission (sequential)
*
* The functions in this file occur (roughly) in the order in
* which the packets are processed.
*
* Locking, Producers and Consumers
*
* The order of packets (per peer) must be maintained,
* but encryption of packets happen out-of-order:
*
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work (encryption) on the packet.
*
* If the element is inserted into the "encryption queue",
* the content is preceded by enough "junk" to contain the transport header
* (to allow the construction of transport messages in-place)
*/
type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
keypair *Keypair // keypair for encryption
peer *Peer // related peer
}
type QueueOutboundElementsContainer struct {
sync.Mutex
elems []*QueueOutboundElement
}
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
// It also reduces the possible collateral damage from use-after-free bugs.
func (elem *QueueOutboundElement) clearPointers() {
elem.buffer = nil
elem.packet = nil
elem.keypair = nil
elem.peer = nil
}
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
elemsContainer := peer.device.GetOutboundElementsContainer()
elemsContainer.elems = append(elemsContainer.elems, elem)
select {
case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
}
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
peer.timers.handshakeAttempts.Store(0)
}
peer.handshake.mutex.RLock()
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.RUnlock()
return nil
}
peer.handshake.mutex.RUnlock()
peer.handshake.mutex.Lock()
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.Unlock()
return nil
}
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.device.junkCreator.createJunkPackets()
peer.device.aSecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
sendBuffer = append(sendBuffer, junkedHeader)
err = peer.SendBuffers(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
peer.timersHandshakeInitiated()
return err
}
func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.mutex.Lock()
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
peer.device.log.Verbosef("%v - Sending handshake response", peer)
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession()
if err != nil {
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
return err
}
peer.timersSessionDerived()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
)
if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err)
return err
}
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
func (peer *Peer) keepKeyFreshSending() {
keypair := peer.keypairs.Current()
if keypair == nil {
return
}
nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
device.state.stopping.Done()
device.queue.encryption.wg.Done()
}()
device.log.Verbosef("Routine: TUN reader - started")
var (
batchSize = device.BatchSize()
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
)
for i := range elems {
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
defer func() {
for _, elem := range elems {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
}
}()
for {
// read packets
count, readErr = device.tun.device.Read(bufs, sizes, offset)
for i := 0; i < count; i++ {
if sizes[i] < 1 {
continue
}
elem := elems[i]
elem.packet = bufs[i][offset : offset+sizes[i]]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")
}
if peer == nil {
continue
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsContainer()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
for peer, elemsForPeer := range elemsByPeer {
if peer.isRunning.Load() {
peer.StagePackets(elemsForPeer)
peer.SendStagedPackets()
} else {
for _, elem := range elemsForPeer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsForPeer)
}
delete(elemsByPeer, peer)
}
if readErr != nil {
if errors.Is(readErr, tun.ErrTooManySegments) {
// TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
continue
}
if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
}
go device.Close()
}
return
}
}
}
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
for _, elem := range tooOld.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
}
func (peer *Peer) SendStagedPackets() {
top:
if len(peer.queue.staged) == 0 || !peer.device.isUp() {
return
}
keypair := peer.keypairs.Current()
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
for {
var elemsContainerOOO *QueueOutboundElementsContainer
select {
case elemsContainer := <-peer.queue.staged:
i := 0
for _, elem := range elemsContainer.elems {
elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
}
elemsContainer.Lock()
elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
}
if len(elemsContainer.elems) == 0 {
peer.device.PutOutboundElementsContainer(elemsContainer)
goto top
}
// add to parallel and sequential queue
if peer.isRunning.Load() {
peer.queue.outbound.c <- elemsContainer
peer.device.queue.encryption.c <- elemsContainer
} else {
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
}
if elemsContainerOOO != nil {
goto top
}
default:
return
}
}
}
func (peer *Peer) FlushStagedPackets() {
for {
select {
case elemsContainer := <-peer.queue.staged:
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
}
}
func calculatePaddingSize(packetSize, mtu int) int {
lastUnit := packetSize
if mtu == 0 {
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
}
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
if paddedSize > mtu {
paddedSize = mtu
}
return paddedSize - lastUnit
}
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
* Obs. One instance per core
*/
func (device *Device) RoutineEncryption(id int) {
var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
for elemsContainer := range device.queue.encryption.c {
for _, elem := range elemsContainer.elems {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
}
elemsContainer.Unlock()
}
}
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
peer.stopping.Done()
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
bufs := make([][]byte, 0, maxBatchSize)
for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
if elemsContainer == nil {
return
}
if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
// The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
continue
}
dataSent := false
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
bufs = append(bufs, elem.packet)
}
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendBuffers(bufs)
if dataSent {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
if err != nil {
var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
device.log.Verbosef(err.Error())
err = errGSO.RetryErr
}
}
if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
peer.keepKeyFreshSending()
}
}

12
device/sticky_default.go Normal file
View file

@ -0,0 +1,12 @@
//go:build !linux
package device
import (
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
return nil, nil
}

224
device/sticky_linux.go Normal file
View file

@ -0,0 +1,224 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package device
import (
"sync"
"unsafe"
"golang.org/x/sys/unix"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
if !conn.StdNetSupportsStickySockets {
return nil, nil
}
if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
}
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
if err != nil {
unix.Close(netlinkSock)
return nil, err
}
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
return netlinkCancel, nil
}
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer netlinkCancel.Close()
defer unix.Close(netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.endpoint.Unlock()
break
}
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.endpoint.Unlock()
break
}
pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.endpoint.Lock()
if peer.endpoint.val == nil {
peer.endpoint.Unlock()
continue
}
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
peer.endpoint.Unlock()
continue
}
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.endpoint.Unlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
device.net.fwmark,
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}

View file

@ -1,25 +1,25 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* *
* This is based heavily on timers.c from the kernel implementation. * This is based heavily on timers.c from the kernel implementation.
*/ */
package main package device
import ( import (
"math/rand"
"sync" "sync"
"sync/atomic"
"time" "time"
_ "unsafe"
) )
/* This Timer structure and related functions should roughly copy the interface of //go:linkname fastrandn runtime.fastrandn
* the Linux kernel's struct timer_list. func fastrandn(n uint32) uint32
*/
// A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct { type Timer struct {
timer *time.Timer *time.Timer
modifyingLock sync.RWMutex modifyingLock sync.RWMutex
runningLock sync.Mutex runningLock sync.Mutex
isPending bool isPending bool
@ -27,36 +27,35 @@ type Timer struct {
func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
timer := &Timer{} timer := &Timer{}
timer.timer = time.AfterFunc(time.Hour, func() { timer.Timer = time.AfterFunc(time.Hour, func() {
timer.runningLock.Lock() timer.runningLock.Lock()
defer timer.runningLock.Unlock()
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
if !timer.isPending { if !timer.isPending {
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
timer.runningLock.Unlock()
return return
} }
timer.isPending = false timer.isPending = false
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
expirationFunction(peer) expirationFunction(peer)
timer.runningLock.Unlock()
}) })
timer.timer.Stop() timer.Stop()
return timer return timer
} }
func (timer *Timer) Mod(d time.Duration) { func (timer *Timer) Mod(d time.Duration) {
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
timer.isPending = true timer.isPending = true
timer.timer.Reset(d) timer.Reset(d)
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
} }
func (timer *Timer) Del() { func (timer *Timer) Del() {
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
timer.isPending = false timer.isPending = false
timer.timer.Stop() timer.Stop()
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
} }
@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool {
} }
func (peer *Peer) timersActive() bool { func (peer *Peer) timersActive() bool {
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
} }
func expiredRetransmitHandshake(peer *Peer) { func expiredRetransmitHandshake(peer *Peer) {
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Del() peer.timers.sendKeepalive.Del()
@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) {
/* We drop all packets without a keypair and don't try again, /* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake. * if we try unsuccessfully for too long to make a handshake.
*/ */
peer.FlushNonceQueue() peer.FlushStagedPackets()
/* We set a timer for destroying any residue that might be left /* We set a timer for destroying any residue that might be left
* of a partial exchange. * of a partial exchange.
@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
} }
} else { } else {
atomic.AddUint32(&peer.timers.handshakeAttempts, 1) peer.timers.handshakeAttempts.Add(1)
peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.mutex.Lock() peer.markEndpointSrcForClearing()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
peer.SendHandshakeInitiation(true) peer.SendHandshakeInitiation(true)
} }
@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) { func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive() peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Get() { if peer.timers.needAnotherKeepalive.Load() {
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} }
@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) {
} }
func expiredNewHandshake(peer *Peer) { func expiredNewHandshake(peer *Peer) {
peer.device.log.Debug.Printf("%s: Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.mutex.Lock() peer.markEndpointSrcForClearing()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
func expiredZeroKeyMaterial(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) {
peer.device.log.Debug.Printf("%s: Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll() peer.ZeroAndFlushAll()
} }
func expiredPersistentKeepalive(peer *Peer) { func expiredPersistentKeepalive(peer *Peer) {
if peer.persistentKeepaliveInterval > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() { if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else { } else {
peer.timers.needAnotherKeepalive.Set(true) peer.timers.needAnotherKeepalive.Store(true)
} }
} }
} }
@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */ /* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() { func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Del() peer.timers.retransmitHandshake.Del()
} }
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) peer.lastHandshakeNano.Store(time.Now().UnixNano())
} }
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { keepalive := peer.persistentKeepaliveInterval.Load()
peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
} }
} }
@ -213,9 +204,12 @@ func (peer *Peer) timersInit() {
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) }
peer.timers.sentLastMinuteHandshake.Set(false)
peer.timers.needAnotherKeepalive.Set(false) func (peer *Peer) timersStart() {
peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Store(false)
peer.timers.needAnotherKeepalive.Store(false)
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

53
device/tun.go Normal file
View file

@ -0,0 +1,53 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"fmt"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
const DefaultMTU = 1420
func (device *Device) RoutineTUNEventReader() {
device.log.Verbosef("Routine: event worker - started")
for event := range device.tun.device.Events() {
if event&tun.EventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU()
if err != nil {
device.log.Errorf("Failed to load updated MTU of device: %v", err)
continue
}
if mtu < 0 {
device.log.Errorf("MTU not updated to negative value: %v", mtu)
continue
}
var tooLarge string
if mtu > MaxContentSize {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize
}
old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}
}
if event&tun.EventUp != 0 {
device.log.Verbosef("Interface up requested")
device.Up()
}
if event&tun.EventDown != 0 {
device.log.Verbosef("Interface down requested")
device.Down()
}
}
device.log.Verbosef("Routine: event worker - stopped")
}

583
device/uapi.go Normal file
View file

@ -0,0 +1,583 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
type IPCError struct {
code int64 // error code
err error // underlying/wrapped error
}
func (s IPCError) Error() string {
return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
}
func (s IPCError) Unwrap() error {
return s.err
}
func (s IPCError) ErrorCode() int64 {
return s.code
}
func ipcErrorf(code int64, msg string, args ...any) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
var byteBufferPool = &sync.Pool{
New: func() any { return new(bytes.Buffer) },
}
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcGetOperation(w io.Writer) error {
device.ipcMutex.RLock()
defer device.ipcMutex.RUnlock()
buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer byteBufferPool.Put(buf)
sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n')
}
keyf := func(prefix string, key *[32]byte) {
buf.Grow(len(key)*2 + 2 + len(prefix))
buf.WriteString(prefix)
buf.WriteByte('=')
const hex = "0123456789abcdef"
for i := 0; i < len(key); i++ {
buf.WriteByte(hex[key[i]>>4])
buf.WriteByte(hex[key[i]&0xf])
}
buf.WriteByte('\n')
}
func() {
// lock required resources
device.net.RLock()
defer device.net.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
device.peers.RLock()
defer device.peers.RUnlock()
// serialize device related values
if !device.staticIdentity.privateKey.IsZero() {
keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
}
if device.net.port != 0 {
sendf("listen_port=%d", device.net.port)
}
if device.net.fwmark != 0 {
sendf("fwmark=%d", device.net.fwmark)
}
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
}
}
for _, peer := range device.peers.keyMap {
// Serialize peer state.
peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
peer.endpoint.Lock()
if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
peer.endpoint.Unlock()
nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}
}()
// send lines (does not require resource locks)
if _, err := w.Write(buf.Bytes()); err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
}
return nil
}
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
defer func() {
if err != nil {
device.log.Errorf("%v", err)
}
}()
peer := new(ipcSetPeer)
deviceConfig := true
tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
return nil
}
key, value, ok := strings.Cut(line, "=")
if !ok {
return ipcErrorf(
ipc.IpcErrorProtocol,
"failed to parse line %q",
line,
)
}
if key == "public_key" {
if deviceConfig {
deviceConfig = false
}
peer.handlePostConfig()
// Load/create the peer we are now configuring.
err := device.handlePublicKeyLine(peer, value)
if err != nil {
return err
}
continue
}
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value, &tempASecCfg)
} else {
err = device.handlePeerLine(peer, key, value)
}
if err != nil {
return err
}
}
err = device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
if err := scanner.Err(); err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
}
return nil
}
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromMaybeZeroHex(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
}
device.log.Verbosef("UAPI: Updating private key")
device.SetPrivateKey(sk)
case "listen_port":
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
}
// update port and rebind
device.log.Verbosef("UAPI: Updating listen port")
device.net.Lock()
device.net.port = uint16(port)
device.net.Unlock()
if err := device.BindUpdate(); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
}
case "fwmark":
mark, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
}
device.log.Verbosef("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(mark)); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
}
case "replace_peers":
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
}
return nil
}
// An ipcSetPeer is the current state of an IPC set operation on a peer.
type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
}
func (peer *ipcSetPeer) handlePostConfig() {
if peer.Peer == nil || peer.dummy {
return
}
if peer.created {
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
if peer.pkaOn {
peer.SendKeepalive()
}
peer.SendStagedPackets()
}
}
func (device *Device) handlePublicKeyLine(
peer *ipcSetPeer,
value string,
) error {
// Load/create the peer we are configuring.
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
}
// Ignore peer with the same public key as this device.
device.staticIdentity.RLock()
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock()
if peer.dummy {
peer.Peer = &Peer{}
} else {
peer.Peer = device.LookupPeer(publicKey)
}
peer.created = peer.Peer == nil
if peer.created {
peer.Peer, err = device.NewPeer(publicKey)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
}
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
}
return nil
}
func (device *Device) handlePeerLine(
peer *ipcSetPeer,
key, value string,
) error {
switch key {
case "update_only":
// allow disabling of creation
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
}
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
peer.Peer = &Peer{}
peer.dummy = true
}
case "remove":
// remove currently selected peer from device
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
}
if !peer.dummy {
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
device.RemovePeer(peer.handshake.remoteStatic)
}
peer.Peer = &Peer{}
peer.dummy = true
case "preshared_key":
device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.handshake.mutex.Unlock()
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
}
case "endpoint":
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
endpoint, err := device.net.bind.ParseEndpoint(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
peer.pkaOn = old == 0 && secs != 0
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
}
if peer.dummy {
return nil
}
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
}
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
}
return nil
}
func (device *Device) IpcGet() (string, error) {
buf := new(strings.Builder)
if err := device.IpcGetOperation(buf); err != nil {
return "", err
}
return buf.String(), nil
}
func (device *Device) IpcSet(uapiConf string) error {
return device.IpcSetOperation(strings.NewReader(uapiConf))
}
func (device *Device) IpcHandle(socket net.Conn) {
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s)
writer := bufio.NewWriter(s)
return bufio.NewReadWriter(reader, writer)
}(socket)
for {
op, err := buffered.ReadString('\n')
if err != nil {
return
}
// handle operation
switch op {
case "set=1\n":
err = device.IpcSetOperation(buffered.Reader)
case "get=1\n":
var nextByte byte
nextByte, err = buffered.ReadByte()
if err != nil {
return
}
if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
break
}
err = device.IpcGetOperation(buffered.Writer)
default:
device.log.Errorf("invalid UAPI operation: %v", op)
return
}
// write status
var status *IPCError
if err != nil && !errors.As(err, &status) {
// shouldn't happen
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
}
if status != nil {
device.log.Errorf("%v", status)
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
}
}

View file

@ -1,48 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import "testing"
func TestDevice(t *testing.T) {
// prepare tun devices for generating traffic
tun1, err := CreateDummyTUN("tun1")
if err != nil {
t.Error("failed to create tun:", err.Error())
}
tun2, err := CreateDummyTUN("tun2")
if err != nil {
t.Error("failed to create tun:", err.Error())
}
_ = tun1
_ = tun2
// prepare endpoints
end1, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}

View file

@ -1,15 +0,0 @@
// +build !android
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
const DoNotUseThisProgramOnLinux = UseTheKernelModuleInstead
// --------------------------------------------------------
// Do not use this on Linux. Instead use the kernel module.
// See wireguard.com/install for more information.
// --------------------------------------------------------

View file

@ -1,53 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"math/rand"
"net"
)
type DummyEndpoint struct {
src [16]byte
dst [16]byte
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint
if _, err := rand.Read(end.src[:]); err != nil {
return nil, err
}
_, err := rand.Read(end.dst[:])
return &end, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
}
func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
}
func (e *DummyEndpoint) SrcToBytes() []byte {
return e.src[:]
}
func (e *DummyEndpoint) DstIP() net.IP {
return e.dst[:]
}
func (e *DummyEndpoint) SrcIP() net.IP {
return e.src[:]
}

51
format_test.go Normal file
View file

@ -0,0 +1,51 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"go/format"
"io/fs"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
)
func TestFormatting(t *testing.T) {
var wg sync.WaitGroup
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
t.Errorf("unable to walk %s: %v", path, err)
return nil
}
if d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
wg.Add(1)
go func(path string) {
defer wg.Done()
src, err := os.ReadFile(path)
if err != nil {
t.Errorf("unable to read %s: %v", path, err)
return
}
if runtime.GOOS == "windows" {
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
}
formatted, err := format.Source(src)
if err != nil {
t.Errorf("unable to format %s: %v", path, err)
return
}
if !bytes.Equal(src, formatted) {
t.Errorf("unformatted code: %s", path)
}
}(path)
return nil
})
wg.Wait()
}

18
go.mod
View file

@ -1,7 +1,17 @@
module git.zx2c4.com/wireguard-go module github.com/amnezia-vpn/amneziawg-go
go 1.24
require ( require (
golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 github.com/tevino/abool/v2 v2.1.0
golang.org/x/net v0.0.0-20181005035420-146acd28ed58 golang.org/x/crypto v0.36.0
golang.org/x/sys v0.0.0-20181005133103-4497e2df6f9e golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
)
require (
github.com/google/btree v1.1.3 // indirect
golang.org/x/time v0.9.0 // indirect
) )

26
go.sum
View file

@ -1,6 +1,20 @@
golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 h1:Vk3wNqEZwyGyei9yq5ekj7frek2u7HUfffJ1/opblzc= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
golang.org/x/net v0.0.0-20181005035420-146acd28ed58 h1:otZG8yDCO4LVps5+9bxOeNiCvgmOyt96J3roHTYs7oE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/sys v0.0.0-20181005133103-4497e2df6f9e h1:EfdBzeKbFSvOjoIqSZcfS8wp0FBLokGBEs9lz1OtSg0= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
golang.org/x/sys v0.0.0-20181005133103-4497e2df6f9e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=

View file

@ -1,92 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"errors"
"git.zx2c4.com/wireguard-go/tun"
"os"
"testing"
)
/* Helpers for writing unit tests
*/
type DummyTUN struct {
name string
mtu int
packets chan []byte
events chan tun.TUNEvent
}
func (tun *DummyTUN) File() *os.File {
return nil
}
func (tun *DummyTUN) Name() (string, error) {
return tun.name, nil
}
func (tun *DummyTUN) MTU() (int, error) {
return tun.mtu, nil
}
func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
tun.packets <- d[offset:]
return len(d), nil
}
func (tun *DummyTUN) Close() error {
close(tun.events)
close(tun.packets)
return nil
}
func (tun *DummyTUN) Events() chan tun.TUNEvent {
return tun.events
}
func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
t, ok := <-tun.packets
if !ok {
return 0, errors.New("device closed")
}
copy(d[offset:], t)
return len(t), nil
}
func CreateDummyTUN(name string) (tun.TUNDevice, error) {
var dummy DummyTUN
dummy.mtu = 0
dummy.packets = make(chan []byte, 100)
dummy.events = make(chan tun.TUNEvent, 10)
return &dummy, nil
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

287
ipc/namedpipe/file.go Normal file
View file

@ -0,0 +1,287 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
package namedpipe
import (
"io"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type timeoutChan chan struct{}
var (
ioInitOnce sync.Once
ioCompletionPort windows.Handle
)
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
bytes uint32
err error
}
// ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct {
o windows.Overlapped
ch chan ioResult
}
func initIo() {
h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
panic(err)
}
ioCompletionPort = h
go ioCompletionProcessor(h)
}
// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type file struct {
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomic.Bool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
}
type deadlineHandler struct {
setLock sync.Mutex
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
timedout atomic.Bool
}
// makeFile makes a new file from an existing file handle
func makeFile(h windows.Handle) (*file, error) {
f := &file{handle: h}
ioInitOnce.Do(initIo)
_, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
if err != nil {
return nil, err
}
err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil {
return nil, err
}
f.readDeadline.channel = make(timeoutChan)
f.writeDeadline.channel = make(timeoutChan)
return f, nil
}
// closeHandle closes the resources associated with a Win32 handle
func (f *file) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
if f.closing.Swap(true) == false {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
windows.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
}
}
// Close closes a file.
func (f *file) Close() error {
f.closeHandle()
return nil
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
if f.closing.Load() {
f.wgLock.RUnlock()
return nil, os.ErrClosed
}
f.wg.Add(1)
f.wgLock.RUnlock()
c := &ioOperation{}
c.ch = make(chan ioResult)
return c, nil
}
// ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h windows.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
if op == nil {
panic(err)
}
op.ch <- ioResult{bytes, err}
}
}
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != windows.ERROR_IO_PENDING {
return int(bytes), err
}
if f.closing.Load() {
windows.CancelIoEx(f.handle, &c.o)
}
var timeout timeoutChan
if d != nil {
d.channelLock.Lock()
timeout = d.channel
d.channelLock.Unlock()
}
var r ioResult
select {
case r = <-c.ch:
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.Load() {
err = os.ErrClosed
}
} else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32
err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
windows.CancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
err = os.ErrDeadlineExceeded
}
}
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
runtime.KeepAlive(c)
return int(r.bytes), err
}
// Read reads from a file handle.
func (f *file) Read(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.readDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
var bytes uint32
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)
// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == windows.ERROR_BROKEN_PIPE {
return 0, io.EOF
} else {
return n, err
}
}
// Write writes to a file handle.
func (f *file) Write(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.writeDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
var bytes uint32
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
}
func (f *file) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline)
}
func (f *file) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline)
}
func (f *file) Flush() error {
return windows.FlushFileBuffers(f.handle)
}
func (f *file) Fd() uintptr {
return uintptr(f.handle)
}
func (d *deadlineHandler) set(deadline time.Time) error {
d.setLock.Lock()
defer d.setLock.Unlock()
if d.timer != nil {
if !d.timer.Stop() {
<-d.channel
}
d.timer = nil
}
d.timedout.Store(false)
select {
case <-d.channel:
d.channelLock.Lock()
d.channel = make(chan struct{})
d.channelLock.Unlock()
default:
}
if deadline.IsZero() {
return nil
}
timeoutIO := func() {
d.timedout.Store(true)
close(d.channel)
}
now := time.Now()
duration := deadline.Sub(now)
if deadline.After(now) {
// Deadline is in the future, set a timer to wait
d.timer = time.AfterFunc(duration, timeoutIO)
} else {
// Deadline is in the past. Cancel all pending IO now.
timeoutIO()
}
return nil
}

485
ipc/namedpipe/namedpipe.go Normal file
View file

@ -0,0 +1,485 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
package namedpipe
import (
"context"
"io"
"net"
"os"
"runtime"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type pipe struct {
*file
path string
}
type messageBytePipe struct {
pipe
writeClosed atomic.Bool
readEOF bool
}
type pipeAddress string
func (f *pipe) LocalAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *pipe) RemoteAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t)
f.SetWriteDeadline(t)
return nil
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error {
if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe
}
err := f.file.Flush()
if err != nil {
f.writeClosed.Store(false)
return err
}
_, err = f.file.Write(nil)
if err != nil {
f.writeClosed.Store(false)
return err
}
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed.Load() {
return 0, io.ErrClosedPipe
}
if len(b) == 0 {
return 0, nil
}
return f.file.Write(b)
}
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
// mode pipe will return io.EOF, as will all subsequent reads.
func (f *messageBytePipe) Read(b []byte) (int, error) {
if f.readEOF {
return 0, io.EOF
}
n, err := f.file.Read(b)
if err == io.EOF {
// If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read calls
// also return EOF.
f.readEOF = true
} else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
err = nil
}
return n, err
}
func (f *pipe) Handle() windows.Handle {
return f.handle
}
func (s pipeAddress) Network() string {
return "pipe"
}
func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
path16, err := windows.UTF16PtrFromString(*path)
if err != nil {
return 0, err
}
h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(10 * time.Millisecond)
}
}
}
// DialConfig exposes various options for use in Dial and DialContext.
type DialConfig struct {
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
}
// DialTimeout connects to the specified named pipe by path, timing out if the
// connection takes longer than the specified duration. If timeout is zero, then
// we use a default timeout of 2 seconds.
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
if timeout == 0 {
timeout = time.Second * 2
}
absTimeout := time.Now().Add(timeout)
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := config.DialContext(ctx, path)
if err == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded
}
return conn, err
}
// DialContext attempts to connect to the specified named pipe by path.
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
var err error
var h windows.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
if config.ExpectedOwner != nil {
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
if err != nil {
windows.Close(h)
return nil, err
}
realOwner, _, err := sd.Owner()
if err != nil {
windows.Close(h)
return nil, err
}
if !realOwner.Equals(config.ExpectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
}
}
var flags uint32
err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
windows.Close(h)
return nil, err
}
f, err := makeFile(h)
if err != nil {
windows.Close(h)
return nil, err
}
// If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite.
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
return &messageBytePipe{
pipe: pipe{file: f, path: path},
}, nil
}
return &pipe{file: f, path: path}, nil
}
var defaultDialer DialConfig
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(path, timeout)
}
// DialContext calls DialConfig.DialContext using an empty configuration.
func DialContext(ctx context.Context, path string) (net.Conn, error) {
return defaultDialer.DialContext(ctx, path)
}
type acceptResponse struct {
f *file
err error
}
type pipeListener struct {
firstHandle windows.Handle
path string
config ListenConfig
acceptCh chan chan acceptResponse
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa windows.OBJECT_ATTRIBUTES
oa.Length = uint32(unsafe.Sizeof(oa))
var ntPath windows.NTUnicodeString
if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
if ntstatus, ok := err.(windows.NTStatus); ok {
err = ntstatus.Errno()
}
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if isFirstPipe {
if sd != nil {
oa.SecurityDescriptor = sd
} else {
// Construct the default named pipe security descriptor.
var acl *windows.ACL
if err := windows.RtlDefaultNpAcl(&acl); err != nil {
return 0, err
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
sd, err = windows.NewSecurityDescriptor()
if err != nil {
return 0, err
}
if err = sd.SetDACL(acl, true, false); err != nil {
return 0, err
}
oa.SecurityDescriptor = sd
}
}
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= windows.FILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(windows.FILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if isFirstPipe {
disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with isFirstPipe == false.
access = windows.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h windows.Handle
iosb windows.IO_STATUS_BLOCK
)
err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
if err != nil {
if ntstatus, ok := err.(windows.NTStatus); ok {
err = ntstatus.Errno()
}
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *pipeListener) makeServerPipe() (*file, error) {
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
f, err := makeFile(h)
if err != nil {
windows.Close(h)
return nil, err
}
return f, nil
}
func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
p, err := l.makeServerPipe()
if err != nil {
return nil, err
}
// Wait for the client to connect.
ch := make(chan error)
go func(p *file) {
ch <- connectPipe(p)
}(p)
select {
case err = <-ch:
if err != nil {
p.Close()
p = nil
}
case <-l.closeCh:
// Abort the connect request by closing the handle.
p.Close()
p = nil
err = <-ch
if err == nil || err == os.ErrClosed {
err = net.ErrClosed
}
}
return p, err
}
func (l *pipeListener) listenerRoutine() {
closed := false
for !closed {
select {
case <-l.closeCh:
closed = true
case responseCh := <-l.acceptCh:
var (
p *file
err error
)
for {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != windows.ERROR_NO_DATA {
break
}
}
responseCh <- acceptResponse{p, err}
closed = err == net.ErrClosed
}
}
windows.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close and Accept callers that the handle has been closed.
close(l.doneCh)
}
// ListenConfig contains configuration for the pipe listener.
type ListenConfig struct {
// SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
// this implementation is that CloseWrite is only supported for message mode pipes;
// CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
// transferred to the reader (and returned as io.EOF in this implementation)
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
InputBufferSize int32
// OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
OutputBufferSize int32
}
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
// The pipe must not already exist.
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil {
return nil, err
}
l := &pipeListener{
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan chan acceptResponse),
closeCh: make(chan int),
doneCh: make(chan int),
}
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
path16, err := windows.UTF16PtrFromString(path)
if err == nil {
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
if err == nil {
windows.CloseHandle(h)
}
}
}
go l.listenerRoutine()
return l, nil
}
var defaultListener ListenConfig
// Listen calls ListenConfig.Listen using an empty configuration.
func Listen(path string) (net.Listener, error) {
return defaultListener.Listen(path)
}
func connectPipe(p *file) error {
c, err := p.prepareIo()
if err != nil {
return err
}
defer p.wg.Done()
err = windows.ConnectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err
}
return nil
}
func (l *pipeListener) Accept() (net.Conn, error) {
ch := make(chan acceptResponse)
select {
case l.acceptCh <- ch:
response := <-ch
err := response.err
if err != nil {
return nil, err
}
if l.config.MessageMode {
return &messageBytePipe{
pipe: pipe{file: response.f, path: l.path},
}, nil
}
return &pipe{file: response.f, path: l.path}, nil
case <-l.doneCh:
return nil, net.ErrClosed
}
}
func (l *pipeListener) Close() error {
select {
case l.closeCh <- 1:
<-l.doneCh
case <-l.doneCh:
}
return nil
}
func (l *pipeListener) Addr() net.Addr {
return pipeAddress(l.path)
}

View file

@ -0,0 +1,674 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
package namedpipe_test
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net"
"os"
"sync"
"syscall"
"testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)
func randomPipePath() string {
guid, err := windows.GenerateGUID()
if err != nil {
panic(err)
}
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
}
func TestPingPong(t *testing.T) {
const (
ping = 42
pong = 24
)
pipePath := randomPipePath()
listener, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatalf("unable to listen on pipe: %v", err)
}
defer listener.Close()
go func() {
incoming, err := listener.Accept()
if err != nil {
t.Fatalf("unable to accept pipe connection: %v", err)
}
defer incoming.Close()
var data [1]byte
_, err = incoming.Read(data[:])
if err != nil {
t.Fatalf("unable to read ping from pipe: %v", err)
}
if data[0] != ping {
t.Fatalf("expected ping, got %d", data[0])
}
data[0] = pong
_, err = incoming.Write(data[:])
if err != nil {
t.Fatalf("unable to write pong to pipe: %v", err)
}
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatalf("unable to dial pipe: %v", err)
}
defer client.Close()
client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte
data[0] = ping
_, err = client.Write(data[:])
if err != nil {
t.Fatalf("unable to write ping to pipe: %v", err)
}
_, err = client.Read(data[:])
if err != nil {
t.Fatalf("unable to read pong from pipe: %v", err)
}
if data[0] != pong {
t.Fatalf("expected pong, got %d", data[0])
}
}
func TestDialUnknownFailsImmediately(t *testing.T) {
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err)
}
}
func TestDialListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func TestDialContextListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d)
pipe, err := namedpipe.DialContext(ctx, pipePath)
if err == nil {
pipe.Close()
}
if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
}
func TestDialListenerGetsCancelled(t *testing.T) {
pipePath := randomPipePath()
ctx, cancel := context.WithCancel(context.Background())
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
ch := make(chan error)
go func(ctx context.Context, ch chan error) {
_, err := namedpipe.DialContext(ctx, pipePath)
ch <- err
}(ctx, ch)
time.Sleep(time.Millisecond * 30)
cancel()
err = <-ch
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
t.Skip("dacls on named pipes are broken on wine")
}
pipePath := randomPipePath()
sd, _ := windows.SecurityDescriptorFromString("D:")
l, err := (&namedpipe.ListenConfig{
SecurityDescriptor: sd,
}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
pipe.Close()
}
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
}
}
func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
pipePath := randomPipePath()
if cfg == nil {
cfg = &namedpipe.ListenConfig{}
}
l, err := cfg.Listen(pipePath)
if err != nil {
return
}
defer l.Close()
type response struct {
c net.Conn
err error
}
ch := make(chan response)
go func() {
c, err := l.Accept()
ch <- response{c, err}
}()
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
return
}
r := <-ch
if err = r.err; err != nil {
c.Close()
return
}
client = c
server = r.c
return
}
func TestReadTimeout(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
buf := make([]byte, 10)
_, err = c.Read(buf)
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func server(l net.Listener, ch chan int) {
c, err := l.Accept()
if err != nil {
panic(err)
}
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
s, err := rw.ReadString('\n')
if err != nil {
panic(err)
}
_, err = rw.WriteString("got " + s)
if err != nil {
panic(err)
}
err = rw.Flush()
if err != nil {
panic(err)
}
c.Close()
ch <- 1
}
func TestFullListenDialReadWrite(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
ch := make(chan int)
go server(l, ch)
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer c.Close()
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
_, err = rw.WriteString("hello world\n")
if err != nil {
t.Fatal(err)
}
err = rw.Flush()
if err != nil {
t.Fatal(err)
}
s, err := rw.ReadString('\n')
if err != nil {
t.Fatal(err)
}
ms := "got hello world\n"
if s != ms {
t.Errorf("expected '%s', got '%s'", ms, s)
}
<-ch
}
func TestCloseAbortsListen(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
_, err := l.Accept()
ch <- err
}()
time.Sleep(30 * time.Millisecond)
l.Close()
err = <-ch
if err != net.ErrClosed {
t.Fatalf("expected net.ErrClosed, got %v", err)
}
}
func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
b := make([]byte, 10)
w.Close()
n, err := r.Read(b)
if n > 0 {
t.Errorf("unexpected byte count %d", n)
}
if err != io.EOF {
t.Errorf("expected EOF: %v", err)
}
}
func TestCloseClientEOFServer(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
ensureEOFOnClose(t, c, s)
}
func TestCloseServerEOFClient(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
ensureEOFOnClose(t, s, c)
}
func TestCloseWriteEOF(t *testing.T) {
cfg := &namedpipe.ListenConfig{
MessageMode: true,
}
c, s, err := getConnection(cfg)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
type closeWriter interface {
CloseWrite() error
}
err = c.(closeWriter).CloseWrite()
if err != nil {
t.Fatal(err)
}
b := make([]byte, 10)
_, err = s.Read(b)
if err != io.EOF {
t.Fatal(err)
}
}
func TestAcceptAfterCloseFails(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
l.Close()
_, err = l.Accept()
if err != net.ErrClosed {
t.Fatalf("expected net.ErrClosed, got %v", err)
}
}
func TestDialTimesOutByDefault(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func TestTimeoutPendingRead(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
serverDone := make(chan struct{})
go func() {
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
s.Close()
close(serverDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientErr := make(chan error)
go func() {
buf := make([]byte, 10)
_, err = client.Read(buf)
clientErr <- err
}()
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
client.SetReadDeadline(time.Unix(1, 0))
select {
case err = <-clientErr:
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out while waiting for read to cancel")
<-clientErr
}
<-serverDone
}
func TestTimeoutPendingWrite(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
serverDone := make(chan struct{})
go func() {
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
s.Close()
close(serverDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientErr := make(chan error)
go func() {
_, err = client.Write([]byte("this should timeout"))
clientErr <- err
}()
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
client.SetWriteDeadline(time.Unix(1, 0))
select {
case err = <-clientErr:
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out while waiting for write to cancel")
<-clientErr
}
<-serverDone
}
type CloseWriter interface {
CloseWrite() error
}
func TestEchoWithMessaging(t *testing.T) {
pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{
MessageMode: true, // Use message mode so that CloseWrite() is supported
InputBufferSize: 65536, // Use 64KB buffers to improve performance
OutputBufferSize: 65536,
}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
listenerDone := make(chan bool)
clientDone := make(chan bool)
go func() {
// server echo
conn, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
_, err = io.Copy(conn, conn)
if err != nil {
t.Fatal(err)
}
conn.(CloseWriter).CloseWrite()
close(listenerDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Second)
if err != nil {
t.Fatal(err)
}
defer client.Close()
go func() {
// client read back
bytes := make([]byte, 2)
n, e := client.Read(bytes)
if e != nil {
t.Fatal(e)
}
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n)
}
close(clientDone)
}()
payload := make([]byte, 2)
payload[0] = 0
payload[1] = 1
n, err := client.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("expected 2 bytes, got %v", n)
}
client.(CloseWriter).CloseWrite()
<-listenerDone
<-clientDone
}
func TestConnectRace(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
go func() {
for {
s, err := l.Accept()
if err == net.ErrClosed {
return
}
if err != nil {
t.Fatal(err)
}
s.Close()
}
}()
for i := 0; i < 1000; i++ {
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
c.Close()
}
}
func TestMessageReadMode(t *testing.T) {
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
t.Skipf("Skipping on Windows %d", maj)
}
var wg sync.WaitGroup
defer wg.Wait()
pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := ([]byte)("hello world")
wg.Add(1)
go func() {
defer wg.Done()
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
_, err = s.Write(msg)
if err != nil {
t.Fatal(err)
}
s.Close()
}()
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer c.Close()
mode := uint32(windows.PIPE_READMODE_MESSAGE)
err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
if err != nil {
t.Fatal(err)
}
ch := make([]byte, 1)
var vmsg []byte
for {
n, err := c.Read(ch)
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("expected 1, got %d", n)
}
vmsg = append(vmsg, ch[0])
}
if !bytes.Equal(msg, vmsg) {
t.Fatalf("expected %s, got %s", msg, vmsg)
}
}
func TestListenConnectRace(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long race test")
}
pipePath := randomPipePath()
for i := 0; i < 50 && !t.Failed(); i++ {
var wg sync.WaitGroup
wg.Add(1)
go func() {
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
c.Close()
}
wg.Done()
}()
s, err := namedpipe.Listen(pipePath)
if err != nil {
t.Error(i, err)
} else {
s.Close()
}
wg.Wait()
}
}

View file

@ -1,30 +1,19 @@
// +build darwin freebsd openbsd //go:build darwin || freebsd || openbsd
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package ipc
import ( import (
"errors" "errors"
"fmt"
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"path"
"unsafe" "unsafe"
)
var socketDirectory = "/var/run/wireguard" "golang.org/x/sys/unix"
const (
ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock"
) )
type UAPIListener struct { type UAPIListener struct {
@ -65,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@ -83,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
unixListener.SetUnlinkOnClose(true) unixListener.SetUnlinkOnClose(true)
} }
socketPath := path.Join( socketPath := sockPath(name)
socketDirectory,
fmt.Sprintf(socketName, name),
)
// watch for deletion of socket // watch for deletion of socket
@ -118,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err l.connErr <- err
return return
} }
if kerr != nil || n != 1 { if (kerr != nil || n != 1) && kerr != unix.EINTR {
if kerr != nil { if kerr != nil {
l.connErr <- kerr l.connErr <- kerr
} else { } else {
@ -145,58 +130,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil return uapi, nil
} }
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
unix.Umask(oldUmask)
if err != nil {
return nil, err
}
return listener.File()
}

View file

@ -1,28 +1,16 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package ipc
import ( import (
"errors"
"fmt"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"path"
)
var socketDirectory = "/var/run/wireguard" "github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix"
const (
ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock"
) )
type UAPIListener struct { type UAPIListener struct {
@ -63,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@ -83,10 +70,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
// watch for deletion of socket // watch for deletion of socket
socketPath := path.Join( socketPath := sockPath(name)
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit() uapi.inotifyFd, err = unix.InotifyInit()
if err != nil { if err != nil {
@ -112,14 +96,15 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
} }
go func(l *UAPIListener) { go func(l *UAPIListener) {
var buff [0]byte var buf [0]byte
for { for {
defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition // start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) { if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err l.connErr <- err
return return
} }
_, err := uapi.inotifyRWCancel.Read(buff[:]) _, err := uapi.inotifyRWCancel.Read(buf[:])
if err != nil { if err != nil {
l.connErr <- err l.connErr <- err
return return
@ -142,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil return uapi, nil
} }
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
unix.Umask(oldUmask)
if err != nil {
return nil, err
}
return listener.File()
}

66
ipc/uapi_unix.go Normal file
View file

@ -0,0 +1,66 @@
//go:build linux || darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
"fmt"
"net"
"os"
"golang.org/x/sys/unix"
)
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
IpcErrorUnknown = -55 // ENOANO
)
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
var socketDirectory = "/var/run/amneziawg"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
}
func UAPIOpen(name string) (*os.File, error) {
if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
return nil, err
}
socketPath := sockPath(name)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0o077)
defer unix.Umask(oldUmask)
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener.File()
}
// Test socket, if not in use cleanup and try again.
if _, err := net.Dial("unix", socketPath); err == nil {
return nil, errors.New("unix socket in use")
}
if err := os.Remove(socketPath); err != nil {
return nil, err
}
listener, err = net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
return listener.File()
}

15
ipc/uapi_wasm.go Normal file
View file

@ -0,0 +1,15 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
// Made up sentinel error codes for {js,wasip1}/wasm.
const (
IpcErrorIO = 1
IpcErrorInvalid = 2
IpcErrorPortInUse = 3
IpcErrorUnknown = 4
IpcErrorProtocol = 5
)

88
ipc/uapi_windows.go Normal file
View file

@ -0,0 +1,88 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"net"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)
// TODO: replace these with actual standard windows error numbers from the win package
const (
IpcErrorIO = -int64(5)
IpcErrorProtocol = -int64(71)
IpcErrorInvalid = -int64(22)
IpcErrorPortInUse = -int64(98)
IpcErrorUnknown = -int64(55)
)
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
connErr chan error
kqueueFd int
keventFd int
}
func (l *UAPIListener) Accept() (net.Conn, error) {
for {
select {
case conn := <-l.connNew:
return conn, nil
case err := <-l.connErr:
return nil, err
}
}
}
func (l *UAPIListener) Close() error {
return l.listener.Close()
}
func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr()
}
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
func init() {
var err error
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
if err != nil {
panic(err)
}
}
func UAPIListen(name string) (net.Listener, error) {
listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
if err != nil {
return nil, err
}
uapi := &UAPIListener{
listener: listener,
connNew: make(chan net.Conn, 1),
connErr: make(chan error, 1),
}
go func(l *UAPIListener) {
for {
conn, err := l.listener.Accept()
if err != nil {
l.connErr <- err
break
}
l.connNew <- conn
}
}(uapi)
return uapi, nil
}

View file

@ -1,59 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"io"
"io/ioutil"
"log"
"os"
)
const (
LogLevelSilent = iota
LogLevelError
LogLevelInfo
LogLevelDebug
)
type Logger struct {
Debug *log.Logger
Info *log.Logger
Error *log.Logger
}
func NewLogger(level int, prepend string) *Logger {
output := os.Stdout
logger := new(Logger)
logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
if level >= LogLevelDebug {
return output, output, output
}
if level >= LogLevelInfo {
return output, output, ioutil.Discard
}
if level >= LogLevelError {
return output, ioutil.Discard, ioutil.Discard
}
return ioutil.Discard, ioutil.Discard, ioutil.Discard
}()
logger.Debug = log.New(logDebug,
"DEBUG: "+prepend,
log.Ldate|log.Ltime,
)
logger.Info = log.New(logInfo,
"INFO: "+prepend,
log.Ldate|log.Ltime,
)
logger.Error = log.New(logErr,
"ERROR: "+prepend,
log.Ldate|log.Ltime,
)
return logger
}

133
main.go
View file

@ -1,18 +1,24 @@
/* SPDX-License-Identifier: GPL-2.0 //go:build !windows
/* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
import ( import (
"fmt" "fmt"
"git.zx2c4.com/wireguard-go/tun"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/sys/unix"
) )
const ( const (
@ -27,61 +33,38 @@ const (
) )
func printUsage() { func printUsage() {
fmt.Printf("usage:\n") fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
} }
func warning() { func warning() {
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { switch runtime.GOOS {
case "linux", "freebsd", "openbsd":
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return
}
default:
return return
} }
shouldQuit := false fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G") fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "W do what it is supposed to do, and things may go G") fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G") fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
fmt.Fprintln(os.Stderr, "W at your own risk. G") fmt.Fprintln(os.Stderr, "│ │")
if runtime.GOOS == "linux" { fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G")
fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G")
fmt.Fprintln(os.Stderr, "W foolish. This is because the Linux kernel has G")
fmt.Fprintln(os.Stderr, "W built-in first class support for WireGuard, and G")
fmt.Fprintln(os.Stderr, "W this support is much more refined than this G")
fmt.Fprintln(os.Stderr, "W program. For more information on installing the G")
fmt.Fprintln(os.Stderr, "W kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit {
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
}
}
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
if shouldQuit {
os.Exit(1)
}
} }
func main() { func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" { if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", WireGuardGoVersion, runtime.GOOS, runtime.GOARCH) fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
return return
} }
warning() warning()
// parse arguments
var foreground bool var foreground bool
var interfaceName string var interfaceName string
if len(os.Args) < 2 || len(os.Args) > 3 { if len(os.Args) < 2 || len(os.Args) > 3 {
@ -116,24 +99,22 @@ func main() {
logLevel := func() int { logLevel := func() int {
switch os.Getenv("LOG_LEVEL") { switch os.Getenv("LOG_LEVEL") {
case "debug": case "verbose", "debug":
return LogLevelDebug return device.LogLevelVerbose
case "info":
return LogLevelInfo
case "error": case "error":
return LogLevelError return device.LogLevelError
case "silent": case "silent":
return LogLevelSilent return device.LogLevelSilent
} }
return LogLevelInfo return device.LogLevelError
}() }()
// open TUN device (or use supplied fd) // open TUN device (or use supplied fd)
tun, err := func() (tun.TUNDevice, error) { tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return tun.CreateTUN(interfaceName, DefaultMTU) return tun.CreateTUN(interfaceName, device.DefaultMTU)
} }
// construct tun device from supplied fd // construct tun device from supplied fd
@ -143,28 +124,31 @@ func main() {
return nil, err return nil, err
} }
err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "") file := os.NewFile(uintptr(fd), "")
return tun.CreateTUNFromFile(file, DefaultMTU) return tun.CreateTUNFromFile(file, device.DefaultMTU)
}() }()
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tdev.Name()
if err2 == nil { if err2 == nil {
interfaceName = realInterfaceName interfaceName = realInterfaceName
} }
} }
logger := NewLogger( logger := device.NewLogger(
logLevel, logLevel,
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Info.Println("Starting wireguard-go version", WireGuardGoVersion) logger.Verbosef("Starting amneziawg-go version %s", Version)
logger.Debug.Println("Debug log enabled")
if err != nil { if err != nil {
logger.Error.Println("Failed to create TUN device:", err) logger.Errorf("Failed to create TUN device: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -173,7 +157,7 @@ func main() {
fileUAPI, err := func() (*os.File, error) { fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
if uapiFdStr == "" { if uapiFdStr == "" {
return UAPIOpen(interfaceName) return ipc.UAPIOpen(interfaceName)
} }
// use supplied fd // use supplied fd
@ -185,9 +169,8 @@ func main() {
return os.NewFile(uintptr(fd), ""), nil return os.NewFile(uintptr(fd), ""), nil
}() }()
if err != nil { if err != nil {
logger.Error.Println("UAPI listen error:", err) logger.Errorf("UAPI listen error: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
return return
} }
@ -199,7 +182,7 @@ func main() {
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND))
files := [3]*os.File{} files := [3]*os.File{}
if os.Getenv("LOG_LEVEL") != "" && logLevel != LogLevelSilent { if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent {
files[0], _ = os.Open(os.DevNull) files[0], _ = os.Open(os.DevNull)
files[1] = os.Stdout files[1] = os.Stdout
files[2] = os.Stderr files[2] = os.Stderr
@ -213,7 +196,7 @@ func main() {
files[0], // stdin files[0], // stdin
files[1], // stdout files[1], // stdout
files[2], // stderr files[2], // stderr
tun.File(), tdev.File(),
fileUAPI, fileUAPI,
}, },
Dir: ".", Dir: ".",
@ -222,7 +205,7 @@ func main() {
path, err := os.Executable() path, err := os.Executable()
if err != nil { if err != nil {
logger.Error.Println("Failed to determine executable:", err) logger.Errorf("Failed to determine executable: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -232,23 +215,23 @@ func main() {
attr, attr,
) )
if err != nil { if err != nil {
logger.Error.Println("Failed to daemonize:", err) logger.Errorf("Failed to daemonize: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
process.Release() process.Release()
return return
} }
device := NewDevice(tun, logger) device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
logger.Info.Println("Device started") logger.Verbosef("Device started")
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal, 1) term := make(chan os.Signal, 1)
uapi, err := UAPIListen(interfaceName, fileUAPI) uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
if err != nil { if err != nil {
logger.Error.Println("Failed to listen on uapi socket:", err) logger.Errorf("Failed to listen on uapi socket: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -259,15 +242,15 @@ func main() {
errs <- err errs <- err
return return
} }
go ipcHandle(device, conn) go device.IpcHandle(conn)
} }
}() }()
logger.Info.Println("UAPI listener started") logger.Verbosef("UAPI listener started")
// wait for program to terminate // wait for program to terminate
signal.Notify(term, syscall.SIGTERM) signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
select { select {
@ -281,5 +264,5 @@ func main() {
uapi.Close() uapi.Close()
device.Close() device.Close()
logger.Info.Println("Shutting down") logger.Verbosef("Shutting down")
} }

99
main_windows.go Normal file
View file

@ -0,0 +1,99 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"fmt"
"os"
"os/signal"
"golang.org/x/sys/windows"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
const (
ExitSetupSuccess = 0
ExitSetupFailed = 1
)
func main() {
if len(os.Args) != 2 {
os.Exit(ExitSetupFailed)
}
interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
logger := device.NewLogger(
device.LogLevelVerbose,
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting amneziawg-go version %s", Version)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {
realInterfaceName, err2 := tun.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
} else {
logger.Errorf("Failed to create TUN device: %v", err)
os.Exit(ExitSetupFailed)
}
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
err = device.Up()
if err != nil {
logger.Errorf("Failed to bring up device: %v", err)
os.Exit(ExitSetupFailed)
}
logger.Verbosef("Device started")
uapi, err := ipc.UAPIListen(interfaceName)
if err != nil {
logger.Errorf("Failed to listen on uapi socket: %v", err)
os.Exit(ExitSetupFailed)
}
errs := make(chan error)
term := make(chan os.Signal, 1)
go func() {
for {
conn, err := uapi.Accept()
if err != nil {
errs <- err
return
}
go device.IpcHandle(conn)
}
}()
logger.Verbosef("UAPI listener started")
// wait for program to terminate
signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill)
signal.Notify(term, windows.SIGTERM)
select {
case <-term:
case <-errs:
case <-device.Wait():
}
// clean up
uapi.Close()
device.Close()
logger.Verbosef("Shutting down")
}

48
misc.go
View file

@ -1,48 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
flag int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.flag, flag)
}
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

270
peer.go
View file

@ -1,270 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"encoding/base64"
"errors"
"fmt"
"sync"
"time"
)
const (
PeerRoutineNumber = 3
)
type Peer struct {
isRunning AtomicBool
mutex sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
endpoint Endpoint
persistentKeepaliveInterval uint16
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
handshakeAttempts uint32
needAnotherKeepalive AtomicBool
sentLastMinuteHandshake AtomicBool
}
signals struct {
newKeypairArrived chan struct{}
flushNonceQueue chan struct{}
}
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
packetInNonceQueueIsAwaitingKey AtomicBool
}
routines struct {
mutex sync.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
stop chan struct{} // size 0, stop all go routines in peer
}
cookieGenerator CookieGenerator
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed.Get() {
return nil, errors.New("device closed")
}
// lock resources
device.staticIdentity.mutex.RLock()
defer device.staticIdentity.mutex.RUnlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// check if over limit
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers")
}
// create peer
peer := new(Peer)
peer.mutex.Lock()
defer peer.mutex.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.isRunning.Set(false)
// map public key
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
}
device.peers.keyMap[pk] = peer
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
// start peer
if peer.device.isUp.Get() {
peer.Start()
}
return peer, nil
}
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
if peer.device.net.bind == nil {
return errors.New("no bind")
}
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if peer.endpoint == nil {
return errors.New("no known endpoint for peer")
}
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
func (peer *Peer) String() string {
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
abbreviatedKey := "invalid"
if len(base64Key) == 44 {
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
}
return fmt.Sprintf("peer(%s)", abbreviatedKey)
}
func (peer *Peer) Start() {
// should never start a peer on a closed device
if peer.device.isClosed.Get() {
return
}
// prevent simultaneous start/stop operations
peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
if peer.isRunning.Get() {
return
}
device := peer.device
device.log.Debug.Println(peer, "- Starting...")
// reset routine state
peer.routines.starting.Wait()
peer.routines.stopping.Wait()
peer.routines.stop = make(chan struct{})
peer.routines.starting.Add(PeerRoutineNumber)
peer.routines.stopping.Add(PeerRoutineNumber)
// prepare queues
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
peer.timersInit()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
peer.signals.newKeypairArrived = make(chan struct{}, 1)
peer.signals.flushNonceQueue = make(chan struct{}, 1)
// wait for routines to start
go peer.RoutineNonce()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
peer.routines.starting.Wait()
peer.isRunning.Set(true)
}
func (peer *Peer) ZeroAndFlushAll() {
device := peer.device
// clear key pairs
keypairs := &peer.keypairs
keypairs.mutex.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next)
keypairs.previous = nil
keypairs.current = nil
keypairs.next = nil
keypairs.mutex.Unlock()
// clear handshake state
handshake := &peer.handshake
handshake.mutex.Lock()
device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
handshake.mutex.Unlock()
peer.FlushNonceQueue()
}
func (peer *Peer) Stop() {
// prevent simultaneous start/stop operations
if !peer.isRunning.Swap(false) {
return
}
peer.routines.starting.Wait()
peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println(peer, "- Stopping...")
peer.timersStop()
// stop & wait for ongoing peer routines
close(peer.routines.stop)
peer.routines.stopping.Wait()
// close queues
close(peer.queue.nonce)
close(peer.queue.outbound)
close(peer.queue.inbound)
peer.ZeroAndFlushAll()
}
var roamingDisabled bool
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
if roamingDisabled {
return
}
peer.mutex.Lock()
peer.endpoint = endpoint
peer.mutex.Unlock()
}

View file

@ -1,89 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import "sync"
func (device *Device) PopulatePools() {
if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool = &sync.Pool{
New: func() interface{} {
return new([MaxMessageSize]byte)
},
}
device.pool.inboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueInboundElement)
},
}
device.pool.outboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueOutboundElement)
},
}
} else {
device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
}
device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.inboundElementReuseChan <- new(QueueInboundElement)
}
device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
}
}
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
if PreallocatedBuffersPerPool == 0 {
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
} else {
return <-device.pool.messageBufferReuseChan
}
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool.Put(msg)
} else {
device.pool.messageBufferReuseChan <- msg
}
}
func (device *Device) GetInboundElement() *QueueInboundElement {
if PreallocatedBuffersPerPool == 0 {
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
} else {
return <-device.pool.inboundElementReuseChan
}
}
func (device *Device) PutInboundElement(msg *QueueInboundElement) {
if PreallocatedBuffersPerPool == 0 {
device.pool.inboundElementPool.Put(msg)
} else {
device.pool.inboundElementReuseChan <- msg
}
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
if PreallocatedBuffersPerPool == 0 {
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
} else {
return <-device.pool.outboundElementReuseChan
}
}
func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
if PreallocatedBuffersPerPool == 0 {
device.pool.outboundElementPool.Put(msg)
} else {
device.pool.outboundElementReuseChan <- msg
}
}

View file

@ -1,12 +1,12 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net" "net/netip"
"sync" "sync"
"time" "time"
) )
@ -20,21 +20,22 @@ const (
) )
type RatelimiterEntry struct { type RatelimiterEntry struct {
mutex sync.Mutex mu sync.Mutex
lastTime time.Time lastTime time.Time
tokens int64 tokens int64
} }
type Ratelimiter struct { type Ratelimiter struct {
mutex sync.RWMutex mu sync.RWMutex
stopReset chan struct{} timeNow func() time.Time
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry stopReset chan struct{} // send to reset, close to stop
table map[netip.Addr]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
rate.mutex.Lock() rate.mu.Lock()
defer rate.mutex.Unlock() defer rate.mu.Unlock()
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
@ -42,111 +43,83 @@ func (rate *Ratelimiter) Close() {
} }
func (rate *Ratelimiter) Init() { func (rate *Ratelimiter) Init() {
rate.mutex.Lock() rate.mu.Lock()
defer rate.mutex.Unlock() defer rate.mu.Unlock()
if rate.timeNow == nil {
rate.timeNow = time.Now
}
// stop any ongoing garbage collection routine // stop any ongoing garbage collection routine
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
} }
rate.stopReset = make(chan struct{}) rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.table = make(map[netip.Addr]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
// start garbage collection routine stopReset := rate.stopReset // store in case Init is called again.
// Start garbage collection routine.
go func() { go func() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
ticker.Stop() ticker.Stop()
for { for {
select { select {
case _, ok := <-rate.stopReset: case _, ok := <-stopReset:
ticker.Stop() ticker.Stop()
if ok { if !ok {
ticker = time.NewTicker(time.Second)
} else {
return return
} }
ticker = time.NewTicker(time.Second)
case <-ticker.C: case <-ticker.C:
func() { if rate.cleanup() {
rate.mutex.Lock() ticker.Stop()
defer rate.mutex.Unlock() }
for key, entry := range rate.tableIPv4 {
entry.mutex.Lock()
if time.Now().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
}
entry.mutex.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.mutex.Lock()
if time.Now().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mutex.Unlock()
}
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
ticker.Stop()
}
}()
} }
} }
}() }()
} }
func (rate *Ratelimiter) Allow(ip net.IP) bool { func (rate *Ratelimiter) cleanup() (empty bool) {
var entry *RatelimiterEntry rate.mu.Lock()
var keyIPv4 [net.IPv4len]byte defer rate.mu.Unlock()
var keyIPv6 [net.IPv6len]byte
// lookup entry for key, entry := range rate.table {
entry.mu.Lock()
IPv4 := ip.To4() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
IPv6 := ip.To16() delete(rate.table, key)
}
rate.mutex.RLock() entry.mu.Unlock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
} }
rate.mutex.RUnlock() return len(rate.table) == 0
}
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
// lookup entry
rate.mu.RLock()
entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found // make new entry if not found
if entry == nil { if entry == nil {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = time.Now() entry.lastTime = rate.timeNow()
rate.mutex.Lock() rate.mu.Lock()
if IPv4 != nil { rate.table[ip] = entry
rate.tableIPv4[keyIPv4] = entry if len(rate.table) == 1 {
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { rate.stopReset <- struct{}{}
rate.stopReset <- struct{}{}
}
} else {
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
} }
rate.mutex.Unlock() rate.mu.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.mu.Lock()
entry.mutex.Lock() now := rate.timeNow()
now := time.Now()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now entry.lastTime = now
if entry.tokens > maxTokens { if entry.tokens > maxTokens {
@ -154,12 +127,11 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
} }
// subtract cost of packet // subtract cost of packet
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.mutex.Unlock() entry.mu.Unlock()
return true return true
} }
entry.mutex.Unlock() entry.mu.Unlock()
return false return false
} }

View file

@ -1,32 +1,31 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
) )
type RatelimiterResult struct { type result struct {
allowed bool allowed bool
text string text string
wait time.Duration wait time.Duration
} }
func TestRatelimiter(t *testing.T) { func TestRatelimiter(t *testing.T) {
var rate Ratelimiter
var expectedResults []result
var ratelimiter Ratelimiter nano := func(nano int64) time.Duration {
var expectedResults []RatelimiterResult
Nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano) return time.Nanosecond * time.Duration(nano)
} }
Add := func(res RatelimiterResult) { add := func(res result) {
expectedResults = append( expectedResults = append(
expectedResults, expectedResults,
res, res,
@ -34,69 +33,86 @@ func TestRatelimiter(t *testing.T) {
} }
for i := 0; i < packetsBurstable; i++ { for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
text: "inital burst", text: "initial burst",
}) })
} }
Add(RatelimiterResult{ add(result{
allowed: false, allowed: false,
text: "after burst", text: "after burst",
}) })
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond), wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet", text: "filling tokens for single packet",
}) })
Add(RatelimiterResult{ add(result{
allowed: false, allowed: false,
text: "not having refilled enough", text: "not having refilled enough",
}) })
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)), wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst", text: "filling tokens for two packet burst",
}) })
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
text: "second packet in 2 packet burst", text: "second packet in 2 packet burst",
}) })
Add(RatelimiterResult{ add(result{
allowed: false, allowed: false,
text: "packet following 2 packet burst", text: "packet following 2 packet burst",
}) })
ips := []net.IP{ ips := []netip.Addr{
net.ParseIP("127.0.0.1"), netip.MustParseAddr("127.0.0.1"),
net.ParseIP("192.168.1.1"), netip.MustParseAddr("192.168.1.1"),
net.ParseIP("172.167.2.3"), netip.MustParseAddr("172.167.2.3"),
net.ParseIP("97.231.252.215"), netip.MustParseAddr("97.231.252.215"),
net.ParseIP("248.97.91.167"), netip.MustParseAddr("248.97.91.167"),
net.ParseIP("188.208.233.47"), netip.MustParseAddr("188.208.233.47"),
net.ParseIP("104.2.183.179"), netip.MustParseAddr("104.2.183.179"),
net.ParseIP("72.129.46.120"), netip.MustParseAddr("72.129.46.120"),
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
} }
ratelimiter.Init() now := time.Now()
rate.timeNow = func() time.Time {
return now
}
defer func() {
// Lock to avoid data race with cleanup goroutine from Init.
rate.mu.Lock()
defer rate.mu.Unlock()
rate.timeNow = time.Now
}()
timeSleep := func(d time.Duration) {
now = now.Add(d + 1)
rate.cleanup()
}
rate.Init()
defer rate.Close()
for i, res := range expectedResults { for i, res := range expectedResults {
time.Sleep(res.wait) timeSleep(res.wait)
for _, ip := range ips { for _, ip := range ips {
allowed := ratelimiter.Allow(ip) allowed := rate.Allow(ip)
if allowed != res.allowed { if allowed != res.allowed {
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show more