mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-25 10:27:44 +02:00
Compare commits
183 commits
0.0.202104
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
27e661d68e | ||
|
71be0eb3a6 | ||
|
e3f1273f8a | ||
|
c97b5b7615 | ||
|
668ddfd455 | ||
|
b8da08c106 | ||
|
2e3f7d122c | ||
|
2e7780471a | ||
|
87d8c00f86 | ||
|
c00bda9200 | ||
|
d2b0fc9789 | ||
|
77d39ff3b9 | ||
|
e433d13df6 | ||
|
3ddf952973 | ||
|
3f0a3bcfa0 | ||
|
4dddf62e57 | ||
|
827ec6e14b | ||
|
92e28a0d14 | ||
|
52fed4d362 | ||
|
9c6b3ff332 | ||
|
7de7a9a754 | ||
|
0c347529b8 | ||
|
6705978fc8 | ||
|
032e33f577 | ||
|
59101fd202 | ||
|
8bcfbac230 | ||
|
f0dfb5eacc | ||
|
9195025d8f | ||
|
cbd414dfec | ||
|
7155d20913 | ||
|
bfeb3954f6 | ||
|
e3c9ec8012 | ||
|
ce9d3866a3 | ||
|
e5f355e843 | ||
|
c05b2ee2a3 | ||
|
180c9284f3 | ||
|
015e11875d | ||
|
12269c2761 | ||
|
542e565baa | ||
|
7c20311b3d | ||
|
4ffa9c2032 | ||
|
d0bc03c707 | ||
|
1cf89f5339 | ||
|
b43118018e | ||
|
7af55a3e6f | ||
|
c493b95f66 | ||
|
2e0774f246 | ||
|
b3df23dcd4 | ||
|
f502ec3fad | ||
|
5d37bd24e1 | ||
|
24ea13351e | ||
|
177caa7e44 | ||
|
b81ca925db | ||
|
42ec952ead | ||
|
ec8f6f82c2 | ||
|
1ec454f253 | ||
|
8a015f7c76 | ||
|
895d6c23cd | ||
|
4201e08f1d | ||
|
6a84778f2c | ||
|
b34974c476 | ||
|
f30419e0d1 | ||
|
8f1a6a10b2 | ||
|
469159ecf7 | ||
|
6e755e132a | ||
|
1f25eac395 | ||
|
25eb973e00 | ||
|
b7cd547315 | ||
|
052af4a807 | ||
|
aad7fca9c5 | ||
|
6f895be10d | ||
|
6a07b2a355 | ||
|
334b605e72 | ||
|
3a9e75374f | ||
|
cc20c08c96 | ||
|
1417a47c8f | ||
|
7f511c3bb1 | ||
|
07a1e55270 | ||
|
fff53afca7 | ||
|
0ad14a89f5 | ||
|
7d327ed35a | ||
|
f41f474466 | ||
|
5819c6af28 | ||
|
6901984f6a | ||
|
2fcdaf9799 | ||
|
dbd949307e | ||
|
f26efb65f2 | ||
|
f67c862a2a | ||
|
9e2f386022 | ||
|
3bb8fec7e4 | ||
|
21636207a6 | ||
|
c7b76d3d9e | ||
|
1e2c3e5a3c | ||
|
ebbd4a4330 | ||
|
0ae4b3177c | ||
|
077ce8ecab | ||
|
bb719d3a6e | ||
|
fde0a9525a | ||
|
b51010ba13 | ||
|
d1d08426b2 | ||
|
3381e21b18 | ||
|
c31a7b1ab4 | ||
|
6a08d81f6b | ||
|
ef5c587f78 | ||
|
193cf8d6a5 | ||
|
ee1c8e0e87 | ||
|
95b48cdb39 | ||
|
5aff28b14c | ||
|
46826fc4e5 | ||
|
42c9af45e1 | ||
|
ae6bc4dd64 | ||
|
2cec4d1a62 | ||
|
3b95c81cc1 | ||
|
b9669b734e | ||
|
e0b8f11489 | ||
|
114a3db918 | ||
|
9c9e7e2724 | ||
|
2dd424e2d8 | ||
|
387f7c461a | ||
|
4d87c9e824 | ||
|
ef8d6804d7 | ||
|
de7c702ace | ||
|
fc4f975a4d | ||
|
9d699ba730 | ||
|
425f7c726b | ||
|
3cae233d69 | ||
|
111e0566dc | ||
|
e3134bf665 | ||
|
63abb5537b | ||
|
851efb1bb6 | ||
|
c07dd60cdb | ||
|
eb6302c7eb | ||
|
60683d7361 | ||
|
e42c6c4bc2 | ||
|
828a885a71 | ||
|
f1f626090e | ||
|
82e0b734e5 | ||
|
fdf57a1fa4 | ||
|
f87e87af0d | ||
|
ba9e364dab | ||
|
dfd688b6aa | ||
|
c01d52b66a | ||
|
82d2aa87aa | ||
|
982d5d2e84 | ||
|
642a56e165 | ||
|
bb745b2ea3 | ||
|
fcc601dbf0 | ||
|
217ac1016b | ||
|
eae5e0f3a3 | ||
|
2ef39d4754 | ||
|
3957e9b9dd | ||
|
bad6caeb82 | ||
|
c89f5ca665 | ||
|
15b24b6179 | ||
|
f9b48a961c | ||
|
d0cf96114f | ||
|
841756e328 | ||
|
c382222eab | ||
|
b41f4cc768 | ||
|
4a57024b94 | ||
|
64cb82f2b3 | ||
|
c27ff9b9f6 | ||
|
99e8b4ba60 | ||
|
bd83f0ac99 | ||
|
50d779833e | ||
|
a9b377e9e1 | ||
|
9087e444e6 | ||
|
25ad08a591 | ||
|
5846b62283 | ||
|
9844c74f67 | ||
|
4e9e5dad09 | ||
|
39e0b6dade | ||
|
7121927b87 | ||
|
326aec10af | ||
|
efb8818550 | ||
|
69b39db0b4 | ||
|
db733ccd65 | ||
|
a7aec4449f | ||
|
60a26371f4 | ||
|
a544776d70 | ||
|
69a42a4eef | ||
|
097af6e135 | ||
|
8246d251ea |
130 changed files with 7184 additions and 4997 deletions
41
.github/workflows/build-if-tag.yml
vendored
Normal file
41
.github/workflows/build-if-tag.yml
vendored
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
name: build-if-tag
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||||
|
|
||||||
|
env:
|
||||||
|
APP: amneziawg-go
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
name: build
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.ref_name }}
|
||||||
|
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Setup metadata
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
id: metadata
|
||||||
|
with:
|
||||||
|
images: amneziavpn/${{ env.APP }}
|
||||||
|
tags: type=semver,pattern={{version}}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.metadata.outputs.tags }}
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1 @@
|
||||||
wireguard-go
|
amneziawg-go
|
17
Dockerfile
Normal file
17
Dockerfile
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
FROM golang:1.24 as awg
|
||||||
|
COPY . /awg
|
||||||
|
WORKDIR /awg
|
||||||
|
RUN go mod download && \
|
||||||
|
go mod verify && \
|
||||||
|
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
|
||||||
|
|
||||||
|
FROM alpine:3.19
|
||||||
|
ARG AWGTOOLS_RELEASE="1.0.20241018"
|
||||||
|
RUN apk --no-cache add iproute2 iptables bash && \
|
||||||
|
cd /usr/bin/ && \
|
||||||
|
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
|
||||||
|
unzip -j alpine-3.19-amneziawg-tools.zip && \
|
||||||
|
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
|
||||||
|
ln -s /usr/bin/awg /usr/bin/wg && \
|
||||||
|
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
|
||||||
|
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go
|
16
Makefile
16
Makefile
|
@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
|
||||||
|
|
||||||
generate-version-and-build:
|
generate-version-and-build:
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
|
||||||
ver="$$(printf 'package main\nconst Version = "%s"\n' "$$tag")" && \
|
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
||||||
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
||||||
echo "$$ver" > version.go && \
|
echo "$$ver" > version.go && \
|
||||||
git update-index --assume-unchanged version.go || true
|
git update-index --assume-unchanged version.go || true
|
||||||
@$(MAKE) wireguard-go
|
@$(MAKE) amneziawg-go
|
||||||
|
|
||||||
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
|
||||||
go build -v -o "$@"
|
go build -v -o "$@"
|
||||||
|
|
||||||
install: wireguard-go
|
install: amneziawg-go
|
||||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test -v ./...
|
go test ./...
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f wireguard-go
|
rm -f amneziawg-go
|
||||||
|
|
||||||
.PHONY: all clean test install generate-version-and-build
|
.PHONY: all clean test install generate-version-and-build
|
||||||
|
|
61
README.md
61
README.md
|
@ -1,24 +1,27 @@
|
||||||
# Go Implementation of [WireGuard](https://www.wireguard.com/)
|
# Go Implementation of AmneziaWG
|
||||||
|
|
||||||
This is an implementation of WireGuard in Go.
|
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
|
||||||
|
|
||||||
|
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
|
||||||
|
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
|
||||||
|
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
|
Simply run:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ wireguard-go wg0
|
$ amneziawg-go wg0
|
||||||
```
|
```
|
||||||
|
|
||||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
|
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
|
||||||
|
|
||||||
To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
|
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ wireguard-go -f wg0
|
$ amneziawg-go -f wg0
|
||||||
```
|
```
|
||||||
|
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
|
@ -26,52 +29,24 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
|
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
||||||
|
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
|
||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
|
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
|
||||||
|
|
||||||
### FreeBSD
|
|
||||||
|
|
||||||
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
|
|
||||||
|
|
||||||
### OpenBSD
|
|
||||||
|
|
||||||
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
This requires an installation of [go](https://golang.org) ≥ 1.16.
|
This requires an installation of the latest version of [Go](https://go.dev/).
|
||||||
|
|
||||||
```
|
```
|
||||||
$ git clone https://git.zx2c4.com/wireguard-go
|
$ git clone https://github.com/amnezia-vpn/amneziawg-go
|
||||||
$ cd wireguard-go
|
$ cd amneziawg-go
|
||||||
$ make
|
$ make
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
|
||||||
the Software without restriction, including without limitation the rights to
|
|
||||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
||||||
of the Software, and to permit persons to whom the Software is furnished to do
|
|
||||||
so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
|
@ -1,579 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ipv4Source struct {
|
|
||||||
Src [4]byte
|
|
||||||
Ifindex int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ipv6Source struct {
|
|
||||||
src [16]byte
|
|
||||||
// ifindex belongs in dst.ZoneId
|
|
||||||
}
|
|
||||||
|
|
||||||
type LinuxSocketEndpoint struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
|
||||||
src [unsafe.Sizeof(ipv6Source{})]byte
|
|
||||||
isV6 bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
|
|
||||||
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
|
||||||
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
|
|
||||||
|
|
||||||
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
|
|
||||||
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
|
|
||||||
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
|
|
||||||
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
|
|
||||||
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
|
|
||||||
type LinuxSocketBind struct {
|
|
||||||
// mu guards sock4 and sock6 and the associated fds.
|
|
||||||
// As long as someone holds mu (read or write), the associated fds are valid.
|
|
||||||
mu sync.RWMutex
|
|
||||||
sock4 int
|
|
||||||
sock6 int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
|
|
||||||
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
|
|
||||||
|
|
||||||
var _ Endpoint = (*LinuxSocketEndpoint)(nil)
|
|
||||||
var _ Bind = (*LinuxSocketBind)(nil)
|
|
||||||
|
|
||||||
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
||||||
var end LinuxSocketEndpoint
|
|
||||||
addr, err := parseEndpoint(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4 := addr.IP.To4()
|
|
||||||
if ipv4 != nil {
|
|
||||||
dst := end.dst4()
|
|
||||||
end.isV6 = false
|
|
||||||
dst.Port = addr.Port
|
|
||||||
copy(dst.Addr[:], ipv4)
|
|
||||||
end.ClearSrc()
|
|
||||||
return &end, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv6 := addr.IP.To16()
|
|
||||||
if ipv6 != nil {
|
|
||||||
zone, err := zoneToUint32(addr.Zone)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dst := end.dst6()
|
|
||||||
end.isV6 = true
|
|
||||||
dst.Port = addr.Port
|
|
||||||
dst.ZoneId = zone
|
|
||||||
copy(dst.Addr[:], ipv6[:])
|
|
||||||
end.ClearSrc()
|
|
||||||
return &end, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("invalid IP address")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
|
|
||||||
bind.mu.Lock()
|
|
||||||
defer bind.mu.Unlock()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
var newPort uint16
|
|
||||||
var tries int
|
|
||||||
|
|
||||||
if bind.sock4 != -1 || bind.sock6 != -1 {
|
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
|
||||||
}
|
|
||||||
|
|
||||||
originalPort := port
|
|
||||||
|
|
||||||
again:
|
|
||||||
port = originalPort
|
|
||||||
var sock4, sock6 int
|
|
||||||
// Attempt ipv6 bind, update port if successful.
|
|
||||||
sock6, newPort, err = create6(port)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
port = newPort
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt ipv4 bind, update port if successful.
|
|
||||||
sock4, newPort, err = create4(port)
|
|
||||||
if err != nil {
|
|
||||||
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
|
||||||
unix.Close(sock6)
|
|
||||||
tries++
|
|
||||||
goto again
|
|
||||||
}
|
|
||||||
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
unix.Close(sock6)
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
port = newPort
|
|
||||||
}
|
|
||||||
|
|
||||||
var fns []ReceiveFunc
|
|
||||||
if sock4 != -1 {
|
|
||||||
fns = append(fns, bind.makeReceiveIPv4(sock4))
|
|
||||||
bind.sock4 = sock4
|
|
||||||
}
|
|
||||||
if sock6 != -1 {
|
|
||||||
fns = append(fns, bind.makeReceiveIPv6(sock6))
|
|
||||||
bind.sock6 = sock6
|
|
||||||
}
|
|
||||||
if len(fns) == 0 {
|
|
||||||
return nil, 0, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
return fns, port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
|
|
||||||
if bind.sock6 != -1 {
|
|
||||||
err := unix.SetsockoptInt(
|
|
||||||
bind.sock6,
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_MARK,
|
|
||||||
int(value),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if bind.sock4 != -1 {
|
|
||||||
err := unix.SetsockoptInt(
|
|
||||||
bind.sock4,
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_MARK,
|
|
||||||
int(value),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) Close() error {
|
|
||||||
// Take a readlock to shut down the sockets...
|
|
||||||
bind.mu.RLock()
|
|
||||||
if bind.sock6 != -1 {
|
|
||||||
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
|
|
||||||
}
|
|
||||||
if bind.sock4 != -1 {
|
|
||||||
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
|
|
||||||
}
|
|
||||||
bind.mu.RUnlock()
|
|
||||||
// ...and a write lock to close the fd.
|
|
||||||
// This ensures that no one else is using the fd.
|
|
||||||
bind.mu.Lock()
|
|
||||||
defer bind.mu.Unlock()
|
|
||||||
var err1, err2 error
|
|
||||||
if bind.sock6 != -1 {
|
|
||||||
err1 = unix.Close(bind.sock6)
|
|
||||||
bind.sock6 = -1
|
|
||||||
}
|
|
||||||
if bind.sock4 != -1 {
|
|
||||||
err2 = unix.Close(bind.sock4)
|
|
||||||
bind.sock4 = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*LinuxSocketBind) makeReceiveIPv6(sock int) ReceiveFunc {
|
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
|
||||||
var end LinuxSocketEndpoint
|
|
||||||
n, err := receive6(sock, buff, &end)
|
|
||||||
return n, &end, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*LinuxSocketBind) makeReceiveIPv4(sock int) ReceiveFunc {
|
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
|
||||||
var end LinuxSocketEndpoint
|
|
||||||
n, err := receive4(sock, buff, &end)
|
|
||||||
return n, &end, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
|
||||||
nend, ok := end.(*LinuxSocketEndpoint)
|
|
||||||
if !ok {
|
|
||||||
return ErrWrongEndpointType
|
|
||||||
}
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
if !nend.isV6 {
|
|
||||||
if bind.sock4 == -1 {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
return send4(bind.sock4, nend, buff)
|
|
||||||
} else {
|
|
||||||
if bind.sock6 == -1 {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
return send6(bind.sock6, nend, buff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) SrcIP() net.IP {
|
|
||||||
if !end.isV6 {
|
|
||||||
return net.IPv4(
|
|
||||||
end.src4().Src[0],
|
|
||||||
end.src4().Src[1],
|
|
||||||
end.src4().Src[2],
|
|
||||||
end.src4().Src[3],
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
return end.src6().src[:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) DstIP() net.IP {
|
|
||||||
if !end.isV6 {
|
|
||||||
return net.IPv4(
|
|
||||||
end.dst4().Addr[0],
|
|
||||||
end.dst4().Addr[1],
|
|
||||||
end.dst4().Addr[2],
|
|
||||||
end.dst4().Addr[3],
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
return end.dst6().Addr[:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
|
|
||||||
if !end.isV6 {
|
|
||||||
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
|
|
||||||
} else {
|
|
||||||
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) SrcToString() string {
|
|
||||||
return end.SrcIP().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) DstToString() string {
|
|
||||||
var udpAddr net.UDPAddr
|
|
||||||
udpAddr.IP = end.DstIP()
|
|
||||||
if !end.isV6 {
|
|
||||||
udpAddr.Port = end.dst4().Port
|
|
||||||
} else {
|
|
||||||
udpAddr.Port = end.dst6().Port
|
|
||||||
}
|
|
||||||
return udpAddr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) ClearDst() {
|
|
||||||
for i := range end.dst {
|
|
||||||
end.dst[i] = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) ClearSrc() {
|
|
||||||
for i := range end.src {
|
|
||||||
end.src[i] = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func zoneToUint32(zone string) (uint32, error) {
|
|
||||||
if zone == "" {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if intr, err := net.InterfaceByName(zone); err == nil {
|
|
||||||
return uint32(intr.Index), nil
|
|
||||||
}
|
|
||||||
n, err := strconv.ParseUint(zone, 10, 32)
|
|
||||||
return uint32(n), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func create4(port uint16) (int, uint16, error) {
|
|
||||||
|
|
||||||
// create socket
|
|
||||||
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET,
|
|
||||||
unix.SOCK_DGRAM,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return -1, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := unix.SockaddrInet4{
|
|
||||||
Port: int(port),
|
|
||||||
}
|
|
||||||
|
|
||||||
// set sockopts and bind
|
|
||||||
|
|
||||||
if err := func() error {
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.IPPROTO_IP,
|
|
||||||
unix.IP_PKTINFO,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return unix.Bind(fd, &addr)
|
|
||||||
}(); err != nil {
|
|
||||||
unix.Close(fd)
|
|
||||||
return -1, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sa, err := unix.Getsockname(fd)
|
|
||||||
if err == nil {
|
|
||||||
addr.Port = sa.(*unix.SockaddrInet4).Port
|
|
||||||
}
|
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func create6(port uint16) (int, uint16, error) {
|
|
||||||
|
|
||||||
// create socket
|
|
||||||
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET6,
|
|
||||||
unix.SOCK_DGRAM,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return -1, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set sockopts and bind
|
|
||||||
|
|
||||||
addr := unix.SockaddrInet6{
|
|
||||||
Port: int(port),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := func() error {
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.IPPROTO_IPV6,
|
|
||||||
unix.IPV6_RECVPKTINFO,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.IPPROTO_IPV6,
|
|
||||||
unix.IPV6_V6ONLY,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return unix.Bind(fd, &addr)
|
|
||||||
|
|
||||||
}(); err != nil {
|
|
||||||
unix.Close(fd)
|
|
||||||
return -1, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sa, err := unix.Getsockname(fd)
|
|
||||||
if err == nil {
|
|
||||||
addr.Port = sa.(*unix.SockaddrInet6).Port
|
|
||||||
}
|
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
|
|
||||||
|
|
||||||
// construct message header
|
|
||||||
|
|
||||||
cmsg := struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet4Pktinfo
|
|
||||||
}{
|
|
||||||
unix.Cmsghdr{
|
|
||||||
Level: unix.IPPROTO_IP,
|
|
||||||
Type: unix.IP_PKTINFO,
|
|
||||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
|
||||||
},
|
|
||||||
unix.Inet4Pktinfo{
|
|
||||||
Spec_dst: end.src4().Src,
|
|
||||||
Ifindex: end.src4().Ifindex,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
end.mu.Lock()
|
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
|
||||||
end.mu.Unlock()
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear src and retry
|
|
||||||
|
|
||||||
if err == unix.EINVAL {
|
|
||||||
end.ClearSrc()
|
|
||||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
|
||||||
end.mu.Lock()
|
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
|
||||||
end.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
|
|
||||||
|
|
||||||
// construct message header
|
|
||||||
|
|
||||||
cmsg := struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet6Pktinfo
|
|
||||||
}{
|
|
||||||
unix.Cmsghdr{
|
|
||||||
Level: unix.IPPROTO_IPV6,
|
|
||||||
Type: unix.IPV6_PKTINFO,
|
|
||||||
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
|
||||||
},
|
|
||||||
unix.Inet6Pktinfo{
|
|
||||||
Addr: end.src6().src,
|
|
||||||
Ifindex: end.dst6().ZoneId,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmsg.pktinfo.Addr == [16]byte{} {
|
|
||||||
cmsg.pktinfo.Ifindex = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
end.mu.Lock()
|
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
|
||||||
end.mu.Unlock()
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear src and retry
|
|
||||||
|
|
||||||
if err == unix.EINVAL {
|
|
||||||
end.ClearSrc()
|
|
||||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
|
||||||
end.mu.Lock()
|
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
|
||||||
end.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
|
|
||||||
|
|
||||||
// construct message header
|
|
||||||
|
|
||||||
var cmsg struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet4Pktinfo
|
|
||||||
}
|
|
||||||
|
|
||||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
end.isV6 = false
|
|
||||||
|
|
||||||
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
|
|
||||||
*end.dst4() = *newDst4
|
|
||||||
}
|
|
||||||
|
|
||||||
// update source cache
|
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
|
||||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
|
||||||
end.src4().Src = cmsg.pktinfo.Spec_dst
|
|
||||||
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
|
||||||
}
|
|
||||||
|
|
||||||
return size, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
|
|
||||||
|
|
||||||
// construct message header
|
|
||||||
|
|
||||||
var cmsg struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet6Pktinfo
|
|
||||||
}
|
|
||||||
|
|
||||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
end.isV6 = true
|
|
||||||
|
|
||||||
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
|
|
||||||
*end.dst6() = *newDst6
|
|
||||||
}
|
|
||||||
|
|
||||||
// update source cache
|
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
|
||||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
|
||||||
end.src6().src = cmsg.pktinfo.Addr
|
|
||||||
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
|
|
||||||
}
|
|
||||||
|
|
||||||
return size, nil
|
|
||||||
}
|
|
524
conn/bind_std.go
524
conn/bind_std.go
|
@ -1,72 +1,126 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package conn
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StdNetBind is meant to be a temporary solution on platforms for which
|
var (
|
||||||
// the sticky socket / source caching behavior has not yet been implemented.
|
_ Bind = (*StdNetBind)(nil)
|
||||||
// It uses the Go's net package to implement networking.
|
)
|
||||||
// See LinuxSocketBind for a proper implementation on the Linux platform.
|
|
||||||
|
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||||
|
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||||
|
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||||
|
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||||
|
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||||
type StdNetBind struct {
|
type StdNetBind struct {
|
||||||
mu sync.Mutex // protects following fields
|
mu sync.Mutex // protects all fields except as specified
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
|
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||||
|
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||||
|
ipv4TxOffload bool
|
||||||
|
ipv4RxOffload bool
|
||||||
|
ipv6TxOffload bool
|
||||||
|
ipv6RxOffload bool
|
||||||
|
|
||||||
|
// these two fields are not guarded by mu
|
||||||
|
udpAddrPool sync.Pool
|
||||||
|
msgsPool sync.Pool
|
||||||
|
|
||||||
blackhole4 bool
|
blackhole4 bool
|
||||||
blackhole6 bool
|
blackhole6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStdNetBind() Bind { return &StdNetBind{} }
|
func NewStdNetBind() Bind {
|
||||||
|
return &StdNetBind{
|
||||||
|
udpAddrPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: make([]byte, 16),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
type StdNetEndpoint net.UDPAddr
|
msgsPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||||
|
// both aliases for x/net/internal/socket.Message.
|
||||||
|
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var _ Bind = (*StdNetBind)(nil)
|
type StdNetEndpoint struct {
|
||||||
var _ Endpoint = (*StdNetEndpoint)(nil)
|
// AddrPort is the endpoint destination.
|
||||||
|
netip.AddrPort
|
||||||
|
// src is the current sticky source address and interface index, if
|
||||||
|
// supported. Typically this is a PKTINFO structure from/for control
|
||||||
|
// messages, see unix.PKTINFO for an example.
|
||||||
|
src []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Bind = (*StdNetBind)(nil)
|
||||||
|
_ Endpoint = &StdNetEndpoint{}
|
||||||
|
)
|
||||||
|
|
||||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
addr, err := parseEndpoint(s)
|
e, err := netip.ParseAddrPort(s)
|
||||||
return (*StdNetEndpoint)(addr), err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &StdNetEndpoint{
|
||||||
|
AddrPort: e,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*StdNetEndpoint) ClearSrc() {}
|
func (e *StdNetEndpoint) ClearSrc() {
|
||||||
|
if e.src != nil {
|
||||||
func (e *StdNetEndpoint) DstIP() net.IP {
|
// Truncate src, no need to reallocate.
|
||||||
return (*net.UDPAddr)(e).IP
|
e.src = e.src[:0]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIP() net.IP {
|
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||||
return nil // not supported
|
return e.AddrPort.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||||
addr := (*net.UDPAddr)(e)
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
out := addr.IP.To4()
|
return b
|
||||||
if out == nil {
|
|
||||||
out = addr.IP
|
|
||||||
}
|
|
||||||
out = append(out, byte(addr.Port&0xff))
|
|
||||||
out = append(out, byte((addr.Port>>8)&0xff))
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToString() string {
|
func (e *StdNetEndpoint) DstToString() string {
|
||||||
return (*net.UDPAddr)(e).String()
|
return e.AddrPort.String()
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcToString() string {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
@ -80,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
return conn, uaddr.Port, nil
|
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||||
bind.mu.Lock()
|
s.mu.Lock()
|
||||||
defer bind.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var tries int
|
var tries int
|
||||||
|
|
||||||
if bind.ipv4 != nil || bind.ipv6 != nil {
|
if s.ipv4 != nil || s.ipv6 != nil {
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,92 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||||
// If uport is 0, we can retry on failure.
|
// If uport is 0, we can retry on failure.
|
||||||
again:
|
again:
|
||||||
port := int(uport)
|
port := int(uport)
|
||||||
var ipv4, ipv6 *net.UDPConn
|
var v4conn, v6conn *net.UDPConn
|
||||||
|
var v4pc *ipv4.PacketConn
|
||||||
|
var v6pc *ipv6.PacketConn
|
||||||
|
|
||||||
ipv4, port, err = listenNet("udp4", port)
|
v4conn, port, err = listenNet("udp4", port)
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen on the same port as we're using for ipv4.
|
// Listen on the same port as we're using for ipv4.
|
||||||
ipv6, port, err = listenNet("udp6", port)
|
v6conn, port, err = listenNet("udp6", port)
|
||||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||||
ipv4.Close()
|
v4conn.Close()
|
||||||
tries++
|
tries++
|
||||||
goto again
|
goto again
|
||||||
}
|
}
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
ipv4.Close()
|
v4conn.Close()
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
var fns []ReceiveFunc
|
var fns []ReceiveFunc
|
||||||
if ipv4 != nil {
|
if v4conn != nil {
|
||||||
fns = append(fns, bind.makeReceiveIPv4(ipv4))
|
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||||
bind.ipv4 = ipv4
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
v4pc = ipv4.NewPacketConn(v4conn)
|
||||||
|
s.ipv4PC = v4pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||||
|
s.ipv4 = v4conn
|
||||||
}
|
}
|
||||||
if ipv6 != nil {
|
if v6conn != nil {
|
||||||
fns = append(fns, bind.makeReceiveIPv6(ipv6))
|
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||||
bind.ipv6 = ipv6
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
v6pc = ipv6.NewPacketConn(v6conn)
|
||||||
|
s.ipv6PC = v6pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||||
|
s.ipv6 = v6conn
|
||||||
}
|
}
|
||||||
if len(fns) == 0 {
|
if len(fns) == 0 {
|
||||||
return nil, 0, syscall.EAFNOSUPPORT
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
return fns, uint16(port), nil
|
return fns, uint16(port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Close() error {
|
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||||
bind.mu.Lock()
|
for i := range *msgs {
|
||||||
defer bind.mu.Unlock()
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||||
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||||
|
}
|
||||||
|
s.msgsPool.Put(msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||||
|
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// If compilation fails here these are no longer the same underlying type.
|
||||||
|
_ ipv6.Message = ipv4.Message{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type batchReader interface {
|
||||||
|
ReadBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type batchWriter interface {
|
||||||
|
WriteBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) receiveIP(
|
||||||
|
br batchReader,
|
||||||
|
conn *net.UDPConn,
|
||||||
|
rxOffload bool,
|
||||||
|
bufs [][]byte,
|
||||||
|
sizes []int,
|
||||||
|
eps []Endpoint,
|
||||||
|
) (n int, err error) {
|
||||||
|
msgs := s.getMessages()
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
|
}
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
if rxOffload {
|
||||||
|
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
||||||
|
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg := &(*msgs)[0]
|
||||||
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
|
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
|
// rename the IdealBatchSize constant to BatchSize.
|
||||||
|
func (s *StdNetBind) BatchSize() int {
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
return IdealBatchSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var err1, err2 error
|
var err1, err2 error
|
||||||
if bind.ipv4 != nil {
|
if s.ipv4 != nil {
|
||||||
err1 = bind.ipv4.Close()
|
err1 = s.ipv4.Close()
|
||||||
bind.ipv4 = nil
|
s.ipv4 = nil
|
||||||
|
s.ipv4PC = nil
|
||||||
}
|
}
|
||||||
if bind.ipv6 != nil {
|
if s.ipv6 != nil {
|
||||||
err2 = bind.ipv6.Close()
|
err2 = s.ipv6.Close()
|
||||||
bind.ipv6 = nil
|
s.ipv6 = nil
|
||||||
|
s.ipv6PC = nil
|
||||||
}
|
}
|
||||||
bind.blackhole4 = false
|
s.blackhole4 = false
|
||||||
bind.blackhole6 = false
|
s.blackhole6 = false
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
s.ipv4RxOffload = false
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
s.ipv6RxOffload = false
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
|
type ErrUDPGSODisabled struct {
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
onLaddr string
|
||||||
n, endpoint, err := conn.ReadFromUDP(buff)
|
RetryErr error
|
||||||
if endpoint != nil {
|
|
||||||
endpoint.IP = endpoint.IP.To4()
|
|
||||||
}
|
|
||||||
return n, (*StdNetEndpoint)(endpoint), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
|
func (e ErrUDPGSODisabled) Error() string {
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
|
||||||
n, endpoint, err := conn.ReadFromUDP(buff)
|
|
||||||
return n, (*StdNetEndpoint)(endpoint), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||||
var err error
|
return e.RetryErr
|
||||||
nend, ok := endpoint.(*StdNetEndpoint)
|
}
|
||||||
if !ok {
|
|
||||||
return ErrWrongEndpointType
|
|
||||||
}
|
|
||||||
|
|
||||||
bind.mu.Lock()
|
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
blackhole := bind.blackhole4
|
s.mu.Lock()
|
||||||
conn := bind.ipv4
|
blackhole := s.blackhole4
|
||||||
if nend.IP.To4() == nil {
|
conn := s.ipv4
|
||||||
blackhole = bind.blackhole6
|
offload := s.ipv4TxOffload
|
||||||
conn = bind.ipv6
|
br := batchWriter(s.ipv4PC)
|
||||||
|
is6 := false
|
||||||
|
if endpoint.DstIP().Is6() {
|
||||||
|
blackhole = s.blackhole6
|
||||||
|
conn = s.ipv6
|
||||||
|
br = s.ipv6PC
|
||||||
|
is6 = true
|
||||||
|
offload = s.ipv6TxOffload
|
||||||
}
|
}
|
||||||
bind.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if blackhole {
|
if blackhole {
|
||||||
return nil
|
return nil
|
||||||
|
@ -191,6 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
|
||||||
|
msgs := s.getMessages()
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
|
defer s.udpAddrPool.Put(ua)
|
||||||
|
if is6 {
|
||||||
|
as16 := endpoint.DstIP().As16()
|
||||||
|
copy(ua.IP, as16[:])
|
||||||
|
ua.IP = ua.IP[:16]
|
||||||
|
} else {
|
||||||
|
as4 := endpoint.DstIP().As4()
|
||||||
|
copy(ua.IP, as4[:])
|
||||||
|
ua.IP = ua.IP[:4]
|
||||||
|
}
|
||||||
|
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||||
|
var (
|
||||||
|
retried bool
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
retry:
|
||||||
|
if offload {
|
||||||
|
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||||
|
err = s.send(conn, br, (*msgs)[:n])
|
||||||
|
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||||
|
offload = false
|
||||||
|
s.mu.Lock()
|
||||||
|
if is6 {
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
} else {
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Addr = ua
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||||
|
}
|
||||||
|
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||||
|
}
|
||||||
|
if retried {
|
||||||
|
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
start int
|
||||||
|
)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
for {
|
||||||
|
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||||
|
if err != nil || n == len(msgs[start:]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start += n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, msg := range msgs {
|
||||||
|
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||||
|
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||||
|
// length field is self excluding.
|
||||||
|
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||||
|
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||||
|
|
||||||
|
// This is a hard limit imposed by the kernel.
|
||||||
|
udpSegmentMaxDatagrams = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||||
|
|
||||||
|
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||||
|
var (
|
||||||
|
base = -1 // index of msg we are currently coalescing into
|
||||||
|
gsoSize int // segmentation size of msgs[base]
|
||||||
|
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||||
|
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||||
|
)
|
||||||
|
maxPayloadLen := maxIPv4PayloadLen
|
||||||
|
if ep.DstIP().Is6() {
|
||||||
|
maxPayloadLen = maxIPv6PayloadLen
|
||||||
|
}
|
||||||
|
for i, buf := range bufs {
|
||||||
|
if i > 0 {
|
||||||
|
msgLen := len(buf)
|
||||||
|
baseLenBefore := len(msgs[base].Buffers[0])
|
||||||
|
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||||
|
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||||
|
msgLen <= gsoSize &&
|
||||||
|
msgLen <= freeBaseCap &&
|
||||||
|
dgramCnt < udpSegmentMaxDatagrams &&
|
||||||
|
!endBatch {
|
||||||
|
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||||
|
if i == len(bufs)-1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
dgramCnt++
|
||||||
|
if msgLen < gsoSize {
|
||||||
|
// A smaller than gsoSize packet on the tail is legal, but
|
||||||
|
// it must end the batch.
|
||||||
|
endBatch = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dgramCnt > 1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
// Reset prior to incrementing base since we are preparing to start a
|
||||||
|
// new potential batch.
|
||||||
|
endBatch = false
|
||||||
|
base++
|
||||||
|
gsoSize = len(buf)
|
||||||
|
setSrcControl(&msgs[base].OOB, ep)
|
||||||
|
msgs[base].Buffers[0] = buf
|
||||||
|
msgs[base].Addr = addr
|
||||||
|
dgramCnt = 1
|
||||||
|
}
|
||||||
|
return base + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type getGSOFunc func(control []byte) (int, error)
|
||||||
|
|
||||||
|
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||||
|
for i := firstMsgAt; i < len(msgs); i++ {
|
||||||
|
msg := &msgs[i]
|
||||||
|
if msg.N == 0 {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
gsoSize int
|
||||||
|
start int
|
||||||
|
end = msg.N
|
||||||
|
numToSplit = 1
|
||||||
|
)
|
||||||
|
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if gsoSize > 0 {
|
||||||
|
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||||
|
end = gsoSize
|
||||||
|
}
|
||||||
|
for j := 0; j < numToSplit; j++ {
|
||||||
|
if n > i {
|
||||||
|
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||||
|
}
|
||||||
|
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||||
|
msgs[n].N = copied
|
||||||
|
msgs[n].Addr = msg.Addr
|
||||||
|
start = end
|
||||||
|
end += gsoSize
|
||||||
|
if end > msg.N {
|
||||||
|
end = msg.N
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if i != n-1 {
|
||||||
|
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||||
|
// of splitting, so we only zero the source msg len when it is not
|
||||||
|
// the destination of the last split operation above.
|
||||||
|
msg.N = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
250
conn/bind_std_test.go
Normal file
250
conn/bind_std_test.go
Normal file
|
@ -0,0 +1,250 @@
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||||
|
bind := NewStdNetBind().(*StdNetBind)
|
||||||
|
fns, _, err := bind.Open(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bind.Close()
|
||||||
|
bufs := make([][]byte, 1)
|
||||||
|
bufs[0] = make([]byte, 1)
|
||||||
|
sizes := make([]int, 1)
|
||||||
|
eps := make([]Endpoint, 1)
|
||||||
|
for _, fn := range fns {
|
||||||
|
// The ReceiveFuncs must not access conn-related fields on StdNetBind
|
||||||
|
// unguarded. Close() nils the conn-related fields resulting in a panic
|
||||||
|
// if they violate the mutex.
|
||||||
|
fn(bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_coalesceMessages(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
buffs [][]byte
|
||||||
|
wantLens []int
|
||||||
|
wantGSO []int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "one message no coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{1},
|
||||||
|
wantGSO: []int{0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two messages equal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 1, 2),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{2},
|
||||||
|
wantGSO: []int{1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two messages unequal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 3),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{3},
|
||||||
|
wantGSO: []int{2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three messages second unequal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 3),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
},
|
||||||
|
wantLens: []int{3, 2},
|
||||||
|
wantGSO: []int{2, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three messages limited cap coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 4),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
},
|
||||||
|
wantLens: []int{4, 2},
|
||||||
|
wantGSO: []int{2, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1").To4(),
|
||||||
|
Port: 1,
|
||||||
|
}
|
||||||
|
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make([][]byte, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, 2)
|
||||||
|
}
|
||||||
|
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
|
||||||
|
if got != len(tt.wantLens) {
|
||||||
|
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||||
|
}
|
||||||
|
for i := 0; i < got; i++ {
|
||||||
|
if msgs[i].Addr != addr {
|
||||||
|
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||||
|
}
|
||||||
|
gotLen := len(msgs[i].Buffers[0])
|
||||||
|
if gotLen != tt.wantLens[i] {
|
||||||
|
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||||
|
}
|
||||||
|
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||||
|
}
|
||||||
|
if gotGSO != tt.wantGSO[i] {
|
||||||
|
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockGetGSOSize(control []byte) (int, error) {
|
||||||
|
if len(control) < 2 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return int(binary.LittleEndian.Uint16(control)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_splitCoalescedMessages(t *testing.T) {
|
||||||
|
newMsg := func(n, gso int) ipv6.Message {
|
||||||
|
msg := ipv6.Message{
|
||||||
|
Buffers: [][]byte{make([]byte, 1<<16-1)},
|
||||||
|
N: n,
|
||||||
|
OOB: make([]byte, 2),
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||||
|
if gso > 0 {
|
||||||
|
msg.NN = 2
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
msgs []ipv6.Message
|
||||||
|
firstMsgAt int
|
||||||
|
wantNumEval int
|
||||||
|
wantMsgLens []int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "second last split last empty",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(3, 1),
|
||||||
|
newMsg(0, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 3,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last empty",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 1,
|
||||||
|
wantMsgLens: []int{1, 0, 0, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last no split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 2,
|
||||||
|
wantMsgLens: []int{1, 1, 0, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(3, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last split last split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(2, 1),
|
||||||
|
newMsg(2, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last split overflow",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(4, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
|
||||||
|
if err != nil && !tt.wantErr {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.wantNumEval {
|
||||||
|
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||||
|
}
|
||||||
|
for i, msg := range tt.msgs {
|
||||||
|
if msg.N != tt.wantMsgLens[i] {
|
||||||
|
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package conn
|
package conn
|
||||||
|
@ -9,6 +9,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -16,7 +17,7 @@ import (
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn/winrio"
|
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -73,7 +74,7 @@ type afWinRingBind struct {
|
||||||
type WinRingBind struct {
|
type WinRingBind struct {
|
||||||
v4, v6 afWinRingBind
|
v4, v6 afWinRingBind
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
isOpen uint32
|
isOpen atomic.Uint32 // 0, 1, or 2
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultBind() Bind { return NewWinRingBind() }
|
func NewDefaultBind() Bind { return NewWinRingBind() }
|
||||||
|
@ -90,8 +91,10 @@ type WinRingEndpoint struct {
|
||||||
data [30]byte
|
data [30]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Bind = (*WinRingBind)(nil)
|
var (
|
||||||
var _ Endpoint = (*WinRingEndpoint)(nil)
|
_ Bind = (*WinRingBind)(nil)
|
||||||
|
_ Endpoint = (*WinRingEndpoint)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
host, port, err := net.SplitHostPort(s)
|
host, port, err := net.SplitHostPort(s)
|
||||||
|
@ -121,27 +124,25 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
|
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
|
||||||
return nil, windows.ERROR_INVALID_ADDRESS
|
return nil, windows.ERROR_INVALID_ADDRESS
|
||||||
}
|
}
|
||||||
var src []byte
|
|
||||||
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
|
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
|
||||||
unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen))
|
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
|
||||||
copy(dst[:], src)
|
|
||||||
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
|
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*WinRingEndpoint) ClearSrc() {}
|
func (*WinRingEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstIP() net.IP {
|
func (e *WinRingEndpoint) DstIP() netip.Addr {
|
||||||
switch e.family {
|
switch e.family {
|
||||||
case windows.AF_INET:
|
case windows.AF_INET:
|
||||||
return append([]byte{}, e.data[2:6]...)
|
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
|
||||||
case windows.AF_INET6:
|
case windows.AF_INET6:
|
||||||
return append([]byte{}, e.data[6:22]...)
|
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
|
||||||
}
|
}
|
||||||
return nil
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *WinRingEndpoint) SrcIP() net.IP {
|
func (e *WinRingEndpoint) SrcIP() netip.Addr {
|
||||||
return nil // not supported
|
return netip.Addr{} // not supported
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstToBytes() []byte {
|
func (e *WinRingEndpoint) DstToBytes() []byte {
|
||||||
|
@ -163,15 +164,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
|
||||||
func (e *WinRingEndpoint) DstToString() string {
|
func (e *WinRingEndpoint) DstToString() string {
|
||||||
switch e.family {
|
switch e.family {
|
||||||
case windows.AF_INET:
|
case windows.AF_INET:
|
||||||
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
|
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||||
return addr.String()
|
|
||||||
case windows.AF_INET6:
|
case windows.AF_INET6:
|
||||||
var zone string
|
var zone string
|
||||||
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
||||||
zone = strconv.FormatUint(uint64(scope), 10)
|
zone = strconv.FormatUint(uint64(scope), 10)
|
||||||
}
|
}
|
||||||
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
|
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -213,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) closeAndZero() {
|
func (bind *WinRingBind) closeAndZero() {
|
||||||
atomic.StoreUint32(&bind.isOpen, 0)
|
bind.isOpen.Store(0)
|
||||||
bind.v4.CloseAndZero()
|
bind.v4.CloseAndZero()
|
||||||
bind.v6.CloseAndZero()
|
bind.v6.CloseAndZero()
|
||||||
}
|
}
|
||||||
|
@ -277,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
|
||||||
bind.closeAndZero()
|
bind.closeAndZero()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 0 {
|
if bind.isOpen.Load() != 0 {
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
}
|
}
|
||||||
var sa windows.Sockaddr
|
var sa windows.Sockaddr
|
||||||
|
@ -300,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&bind.isOpen, 1)
|
bind.isOpen.Store(1)
|
||||||
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) Close() error {
|
func (bind *WinRingBind) Close() error {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
if bind.isOpen.Load() != 1 {
|
||||||
bind.mu.RUnlock()
|
bind.mu.RUnlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&bind.isOpen, 2)
|
bind.isOpen.Store(2)
|
||||||
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
||||||
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
||||||
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
||||||
|
@ -322,6 +321,13 @@ func (bind *WinRingBind) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
|
// rename the IdealBatchSize constant to BatchSize.
|
||||||
|
func (bind *WinRingBind) BatchSize() int {
|
||||||
|
// TODO: implement batching in and out of the ring
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -346,17 +352,21 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
|
||||||
//go:linkname procyield runtime.procyield
|
//go:linkname procyield runtime.procyield
|
||||||
func procyield(cycles uint32)
|
func procyield(cycles uint32)
|
||||||
|
|
||||||
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
|
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
bind.rx.mu.Lock()
|
bind.rx.mu.Lock()
|
||||||
defer bind.rx.mu.Unlock()
|
defer bind.rx.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
var count uint32
|
var count uint32
|
||||||
var results [1]winrio.Result
|
var results [1]winrio.Result
|
||||||
|
retry:
|
||||||
|
count = 0
|
||||||
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
||||||
if tries > 0 {
|
if tries > 0 {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
procyield(1)
|
procyield(1)
|
||||||
|
@ -364,7 +374,7 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
|
||||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||||
}
|
}
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
err := winrio.Notify(bind.rx.cq)
|
err = winrio.Notify(bind.rx.cq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
|
@ -375,20 +385,28 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
return 0, nil, io.ErrNoProgress
|
return 0, nil, io.ErrNoProgress
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bind.rx.Return(1)
|
bind.rx.Return(1)
|
||||||
err := bind.InsertReceiveRequest()
|
err = bind.InsertReceiveRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
|
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
||||||
|
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
||||||
|
// attacker bandwidth, just like the rest of the receive path.
|
||||||
|
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
if results[0].Status != 0 {
|
if results[0].Status != 0 {
|
||||||
return 0, nil, windows.Errno(results[0].Status)
|
return 0, nil, windows.Errno(results[0].Status)
|
||||||
}
|
}
|
||||||
|
@ -398,20 +416,26 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
|
||||||
return n, &ep, nil
|
return n, &ep, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
return bind.v4.Receive(buf, &bind.isOpen)
|
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
|
||||||
|
sizes[0] = n
|
||||||
|
eps[0] = ep
|
||||||
|
return 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
return bind.v6.Receive(buf, &bind.isOpen)
|
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
|
||||||
|
sizes[0] = n
|
||||||
|
eps[0] = ep
|
||||||
|
return 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
|
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
if len(buf) > bytesPerPacket {
|
if len(buf) > bytesPerPacket {
|
||||||
|
@ -433,7 +457,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||||
|
@ -462,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
|
||||||
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
|
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
nend, ok := endpoint.(*WinRingEndpoint)
|
nend, ok := endpoint.(*WinRingEndpoint)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrWrongEndpointType
|
return ErrWrongEndpointType
|
||||||
}
|
}
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
switch nend.family {
|
for _, buf := range bufs {
|
||||||
case windows.AF_INET:
|
switch nend.family {
|
||||||
if bind.v4.blackhole {
|
case windows.AF_INET:
|
||||||
return nil
|
if bind.v4.blackhole {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case windows.AF_INET6:
|
||||||
|
if bind.v6.blackhole {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return bind.v4.Send(buf, nend, &bind.isOpen)
|
|
||||||
case windows.AF_INET6:
|
|
||||||
if bind.v6.blackhole {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return bind.v6.Send(buf, nend, &bind.isOpen)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
bind.mu.Lock()
|
s.mu.Lock()
|
||||||
defer bind.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
sysconn, err := bind.ipv4.SyscallConn()
|
sysconn, err := s.ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -500,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
bind.blackhole4 = blackhole
|
s.blackhole4 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
bind.mu.Lock()
|
s.mu.Lock()
|
||||||
defer bind.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
sysconn, err := bind.ipv6.SyscallConn()
|
sysconn, err := s.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -520,13 +550,14 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
bind.blackhole6 = blackhole
|
s.blackhole6 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
if bind.isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
||||||
|
@ -540,7 +571,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
||||||
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
if bind.isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
||||||
|
@ -568,21 +599,3 @@ func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error
|
||||||
const IPV6_UNICAST_IF = 31
|
const IPV6_UNICAST_IF = 31
|
||||||
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
|
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
|
||||||
}
|
}
|
||||||
|
|
||||||
// unsafeSlice updates the slice slicePtr to be a slice
|
|
||||||
// referencing the provided data with its length & capacity set to
|
|
||||||
// lenCap.
|
|
||||||
//
|
|
||||||
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
|
|
||||||
// update callers to use unsafe.Slice instead of this.
|
|
||||||
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
|
|
||||||
type sliceHeader struct {
|
|
||||||
Data unsafe.Pointer
|
|
||||||
Len int
|
|
||||||
Cap int
|
|
||||||
}
|
|
||||||
h := (*sliceHeader)(slicePtr)
|
|
||||||
h.Data = data
|
|
||||||
h.Len = lenCap
|
|
||||||
h.Cap = lenCap
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package bindtest
|
package bindtest
|
||||||
|
@ -9,10 +9,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChannelBind struct {
|
type ChannelBind struct {
|
||||||
|
@ -25,8 +25,10 @@ type ChannelBind struct {
|
||||||
|
|
||||||
type ChannelEndpoint uint16
|
type ChannelEndpoint uint16
|
||||||
|
|
||||||
var _ conn.Bind = (*ChannelBind)(nil)
|
var (
|
||||||
var _ conn.Endpoint = (*ChannelEndpoint)(nil)
|
_ conn.Bind = (*ChannelBind)(nil)
|
||||||
|
_ conn.Endpoint = (*ChannelEndpoint)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
func NewChannelBinds() [2]conn.Bind {
|
func NewChannelBinds() [2]conn.Bind {
|
||||||
arx4 := make(chan []byte, 8192)
|
arx4 := make(chan []byte, 8192)
|
||||||
|
@ -61,9 +63,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
|
||||||
|
|
||||||
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
||||||
|
|
||||||
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
|
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
|
||||||
|
|
||||||
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
|
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
|
||||||
|
|
||||||
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
c.closeSignal = make(chan bool)
|
c.closeSignal = make(chan bool)
|
||||||
|
@ -87,45 +89,48 @@ func (c *ChannelBind) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) BatchSize() int { return 1 }
|
||||||
|
|
||||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||||
|
|
||||||
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||||
return func(b []byte) (n int, ep conn.Endpoint, err error) {
|
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||||
select {
|
select {
|
||||||
case <-c.closeSignal:
|
case <-c.closeSignal:
|
||||||
return 0, nil, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
case rx := <-ch:
|
case rx := <-ch:
|
||||||
return copy(b, rx), c.target6, nil
|
copied := copy(bufs[0], rx)
|
||||||
|
sizes[0] = copied
|
||||||
|
eps[0] = c.target6
|
||||||
|
return 1, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
select {
|
for _, b := range bufs {
|
||||||
case <-c.closeSignal:
|
select {
|
||||||
return net.ErrClosed
|
case <-c.closeSignal:
|
||||||
default:
|
return net.ErrClosed
|
||||||
bc := make([]byte, len(b))
|
default:
|
||||||
copy(bc, b)
|
bc := make([]byte, len(b))
|
||||||
if ep.(ChannelEndpoint) == c.target4 {
|
copy(bc, b)
|
||||||
*c.tx4 <- bc
|
if ep.(ChannelEndpoint) == c.target4 {
|
||||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
*c.tx4 <- bc
|
||||||
*c.tx6 <- bc
|
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||||
} else {
|
*c.tx6 <- bc
|
||||||
return os.ErrInvalid
|
} else {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
_, port, err := net.SplitHostPort(s)
|
addr, err := netip.ParseAddrPort(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
i, err := strconv.ParseUint(port, 10, 16)
|
return ChannelEndpoint(addr.Port()), nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ChannelEndpoint(i), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package conn
|
package conn
|
||||||
|
|
||||||
func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||||
sysconn, err := bind.ipv4.SyscallConn()
|
sysconn, err := s.ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||||
sysconn, err := bind.ipv6.SyscallConn()
|
sysconn, err := s.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
|
62
conn/conn.go
62
conn/conn.go
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Package conn implements WireGuard's network connections.
|
// Package conn implements WireGuard's network connections.
|
||||||
|
@ -9,16 +9,23 @@ package conn
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A ReceiveFunc receives a single inbound packet from the network.
|
const (
|
||||||
// It writes the data into b. n is the length of the packet.
|
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||||
// ep is the remote endpoint.
|
)
|
||||||
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
|
|
||||||
|
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||||
|
// into packets. On a successful read it returns the number of elements of
|
||||||
|
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||||
|
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||||
|
// and eps slice with a length greater than or equal to the length of packets.
|
||||||
|
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||||
|
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||||
|
|
||||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||||
//
|
//
|
||||||
|
@ -38,11 +45,16 @@ type Bind interface {
|
||||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||||
SetMark(mark uint32) error
|
SetMark(mark uint32) error
|
||||||
|
|
||||||
// Send writes a packet b to address ep.
|
// Send writes one or more packets in bufs to address ep. The length of
|
||||||
Send(b []byte, ep Endpoint) error
|
// bufs must not exceed BatchSize().
|
||||||
|
Send(bufs [][]byte, ep Endpoint) error
|
||||||
|
|
||||||
// ParseEndpoint creates a new endpoint from a string.
|
// ParseEndpoint creates a new endpoint from a string.
|
||||||
ParseEndpoint(s string) (Endpoint, error)
|
ParseEndpoint(s string) (Endpoint, error)
|
||||||
|
|
||||||
|
// BatchSize is the number of buffers expected to be passed to
|
||||||
|
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||||
|
BatchSize() int
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindSocketToInterface is implemented by Bind objects that support being
|
// BindSocketToInterface is implemented by Bind objects that support being
|
||||||
|
@ -68,8 +80,8 @@ type Endpoint interface {
|
||||||
SrcToString() string // returns the local source address (ip:port)
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
DstToString() string // returns the destination address (ip:port)
|
DstToString() string // returns the destination address (ip:port)
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
DstIP() net.IP
|
DstIP() netip.Addr
|
||||||
SrcIP() net.IP
|
SrcIP() netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -119,33 +131,3 @@ func (fn ReceiveFunc) PrettyName() string {
|
||||||
}
|
}
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|
||||||
// ensure that the host is an IP address
|
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
|
||||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
|
||||||
// trying to make sure with a small sanity test that this is a real IP address and
|
|
||||||
// not something that's likely to incur DNS lookups.
|
|
||||||
host = host[:i]
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip == nil {
|
|
||||||
return nil, errors.New("Failed to parse IP address: " + host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse address and port
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ip4 := addr.IP.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
addr.IP = ip4
|
|
||||||
}
|
|
||||||
return addr, err
|
|
||||||
}
|
|
||||||
|
|
24
conn/conn_test.go
Normal file
24
conn/conn_test.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrettyName(t *testing.T) {
|
||||||
|
var (
|
||||||
|
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
|
||||||
|
)
|
||||||
|
|
||||||
|
const want = "TestPrettyName"
|
||||||
|
|
||||||
|
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
|
||||||
|
if got := recvFunc.PrettyName(); got != want {
|
||||||
|
t.Errorf("PrettyName() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
43
conn/controlfns.go
Normal file
43
conn/controlfns.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||||
|
// the max supported by a default configuration of macOS. Some platforms will
|
||||||
|
// silently clamp the value to other maximums, such as linux clamping to
|
||||||
|
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||||
|
// around this limitation)
|
||||||
|
const socketBufferSize = 7 << 20
|
||||||
|
|
||||||
|
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||||
|
// It is used to apply platform specific configuration to the socket prior to
|
||||||
|
// bind.
|
||||||
|
type controlFn func(network, address string, c syscall.RawConn) error
|
||||||
|
|
||||||
|
// controlFns is a list of functions that are called from the listen config
|
||||||
|
// that can apply socket options.
|
||||||
|
var controlFns = []controlFn{}
|
||||||
|
|
||||||
|
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||||
|
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||||
|
// information OOB configuration for sticky sockets.
|
||||||
|
func listenConfig() *net.ListenConfig {
|
||||||
|
return &net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
for _, fn := range controlFns {
|
||||||
|
if err := fn(network, address, c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
61
conn/controlfns_linux.go
Normal file
61
conn/controlfns_linux.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
|
||||||
|
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||||
|
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||||
|
// fail silently - the result of failure is lower performance on very fast
|
||||||
|
// links or high latency links.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
// Set up to *mem_max
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||||
|
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||||
|
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
switch network {
|
||||||
|
case "udp4":
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "udp6":
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
35
conn/controlfns_unix.go
Normal file
35
conn/controlfns_unix.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
//go:build !windows && !linux && !wasm
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
if network == "udp6" {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
23
conn/controlfns_windows.go
Normal file
23
conn/controlfns_windows.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
|
@ -1,8 +1,8 @@
|
||||||
// +build !linux,!windows
|
//go:build !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package conn
|
package conn
|
||||||
|
|
12
conn/errors_default.go
Normal file
12
conn/errors_default.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
return false
|
||||||
|
}
|
28
conn/errors_linux.go
Normal file
28
conn/errors_linux.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
var serr *os.SyscallError
|
||||||
|
if errors.As(err, &serr) {
|
||||||
|
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||||
|
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||||
|
// See:
|
||||||
|
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||||
|
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||||
|
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
|
||||||
|
// It occurs when the peer mtu + wg headers is greater than path mtu.
|
||||||
|
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
15
conn/features_default.go
Normal file
15
conn/features_default.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
return
|
||||||
|
}
|
31
conn/features_linux.go
Normal file
31
conn/features_linux.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
rc, err := conn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = rc.Control(func(fd uintptr) {
|
||||||
|
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||||
|
txOffload = errSyscall == nil
|
||||||
|
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
|
||||||
|
// use setsockopt workaround
|
||||||
|
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||||
|
rxOffload = errSyscall == nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return txOffload, rxOffload
|
||||||
|
}
|
21
conn/gso_default.go
Normal file
21
conn/gso_default.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||||
|
// offloading control data.
|
||||||
|
const gsoControlSize = 0
|
65
conn/gso_linux.go
Normal file
65
conn/gso_linux.go
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
//go:build linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sizeOfGSOData = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||||
|
}
|
||||||
|
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||||
|
var gso uint16
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||||
|
return int(gso), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||||
|
// data in control untouched.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
existingLen := len(*control)
|
||||||
|
avail := cap(*control) - existingLen
|
||||||
|
space := unix.CmsgSpace(sizeOfGSOData)
|
||||||
|
if avail < space {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
gsoControl := (*control)[existingLen:]
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||||
|
hdr.Level = unix.SOL_UDP
|
||||||
|
hdr.Type = unix.UDP_SEGMENT
|
||||||
|
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||||
|
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||||
|
*control = (*control)[:existingLen+space]
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||||
|
// offloading control data.
|
||||||
|
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
|
@ -1,12 +1,12 @@
|
||||||
// +build !linux,!openbsd,!freebsd
|
//go:build !linux && !openbsd && !freebsd
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package conn
|
package conn
|
||||||
|
|
||||||
func (bind *StdNetBind) SetMark(mark uint32) error {
|
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
// +build linux openbsd freebsd
|
//go:build linux || openbsd || freebsd
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package conn
|
package conn
|
||||||
|
@ -26,13 +26,13 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) SetMark(mark uint32) error {
|
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||||
var operr error
|
var operr error
|
||||||
if fwmarkIoctl == 0 {
|
if fwmarkIoctl == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if bind.ipv4 != nil {
|
if s.ipv4 != nil {
|
||||||
fd, err := bind.ipv4.SyscallConn()
|
fd, err := s.ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if bind.ipv6 != nil {
|
if s.ipv6 != nil {
|
||||||
fd, err := bind.ipv6.SyscallConn()
|
fd, err := s.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
42
conn/sticky_default.go
Normal file
42
conn/sticky_default.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||||
|
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||||
|
// ports and require testing.
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
const stickyControlSize = 0
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = false
|
112
conn/sticky_linux.go
Normal file
112
conn/sticky_linux.go
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
switch len(e.src) {
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return netip.AddrFrom4(info.Spec_dst)
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
// TODO: set zone. in order to do so we need to check if the address is
|
||||||
|
// link local, and if it is perform a syscall to turn the ifindex into a
|
||||||
|
// zone string because netip uses string zones.
|
||||||
|
return netip.AddrFrom16(info.Addr)
|
||||||
|
}
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
switch len(e.src) {
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return info.Ifindex
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return int32(info.Ifindex)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return e.SrcIP().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
ep.ClearSrc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem []byte = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IP &&
|
||||||
|
hdr.Type == unix.IP_PKTINFO {
|
||||||
|
|
||||||
|
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
||||||
|
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
}
|
||||||
|
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||||
|
|
||||||
|
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||||
|
copy(ep.src, hdrBuf)
|
||||||
|
copy(ep.src[unix.CmsgLen(0):], data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
|
hdr.Type == unix.IPV6_PKTINFO {
|
||||||
|
|
||||||
|
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
||||||
|
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||||
|
|
||||||
|
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||||
|
copy(ep.src, hdrBuf)
|
||||||
|
copy(ep.src[unix.CmsgLen(0):], data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||||
|
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||||
|
// that ep is a default value.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
if cap(*control) < len(ep.src) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:0]
|
||||||
|
*control = append(*control, ep.src...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = true
|
266
conn/sticky_linux_test.go
Normal file
266
conn/sticky_linux_test.go
Normal file
|
@ -0,0 +1,266 @@
|
||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
|
||||||
|
var buf []byte
|
||||||
|
if addr.Is4() {
|
||||||
|
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
hdr := unix.Cmsghdr{
|
||||||
|
Level: unix.IPPROTO_IP,
|
||||||
|
Type: unix.IP_PKTINFO,
|
||||||
|
}
|
||||||
|
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||||
|
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||||
|
|
||||||
|
info := unix.Inet4Pktinfo{
|
||||||
|
Ifindex: ifidx,
|
||||||
|
Spec_dst: addr.As4(),
|
||||||
|
}
|
||||||
|
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
|
||||||
|
} else {
|
||||||
|
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||||
|
hdr := unix.Cmsghdr{
|
||||||
|
Level: unix.IPPROTO_IPV6,
|
||||||
|
Type: unix.IPV6_PKTINFO,
|
||||||
|
}
|
||||||
|
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
||||||
|
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||||
|
|
||||||
|
info := unix.Inet6Pktinfo{
|
||||||
|
Ifindex: uint32(ifidx),
|
||||||
|
Addr: addr.As16(),
|
||||||
|
}
|
||||||
|
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.src = buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_setSrcControl(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
ep := &StdNetEndpoint{
|
||||||
|
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
||||||
|
}
|
||||||
|
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
||||||
|
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
|
||||||
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
if hdr.Level != unix.IPPROTO_IP {
|
||||||
|
t.Errorf("unexpected level: %d", hdr.Level)
|
||||||
|
}
|
||||||
|
if hdr.Type != unix.IP_PKTINFO {
|
||||||
|
t.Errorf("unexpected type: %d", hdr.Type)
|
||||||
|
}
|
||||||
|
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
||||||
|
t.Errorf("unexpected length: %d", hdr.Len)
|
||||||
|
}
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
||||||
|
t.Errorf("unexpected address: %v", info.Spec_dst)
|
||||||
|
}
|
||||||
|
if info.Ifindex != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
ep := &StdNetEndpoint{
|
||||||
|
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
||||||
|
}
|
||||||
|
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||||
|
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
|
||||||
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
if hdr.Level != unix.IPPROTO_IPV6 {
|
||||||
|
t.Errorf("unexpected level: %d", hdr.Level)
|
||||||
|
}
|
||||||
|
if hdr.Type != unix.IPV6_PKTINFO {
|
||||||
|
t.Errorf("unexpected type: %d", hdr.Type)
|
||||||
|
}
|
||||||
|
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
||||||
|
t.Errorf("unexpected length: %d", hdr.Len)
|
||||||
|
}
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
if info.Addr != ep.SrcIP().As16() {
|
||||||
|
t.Errorf("unexpected address: %v", info.Addr)
|
||||||
|
}
|
||||||
|
if info.Ifindex != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = 1
|
||||||
|
hdr.Type = 2
|
||||||
|
hdr.Len = 3
|
||||||
|
|
||||||
|
setSrcControl(&control, &StdNetEndpoint{})
|
||||||
|
|
||||||
|
if len(control) != 0 {
|
||||||
|
t.Errorf("unexpected control: %v", control)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getSrcFromControl(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IP
|
||||||
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IPV6
|
||||||
|
hdr.Type = unix.IPV6_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ClearOnEmpty", func(t *testing.T) {
|
||||||
|
var control []byte
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||||
|
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 0 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Multiple", func(t *testing.T) {
|
||||||
|
zeroControl := make([]byte, unix.CmsgSpace(0))
|
||||||
|
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
|
||||||
|
zeroHdr.SetLen(unix.CmsgLen(0))
|
||||||
|
|
||||||
|
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IP
|
||||||
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
combined := make([]byte, 0)
|
||||||
|
combined = append(combined, zeroControl...)
|
||||||
|
combined = append(combined, control...)
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(combined, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_listenConfig(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var i int
|
||||||
|
sc.Control(func(fd uintptr) {
|
||||||
|
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Error("IP_PKTINFO not set!")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var i int
|
||||||
|
sc.Control(func(fd uintptr) {
|
||||||
|
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Error("IPV6_PKTINFO not set!")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package winrio
|
package winrio
|
||||||
|
@ -84,8 +84,10 @@ type iocpNotificationCompletion struct {
|
||||||
overlapped *windows.Overlapped
|
overlapped *windows.Overlapped
|
||||||
}
|
}
|
||||||
|
|
||||||
var initialized sync.Once
|
var (
|
||||||
var available bool
|
initialized sync.Once
|
||||||
|
available bool
|
||||||
|
)
|
||||||
|
|
||||||
func Initialize() bool {
|
func Initialize() bool {
|
||||||
initialized.Do(func() {
|
initialized.Do(func() {
|
||||||
|
@ -108,7 +110,7 @@ func Initialize() bool {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer windows.CloseHandle(socket)
|
defer windows.CloseHandle(socket)
|
||||||
var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
|
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
|
||||||
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
|
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
|
||||||
ob := uint32(0)
|
ob := uint32(0)
|
||||||
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
|
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
|
||||||
|
|
|
@ -1,68 +1,55 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"math/bits"
|
"math/bits"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type parentIndirection struct {
|
||||||
|
parentBit **trieEntry
|
||||||
|
parentBitType uint8
|
||||||
|
}
|
||||||
|
|
||||||
type trieEntry struct {
|
type trieEntry struct {
|
||||||
child [2]*trieEntry
|
peer *Peer
|
||||||
peer *Peer
|
child [2]*trieEntry
|
||||||
bits net.IP
|
parent parentIndirection
|
||||||
cidr uint
|
cidr uint8
|
||||||
bit_at_byte uint
|
bitAtByte uint8
|
||||||
bit_at_shift uint
|
bitAtShift uint8
|
||||||
perPeerElem *list.Element
|
bits []byte
|
||||||
|
perPeerElem *list.Element
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLittleEndian() bool {
|
func commonBits(ip1, ip2 []byte) uint8 {
|
||||||
one := uint32(1)
|
|
||||||
return *(*byte)(unsafe.Pointer(&one)) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func swapU32(i uint32) uint32 {
|
|
||||||
if !isLittleEndian() {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
return bits.ReverseBytes32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func swapU64(i uint64) uint64 {
|
|
||||||
if !isLittleEndian() {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
return bits.ReverseBytes64(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
|
||||||
size := len(ip1)
|
size := len(ip1)
|
||||||
if size == net.IPv4len {
|
if size == net.IPv4len {
|
||||||
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
a := binary.BigEndian.Uint32(ip1)
|
||||||
b := (*uint32)(unsafe.Pointer(&ip2[0]))
|
b := binary.BigEndian.Uint32(ip2)
|
||||||
x := *a ^ *b
|
x := a ^ b
|
||||||
return uint(bits.LeadingZeros32(swapU32(x)))
|
return uint8(bits.LeadingZeros32(x))
|
||||||
} else if size == net.IPv6len {
|
} else if size == net.IPv6len {
|
||||||
a := (*uint64)(unsafe.Pointer(&ip1[0]))
|
a := binary.BigEndian.Uint64(ip1)
|
||||||
b := (*uint64)(unsafe.Pointer(&ip2[0]))
|
b := binary.BigEndian.Uint64(ip2)
|
||||||
x := *a ^ *b
|
x := a ^ b
|
||||||
if x != 0 {
|
if x != 0 {
|
||||||
return uint(bits.LeadingZeros64(swapU64(x)))
|
return uint8(bits.LeadingZeros64(x))
|
||||||
}
|
}
|
||||||
a = (*uint64)(unsafe.Pointer(&ip1[8]))
|
a = binary.BigEndian.Uint64(ip1[8:])
|
||||||
b = (*uint64)(unsafe.Pointer(&ip2[8]))
|
b = binary.BigEndian.Uint64(ip2[8:])
|
||||||
x = *a ^ *b
|
x = a ^ b
|
||||||
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
|
return 64 + uint8(bits.LeadingZeros64(x))
|
||||||
} else {
|
} else {
|
||||||
panic("Wrong size bit string")
|
panic("Wrong size bit string")
|
||||||
}
|
}
|
||||||
|
@ -79,32 +66,8 @@ func (node *trieEntry) removeFromPeerEntries() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
func (node *trieEntry) choose(ip []byte) byte {
|
||||||
if node == nil {
|
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||||
return node
|
|
||||||
}
|
|
||||||
|
|
||||||
// walk recursively
|
|
||||||
|
|
||||||
node.child[0] = node.child[0].removeByPeer(p)
|
|
||||||
node.child[1] = node.child[1].removeByPeer(p)
|
|
||||||
|
|
||||||
if node.peer != p {
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove peer & merge
|
|
||||||
|
|
||||||
node.removeFromPeerEntries()
|
|
||||||
node.peer = nil
|
|
||||||
if node.child[0] == nil {
|
|
||||||
return node.child[1]
|
|
||||||
}
|
|
||||||
return node.child[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) choose(ip net.IP) byte {
|
|
||||||
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) maskSelf() {
|
func (node *trieEntry) maskSelf() {
|
||||||
|
@ -114,86 +77,125 @@ func (node *trieEntry) maskSelf() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
func (node *trieEntry) zeroizePointers() {
|
||||||
|
// Make the garbage collector's life slightly easier
|
||||||
|
node.peer = nil
|
||||||
|
node.child[0] = nil
|
||||||
|
node.child[1] = nil
|
||||||
|
node.parent.parentBit = nil
|
||||||
|
}
|
||||||
|
|
||||||
// at leaf
|
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||||
|
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||||
|
parent = node
|
||||||
|
if parent.cidr == cidr {
|
||||||
|
exact = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bit := node.choose(ip)
|
||||||
|
node = node.child[bit]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if node == nil {
|
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||||
|
if *trie.parentBit == nil {
|
||||||
node := &trieEntry{
|
node := &trieEntry{
|
||||||
bits: ip,
|
peer: peer,
|
||||||
peer: peer,
|
parent: trie,
|
||||||
cidr: cidr,
|
bits: ip,
|
||||||
bit_at_byte: cidr / 8,
|
cidr: cidr,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bitAtByte: cidr / 8,
|
||||||
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
node.maskSelf()
|
node.maskSelf()
|
||||||
node.addToPeerEntries()
|
node.addToPeerEntries()
|
||||||
return node
|
*trie.parentBit = node
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
||||||
// traverse deeper
|
if exact {
|
||||||
|
node.removeFromPeerEntries()
|
||||||
common := commonBits(node.bits, ip)
|
node.peer = peer
|
||||||
if node.cidr <= cidr && common >= node.cidr {
|
node.addToPeerEntries()
|
||||||
if node.cidr == cidr {
|
return
|
||||||
node.removeFromPeerEntries()
|
|
||||||
node.peer = peer
|
|
||||||
node.addToPeerEntries()
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
bit := node.choose(ip)
|
|
||||||
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
|
||||||
return node
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// split node
|
|
||||||
|
|
||||||
newNode := &trieEntry{
|
newNode := &trieEntry{
|
||||||
bits: ip,
|
peer: peer,
|
||||||
peer: peer,
|
bits: ip,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bit_at_byte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
newNode.maskSelf()
|
newNode.maskSelf()
|
||||||
newNode.addToPeerEntries()
|
newNode.addToPeerEntries()
|
||||||
|
|
||||||
cidr = min(cidr, common)
|
var down *trieEntry
|
||||||
|
if node == nil {
|
||||||
// check for shorter prefix
|
down = *trie.parentBit
|
||||||
|
} else {
|
||||||
|
bit := node.choose(ip)
|
||||||
|
down = node.child[bit]
|
||||||
|
if down == nil {
|
||||||
|
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = newNode
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common := commonBits(down.bits, ip)
|
||||||
|
if common < cidr {
|
||||||
|
cidr = common
|
||||||
|
}
|
||||||
|
parent := node
|
||||||
|
|
||||||
if newNode.cidr == cidr {
|
if newNode.cidr == cidr {
|
||||||
bit := newNode.choose(node.bits)
|
bit := newNode.choose(down.bits)
|
||||||
newNode.child[bit] = node
|
down.parent = parentIndirection{&newNode.child[bit], bit}
|
||||||
return newNode
|
newNode.child[bit] = down
|
||||||
|
if parent == nil {
|
||||||
|
newNode.parent = trie
|
||||||
|
*trie.parentBit = newNode
|
||||||
|
} else {
|
||||||
|
bit := parent.choose(newNode.bits)
|
||||||
|
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
||||||
|
parent.child[bit] = newNode
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new parent for node & newNode
|
node = &trieEntry{
|
||||||
|
bits: append([]byte{}, newNode.bits...),
|
||||||
parent := &trieEntry{
|
cidr: cidr,
|
||||||
bits: append([]byte{}, ip...),
|
bitAtByte: cidr / 8,
|
||||||
peer: nil,
|
bitAtShift: 7 - (cidr % 8),
|
||||||
cidr: cidr,
|
|
||||||
bit_at_byte: cidr / 8,
|
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
|
||||||
}
|
}
|
||||||
parent.maskSelf()
|
node.maskSelf()
|
||||||
|
|
||||||
bit := parent.choose(ip)
|
bit := node.choose(down.bits)
|
||||||
parent.child[bit] = newNode
|
down.parent = parentIndirection{&node.child[bit], bit}
|
||||||
parent.child[bit^1] = node
|
node.child[bit] = down
|
||||||
|
bit = node.choose(newNode.bits)
|
||||||
return parent
|
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = newNode
|
||||||
|
if parent == nil {
|
||||||
|
node.parent = trie
|
||||||
|
*trie.parentBit = node
|
||||||
|
} else {
|
||||||
|
bit := parent.choose(node.bits)
|
||||||
|
node.parent = parentIndirection{&parent.child[bit], bit}
|
||||||
|
parent.child[bit] = node
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||||
var found *Peer
|
var found *Peer
|
||||||
size := uint(len(ip))
|
size := uint8(len(ip))
|
||||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||||
if node.peer != nil {
|
if node.peer != nil {
|
||||||
found = node.peer
|
found = node.peer
|
||||||
}
|
}
|
||||||
if node.bit_at_byte == size {
|
if node.bitAtByte == size {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
bit := node.choose(ip)
|
||||||
|
@ -208,13 +210,14 @@ type AllowedIPs struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||||
node := elem.Value.(*trieEntry)
|
node := elem.Value.(*trieEntry)
|
||||||
if !cb(node.bits, node.cidr) {
|
a, _ := netip.AddrFromSlice(node.bits)
|
||||||
|
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -224,32 +227,68 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
table.IPv4 = table.IPv4.removeByPeer(peer)
|
var next *list.Element
|
||||||
table.IPv6 = table.IPv6.removeByPeer(peer)
|
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||||
|
next = elem.Next()
|
||||||
|
node := elem.Value.(*trieEntry)
|
||||||
|
|
||||||
|
node.removeFromPeerEntries()
|
||||||
|
node.peer = nil
|
||||||
|
if node.child[0] != nil && node.child[1] != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bit := 0
|
||||||
|
if node.child[0] == nil {
|
||||||
|
bit = 1
|
||||||
|
}
|
||||||
|
child := node.child[bit]
|
||||||
|
if child != nil {
|
||||||
|
child.parent = node.parent
|
||||||
|
}
|
||||||
|
*node.parent.parentBit = child
|
||||||
|
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||||
|
node.zeroizePointers()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||||
|
if parent.peer != nil {
|
||||||
|
node.zeroizePointers()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
child = parent.child[node.parent.parentBitType^1]
|
||||||
|
if child != nil {
|
||||||
|
child.parent = parent.parent
|
||||||
|
}
|
||||||
|
*parent.parent.parentBit = child
|
||||||
|
node.zeroizePointers()
|
||||||
|
parent.zeroizePointers()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
|
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
switch len(ip) {
|
if prefix.Addr().Is6() {
|
||||||
case net.IPv6len:
|
ip := prefix.Addr().As16()
|
||||||
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||||
case net.IPv4len:
|
} else if prefix.Addr().Is4() {
|
||||||
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
ip := prefix.Addr().As4()
|
||||||
default:
|
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||||
|
} else {
|
||||||
panic(errors.New("inserting unknown address type"))
|
panic(errors.New("inserting unknown address type"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
|
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
return table.IPv4.lookup(address)
|
switch len(ip) {
|
||||||
}
|
case net.IPv6len:
|
||||||
|
return table.IPv6.lookup(ip)
|
||||||
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
|
case net.IPv4len:
|
||||||
table.mutex.RLock()
|
return table.IPv4.lookup(ip)
|
||||||
defer table.mutex.RUnlock()
|
default:
|
||||||
return table.IPv6.lookup(address)
|
panic(errors.New("looking up unknown address type"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,25 +1,28 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NumberOfPeers = 100
|
NumberOfPeers = 100
|
||||||
NumberOfAddresses = 250
|
NumberOfPeerRemovals = 4
|
||||||
NumberOfTests = 10000
|
NumberOfAddresses = 250
|
||||||
|
NumberOfTests = 10000
|
||||||
)
|
)
|
||||||
|
|
||||||
type SlowNode struct {
|
type SlowNode struct {
|
||||||
peer *Peer
|
peer *Peer
|
||||||
cidr uint
|
cidr uint8
|
||||||
bits []byte
|
bits []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
|
||||||
r[i], r[j] = r[j], r[i]
|
r[i], r[j] = r[j], r[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
|
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
|
||||||
for _, t := range r {
|
for _, t := range r {
|
||||||
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
||||||
t.peer = peer
|
t.peer = peer
|
||||||
|
@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv4(t *testing.T) {
|
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
|
||||||
var trie *trieEntry
|
n := 0
|
||||||
var slow SlowRouter
|
for _, x := range r {
|
||||||
|
if x.peer != peer {
|
||||||
|
r[n] = x
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrieRandom(t *testing.T) {
|
||||||
|
var slow4, slow6 SlowRouter
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
const AddressLength = 4
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfPeers; n++ {
|
for n := 0; n < NumberOfPeers; n++ {
|
||||||
peers = append(peers, &Peer{})
|
peers = append(peers, &Peer{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < NumberOfAddresses; n++ {
|
for n := 0; n < NumberOfAddresses; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr4 [4]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr4[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Intn(32) + 1)
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Intn(NumberOfPeers)
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||||
|
|
||||||
|
var addr6 [16]byte
|
||||||
|
rand.Read(addr6[:])
|
||||||
|
cidr = uint8(rand.Intn(128) + 1)
|
||||||
|
index = rand.Intn(NumberOfPeers)
|
||||||
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||||
|
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < NumberOfTests; n++ {
|
var p int
|
||||||
var addr [AddressLength]byte
|
for p = 0; ; p++ {
|
||||||
rand.Read(addr[:])
|
for n := 0; n < NumberOfTests; n++ {
|
||||||
peer1 := slow.Lookup(addr[:])
|
var addr4 [4]byte
|
||||||
peer2 := trie.lookup(addr[:])
|
rand.Read(addr4[:])
|
||||||
if peer1 != peer2 {
|
peer1 := slow4.Lookup(addr4[:])
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
peer2 := allowedIPs.Lookup(addr4[:])
|
||||||
}
|
if peer1 != peer2 {
|
||||||
}
|
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv6(t *testing.T) {
|
var addr6 [16]byte
|
||||||
var trie *trieEntry
|
rand.Read(addr6[:])
|
||||||
var slow SlowRouter
|
peer1 = slow6.Lookup(addr6[:])
|
||||||
var peers []*Peer
|
peer2 = allowedIPs.Lookup(addr6[:])
|
||||||
|
if peer1 != peer2 {
|
||||||
rand.Seed(1)
|
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
|
||||||
|
}
|
||||||
const AddressLength = 16
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfPeers; n++ {
|
|
||||||
peers = append(peers, &Peer{})
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfAddresses; n++ {
|
|
||||||
var addr [AddressLength]byte
|
|
||||||
rand.Read(addr[:])
|
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
|
||||||
index := rand.Int() % NumberOfPeers
|
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfTests; n++ {
|
|
||||||
var addr [AddressLength]byte
|
|
||||||
rand.Read(addr[:])
|
|
||||||
peer1 := slow.Lookup(addr[:])
|
|
||||||
peer2 := trie.lookup(addr[:])
|
|
||||||
if peer1 != peer2 {
|
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
|
||||||
}
|
}
|
||||||
|
if p >= len(peers) || p >= NumberOfPeerRemovals {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
allowedIPs.RemoveByPeer(peers[p])
|
||||||
|
slow4 = slow4.RemoveByPeer(peers[p])
|
||||||
|
slow6 = slow6.RemoveByPeer(peers[p])
|
||||||
|
}
|
||||||
|
for ; p < len(peers); p++ {
|
||||||
|
allowedIPs.RemoveByPeer(peers[p])
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
||||||
|
t.Error("Failed to remove all nodes from trie by peer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -8,20 +8,17 @@ package device
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Todo: More comprehensive
|
|
||||||
*/
|
|
||||||
|
|
||||||
type testPairCommonBits struct {
|
type testPairCommonBits struct {
|
||||||
s1 []byte
|
s1 []byte
|
||||||
s2 []byte
|
s2 []byte
|
||||||
match uint
|
match uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommonBits(t *testing.T) {
|
func TestCommonBits(t *testing.T) {
|
||||||
|
|
||||||
tests := []testPairCommonBits{
|
tests := []testPairCommonBits{
|
||||||
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
||||||
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
||||||
|
@ -42,9 +39,10 @@ func TestCommonBits(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||||
var trie *trieEntry
|
var trie *trieEntry
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
root := parentIndirection{&trie, 2}
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
|
@ -57,9 +55,9 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
|
||||||
for n := 0; n < addressNumber; n++ {
|
for n := 0; n < addressNumber; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % peerNumber
|
index := rand.Int() % peerNumber
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
root.insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
|
@ -97,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *trieEntry
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||||
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.lookup([]byte{a, b, c, d})
|
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.lookup([]byte{a, b, c, d})
|
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||||
if p == peer {
|
if p == peer {
|
||||||
t.Error("Assert NEQ failed")
|
t.Error("Assert NEQ failed")
|
||||||
}
|
}
|
||||||
|
@ -153,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
assertEQ(a, 192, 0, 0, 0)
|
assertEQ(a, 192, 0, 0, 0)
|
||||||
assertEQ(a, 255, 0, 0, 0)
|
assertEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = trie.removeByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 1, 0, 0, 0)
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
assertNEQ(a, 64, 0, 0, 0)
|
assertNEQ(a, 64, 0, 0, 0)
|
||||||
|
@ -161,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
assertNEQ(a, 192, 0, 0, 0)
|
assertNEQ(a, 192, 0, 0, 0)
|
||||||
assertNEQ(a, 255, 0, 0, 0)
|
assertNEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = nil
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
allowedIPs.RemoveByPeer(b)
|
||||||
|
allowedIPs.RemoveByPeer(c)
|
||||||
|
allowedIPs.RemoveByPeer(d)
|
||||||
|
allowedIPs.RemoveByPeer(e)
|
||||||
|
allowedIPs.RemoveByPeer(g)
|
||||||
|
allowedIPs.RemoveByPeer(h)
|
||||||
|
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
||||||
|
t.Error("Expected removing all the peers to empty trie, but it did not")
|
||||||
|
}
|
||||||
|
|
||||||
insert(a, 192, 168, 0, 0, 16)
|
insert(a, 192, 168, 0, 0, 16)
|
||||||
insert(a, 192, 168, 0, 0, 24)
|
insert(a, 192, 168, 0, 0, 24)
|
||||||
|
|
||||||
trie = trie.removeByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 192, 168, 0, 1)
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
}
|
}
|
||||||
|
@ -184,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *trieEntry
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
expand := func(a uint32) []byte {
|
expand := func(a uint32) []byte {
|
||||||
var out [4]byte
|
var out [4]byte
|
||||||
|
@ -195,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
return out[:]
|
return out[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||||
var addr []byte
|
var addr []byte
|
||||||
addr = append(addr, expand(a)...)
|
addr = append(addr, expand(a)...)
|
||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
trie = trie.insert(addr, cidr, peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
|
@ -210,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
p := trie.lookup(addr)
|
p := allowedIPs.Lookup(addr)
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -8,7 +8,7 @@ package device
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DummyDatagram struct {
|
type DummyDatagram struct {
|
||||||
|
@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in6
|
datagram, ok := <-b.in6
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
copy(buff, datagram.msg)
|
copy(buf, datagram.msg)
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in4
|
datagram, ok := <-b.in4
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
copy(buff, datagram.msg)
|
copy(buf, datagram.msg)
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
|
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -19,13 +19,13 @@ import (
|
||||||
// call wg.Done to remove the initial reference.
|
// call wg.Done to remove the initial reference.
|
||||||
// When the refcount hits 0, the queue's channel is closed.
|
// When the refcount hits 0, the queue's channel is closed.
|
||||||
type outboundQueue struct {
|
type outboundQueue struct {
|
||||||
c chan *QueueOutboundElement
|
c chan *QueueOutboundElementsContainer
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOutboundQueue() *outboundQueue {
|
func newOutboundQueue() *outboundQueue {
|
||||||
q := &outboundQueue{
|
q := &outboundQueue{
|
||||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||||
}
|
}
|
||||||
q.wg.Add(1)
|
q.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
|
||||||
|
|
||||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||||
type inboundQueue struct {
|
type inboundQueue struct {
|
||||||
c chan *QueueInboundElement
|
c chan *QueueInboundElementsContainer
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func newInboundQueue() *inboundQueue {
|
func newInboundQueue() *inboundQueue {
|
||||||
q := &inboundQueue{
|
q := &inboundQueue{
|
||||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||||
}
|
}
|
||||||
q.wg.Add(1)
|
q.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
|
||||||
}
|
}
|
||||||
|
|
||||||
type autodrainingInboundQueue struct {
|
type autodrainingInboundQueue struct {
|
||||||
c chan *QueueInboundElement
|
c chan *QueueInboundElementsContainer
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||||
|
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
|
||||||
// some other means, such as sending a sentinel nil values.
|
// some other means, such as sending a sentinel nil values.
|
||||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||||
q := &autodrainingInboundQueue{
|
q := &autodrainingInboundQueue{
|
||||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||||
return q
|
return q
|
||||||
|
@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elem := <-q.c:
|
case elemsContainer := <-q.c:
|
||||||
elem.Lock()
|
elemsContainer.Lock()
|
||||||
device.PutMessageBuffer(elem.buffer)
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutInboundElement(elem)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutInboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type autodrainingOutboundQueue struct {
|
type autodrainingOutboundQueue struct {
|
||||||
c chan *QueueOutboundElement
|
c chan *QueueOutboundElementsContainer
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||||
|
@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
|
||||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||||
q := &autodrainingOutboundQueue{
|
q := &autodrainingOutboundQueue{
|
||||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||||
return q
|
return q
|
||||||
|
@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elem := <-q.c:
|
case elemsContainer := <-q.c:
|
||||||
elem.Lock()
|
elemsContainer.Lock()
|
||||||
device.PutMessageBuffer(elem.buffer)
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutOutboundElement(elem)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutOutboundElementsContainer(elemsContainer)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -35,7 +35,6 @@ const (
|
||||||
/* Implementation constants */
|
/* Implementation constants */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UnderLoadQueueSize = QueueHandshakeSize / 8
|
|
||||||
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
||||||
MaxPeers = 1 << 16 // maximum number of configured peers
|
MaxPeers = 1 << 16 // maximum number of configured peers
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
||||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
|
||||||
st.RLock()
|
st.RLock()
|
||||||
defer st.RUnlock()
|
defer st.RUnlock()
|
||||||
|
|
||||||
|
@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
|
||||||
recv uint32,
|
recv uint32,
|
||||||
src []byte,
|
src []byte,
|
||||||
) (*MessageCookieReply, error) {
|
) (*MessageCookieReply, error) {
|
||||||
|
|
||||||
st.RLock()
|
st.RLock()
|
||||||
|
|
||||||
// refresh cookie secret
|
// refresh cookie secret
|
||||||
|
@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||||
|
|
||||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
||||||
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieGenerator) AddMacs(msg []byte) {
|
func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||||
|
|
||||||
size := len(msg)
|
size := len(msg)
|
||||||
|
|
||||||
smac2 := size - blake2s.Size128
|
smac2 := size - blake2s.Size128
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -10,7 +10,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCookieMAC1(t *testing.T) {
|
func TestCookieMAC1(t *testing.T) {
|
||||||
|
|
||||||
// setup generator / checker
|
// setup generator / checker
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
|
||||||
|
|
||||||
msg[5] ^= 0x20
|
msg[5] ^= 0x20
|
||||||
|
|
||||||
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
|
srcBad1 := []byte{192, 168, 13, 37, 40, 1}
|
||||||
if checker.CheckMAC2(msg, srcBad1) {
|
if checker.CheckMAC2(msg, srcBad1) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
|
srcBad2 := []byte{192, 168, 13, 38, 40, 1}
|
||||||
if checker.CheckMAC2(msg, srcBad2) {
|
if checker.CheckMAC2(msg, srcBad2) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
|
379
device/device.go
379
device/device.go
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -11,10 +11,12 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
|
"github.com/tevino/abool/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
|
@ -30,7 +32,7 @@ type Device struct {
|
||||||
// will become the actual state; Up can fail.
|
// will become the actual state; Up can fail.
|
||||||
// The device can also change state multiple times between time of check and time of use.
|
// The device can also change state multiple times between time of check and time of use.
|
||||||
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
||||||
state uint32 // actually a deviceState, but typed uint32 for convenience
|
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
||||||
// stopping blocks until all inputs to Device have been closed.
|
// stopping blocks until all inputs to Device have been closed.
|
||||||
stopping sync.WaitGroup
|
stopping sync.WaitGroup
|
||||||
// mu protects state changes.
|
// mu protects state changes.
|
||||||
|
@ -44,6 +46,7 @@ type Device struct {
|
||||||
netlinkCancel *rwcancel.RWCancel
|
netlinkCancel *rwcancel.RWCancel
|
||||||
port uint16 // listening port
|
port uint16 // listening port
|
||||||
fwmark uint32 // mark value (0 = disabled)
|
fwmark uint32 // mark value (0 = disabled)
|
||||||
|
brokenRoaming bool
|
||||||
}
|
}
|
||||||
|
|
||||||
staticIdentity struct {
|
staticIdentity struct {
|
||||||
|
@ -52,24 +55,26 @@ type Device struct {
|
||||||
publicKey NoisePublicKey
|
publicKey NoisePublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
rate struct {
|
|
||||||
underLoadUntil int64
|
|
||||||
limiter ratelimiter.Ratelimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
peers struct {
|
peers struct {
|
||||||
sync.RWMutex // protects keyMap
|
sync.RWMutex // protects keyMap
|
||||||
keyMap map[NoisePublicKey]*Peer
|
keyMap map[NoisePublicKey]*Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rate struct {
|
||||||
|
underLoadUntil atomic.Int64
|
||||||
|
limiter ratelimiter.Ratelimiter
|
||||||
|
}
|
||||||
|
|
||||||
allowedips AllowedIPs
|
allowedips AllowedIPs
|
||||||
indexTable IndexTable
|
indexTable IndexTable
|
||||||
cookieChecker CookieChecker
|
cookieChecker CookieChecker
|
||||||
|
|
||||||
pool struct {
|
pool struct {
|
||||||
messageBuffers *WaitPool
|
inboundElementsContainer *WaitPool
|
||||||
inboundElements *WaitPool
|
outboundElementsContainer *WaitPool
|
||||||
outboundElements *WaitPool
|
messageBuffers *WaitPool
|
||||||
|
inboundElements *WaitPool
|
||||||
|
outboundElements *WaitPool
|
||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
|
@ -80,22 +85,39 @@ type Device struct {
|
||||||
|
|
||||||
tun struct {
|
tun struct {
|
||||||
device tun.Device
|
device tun.Device
|
||||||
mtu int32
|
mtu atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
ipcMutex sync.RWMutex
|
ipcMutex sync.RWMutex
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
log *Logger
|
log *Logger
|
||||||
|
|
||||||
|
isASecOn abool.AtomicBool
|
||||||
|
aSecMux sync.RWMutex
|
||||||
|
aSecCfg aSecCfgType
|
||||||
|
junkCreator junkCreator
|
||||||
|
}
|
||||||
|
|
||||||
|
type aSecCfgType struct {
|
||||||
|
isSet bool
|
||||||
|
junkPacketCount int
|
||||||
|
junkPacketMinSize int
|
||||||
|
junkPacketMaxSize int
|
||||||
|
initPacketJunkSize int
|
||||||
|
responsePacketJunkSize int
|
||||||
|
initPacketMagicHeader uint32
|
||||||
|
responsePacketMagicHeader uint32
|
||||||
|
underloadPacketMagicHeader uint32
|
||||||
|
transportPacketMagicHeader uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// deviceState represents the state of a Device.
|
// deviceState represents the state of a Device.
|
||||||
// There are three states: down, up, closed.
|
// There are three states: down, up, closed.
|
||||||
// Transitions:
|
// Transitions:
|
||||||
//
|
//
|
||||||
// down -----+
|
// down -----+
|
||||||
// ↑↓ ↓
|
// ↑↓ ↓
|
||||||
// up -> closed
|
// up -> closed
|
||||||
//
|
|
||||||
type deviceState uint32
|
type deviceState uint32
|
||||||
|
|
||||||
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
||||||
|
@ -108,7 +130,7 @@ const (
|
||||||
// deviceState returns device.state.state as a deviceState
|
// deviceState returns device.state.state as a deviceState
|
||||||
// See those docs for how to interpret this value.
|
// See those docs for how to interpret this value.
|
||||||
func (device *Device) deviceState() deviceState {
|
func (device *Device) deviceState() deviceState {
|
||||||
return deviceState(atomic.LoadUint32(&device.state.state))
|
return deviceState(device.state.state.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// isClosed reports whether the device is closed (or is closing).
|
// isClosed reports whether the device is closed (or is closing).
|
||||||
|
@ -147,20 +169,21 @@ func (device *Device) changeState(want deviceState) (err error) {
|
||||||
case old:
|
case old:
|
||||||
return nil
|
return nil
|
||||||
case deviceStateUp:
|
case deviceStateUp:
|
||||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
|
device.state.state.Store(uint32(deviceStateUp))
|
||||||
err = device.upLocked()
|
err = device.upLocked()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
fallthrough // up failed; bring the device all the way back down
|
fallthrough // up failed; bring the device all the way back down
|
||||||
case deviceStateDown:
|
case deviceStateDown:
|
||||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
|
device.state.state.Store(uint32(deviceStateDown))
|
||||||
errDown := device.downLocked()
|
errDown := device.downLocked()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = errDown
|
err = errDown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
device.log.Verbosef(
|
||||||
|
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,10 +195,15 @@ func (device *Device) upLocked() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The IPC set operation waits for peers to be created before calling Start() on them,
|
||||||
|
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
|
||||||
|
device.ipcMutex.Lock()
|
||||||
|
defer device.ipcMutex.Unlock()
|
||||||
|
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Start()
|
peer.Start()
|
||||||
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -210,13 +238,13 @@ func (device *Device) Down() error {
|
||||||
func (device *Device) IsUnderLoad() bool {
|
func (device *Device) IsUnderLoad() bool {
|
||||||
// check if currently under load
|
// check if currently under load
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
|
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
||||||
if underLoad {
|
if underLoad {
|
||||||
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// check if recently under load
|
// check if recently under load
|
||||||
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
|
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
|
@ -260,7 +288,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
expiredPeers = append(expiredPeers, peer)
|
expiredPeers = append(expiredPeers, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,7 +304,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
|
|
||||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
device.state.state = uint32(deviceStateDown)
|
device.state.state.Store(uint32(deviceStateDown))
|
||||||
device.closed = make(chan struct{})
|
device.closed = make(chan struct{})
|
||||||
device.log = logger
|
device.log = logger
|
||||||
device.net.bind = bind
|
device.net.bind = bind
|
||||||
|
@ -286,10 +314,11 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||||
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
||||||
mtu = DefaultMTU
|
mtu = DefaultMTU
|
||||||
}
|
}
|
||||||
device.tun.mtu = int32(mtu)
|
device.tun.mtu.Store(int32(mtu))
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
device.rate.limiter.Init()
|
device.rate.limiter.Init()
|
||||||
device.indexTable.Init()
|
device.indexTable.Init()
|
||||||
|
|
||||||
device.PopulatePools()
|
device.PopulatePools()
|
||||||
|
|
||||||
// create queues
|
// create queues
|
||||||
|
@ -304,9 +333,9 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
||||||
for i := 0; i < cpus; i++ {
|
for i := 0; i < cpus; i++ {
|
||||||
go device.RoutineEncryption()
|
go device.RoutineEncryption(i + 1)
|
||||||
go device.RoutineDecryption()
|
go device.RoutineDecryption(i + 1)
|
||||||
go device.RoutineHandshake()
|
go device.RoutineHandshake(i + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
device.state.stopping.Add(1) // RoutineReadFromTUN
|
device.state.stopping.Add(1) // RoutineReadFromTUN
|
||||||
|
@ -317,6 +346,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||||
|
// the bind batch size and the tun batch size. The batch size reported by device
|
||||||
|
// is the size used to construct memory pools, and is the allowed batch size for
|
||||||
|
// the lifetime of the device.
|
||||||
|
func (device *Device) BatchSize() int {
|
||||||
|
size := device.net.bind.BatchSize()
|
||||||
|
dSize := device.tun.device.BatchSize()
|
||||||
|
if size < dSize {
|
||||||
|
size = dSize
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
defer device.peers.RUnlock()
|
defer device.peers.RUnlock()
|
||||||
|
@ -349,10 +391,12 @@ func (device *Device) RemoveAllPeers() {
|
||||||
func (device *Device) Close() {
|
func (device *Device) Close() {
|
||||||
device.state.Lock()
|
device.state.Lock()
|
||||||
defer device.state.Unlock()
|
defer device.state.Unlock()
|
||||||
|
device.ipcMutex.Lock()
|
||||||
|
defer device.ipcMutex.Unlock()
|
||||||
if device.isClosed() {
|
if device.isClosed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
|
device.state.state.Store(uint32(deviceStateClosed))
|
||||||
device.log.Verbosef("Device closing")
|
device.log.Verbosef("Device closing")
|
||||||
|
|
||||||
device.tun.device.Close()
|
device.tun.device.Close()
|
||||||
|
@ -372,6 +416,8 @@ func (device *Device) Close() {
|
||||||
|
|
||||||
device.rate.limiter.Close()
|
device.rate.limiter.Close()
|
||||||
|
|
||||||
|
device.resetProtocol()
|
||||||
|
|
||||||
device.log.Verbosef("Device closed")
|
device.log.Verbosef("Device closed")
|
||||||
close(device.closed)
|
close(device.closed)
|
||||||
}
|
}
|
||||||
|
@ -438,11 +484,7 @@ func (device *Device) BindSetMark(mark uint32) error {
|
||||||
// clear cached source addresses
|
// clear cached source addresses
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
@ -467,11 +509,13 @@ func (device *Device) BindUpdate() error {
|
||||||
var err error
|
var err error
|
||||||
var recvFns []conn.ReceiveFunc
|
var recvFns []conn.ReceiveFunc
|
||||||
netc := &device.net
|
netc := &device.net
|
||||||
|
|
||||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
netc.port = 0
|
netc.port = 0
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
netc.bind.Close()
|
netc.bind.Close()
|
||||||
|
@ -490,11 +534,7 @@ func (device *Device) BindUpdate() error {
|
||||||
// clear cached source addresses
|
// clear cached source addresses
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
@ -502,8 +542,9 @@ func (device *Device) BindUpdate() error {
|
||||||
device.net.stopping.Add(len(recvFns))
|
device.net.stopping.Add(len(recvFns))
|
||||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||||
|
batchSize := netc.bind.BatchSize()
|
||||||
for _, fn := range recvFns {
|
for _, fn := range recvFns {
|
||||||
go device.RoutineReceiveIncoming(fn)
|
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
device.log.Verbosef("UDP bind has been updated")
|
device.log.Verbosef("UDP bind has been updated")
|
||||||
|
@ -516,3 +557,251 @@ func (device *Device) BindClose() error {
|
||||||
device.net.Unlock()
|
device.net.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
func (device *Device) isAdvancedSecurityOn() bool {
|
||||||
|
return device.isASecOn.IsSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) resetProtocol() {
|
||||||
|
// restore default message type values
|
||||||
|
MessageInitiationType = 1
|
||||||
|
MessageResponseType = 2
|
||||||
|
MessageCookieReplyType = 3
|
||||||
|
MessageTransportType = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
|
|
||||||
|
if !tempASecCfg.isSet {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isASecOn := false
|
||||||
|
device.aSecMux.Lock()
|
||||||
|
if tempASecCfg.junkPacketCount < 0 {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"JunkPacketCount should be non negative",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
||||||
|
if tempASecCfg.junkPacketCount != 0 {
|
||||||
|
isASecOn = true
|
||||||
|
}
|
||||||
|
|
||||||
|
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
||||||
|
if tempASecCfg.junkPacketMinSize != 0 {
|
||||||
|
isASecOn = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.aSecCfg.junkPacketCount > 0 &&
|
||||||
|
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
||||||
|
|
||||||
|
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
||||||
|
device.aSecCfg.junkPacketMinSize = 0
|
||||||
|
device.aSecCfg.junkPacketMaxSize = 1
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"maxSize: %d; should be greater than minSize: %d; %w",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
tempASecCfg.junkPacketMinSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"maxSize: %d; should be greater than minSize: %d",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
tempASecCfg.junkPacketMinSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.junkPacketMaxSize != 0 {
|
||||||
|
isASecOn = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||||
|
tempASecCfg.initPacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||||
|
tempASecCfg.initPacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.initPacketJunkSize != 0 {
|
||||||
|
isASecOn = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||||
|
tempASecCfg.responsePacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||||
|
tempASecCfg.responsePacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.responsePacketJunkSize != 0 {
|
||||||
|
isASecOn = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.initPacketMagicHeader > 4 {
|
||||||
|
isASecOn = true
|
||||||
|
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||||
|
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
|
||||||
|
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
||||||
|
} else {
|
||||||
|
device.log.Verbosef("UAPI: Using default init type")
|
||||||
|
MessageInitiationType = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.responsePacketMagicHeader > 4 {
|
||||||
|
isASecOn = true
|
||||||
|
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||||
|
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
|
||||||
|
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
||||||
|
} else {
|
||||||
|
device.log.Verbosef("UAPI: Using default response type")
|
||||||
|
MessageResponseType = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
||||||
|
isASecOn = true
|
||||||
|
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||||
|
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
|
||||||
|
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
||||||
|
} else {
|
||||||
|
device.log.Verbosef("UAPI: Using default underload type")
|
||||||
|
MessageCookieReplyType = 3
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.transportPacketMagicHeader > 4 {
|
||||||
|
isASecOn = true
|
||||||
|
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||||
|
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
|
||||||
|
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||||
|
} else {
|
||||||
|
device.log.Verbosef("UAPI: Using default transport type")
|
||||||
|
MessageTransportType = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
isSameMap := map[uint32]bool{}
|
||||||
|
isSameMap[MessageInitiationType] = true
|
||||||
|
isSameMap[MessageResponseType] = true
|
||||||
|
isSameMap[MessageCookieReplyType] = true
|
||||||
|
isSameMap[MessageTransportType] = true
|
||||||
|
|
||||||
|
// size will be different if same values
|
||||||
|
if len(isSameMap) != 4 {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
|
||||||
|
MessageInitiationType,
|
||||||
|
MessageResponseType,
|
||||||
|
MessageCookieReplyType,
|
||||||
|
MessageTransportType,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
||||||
|
MessageInitiationType,
|
||||||
|
MessageResponseType,
|
||||||
|
MessageCookieReplyType,
|
||||||
|
MessageTransportType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
||||||
|
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
||||||
|
|
||||||
|
if newInitSize == newResponseSize {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`new init size:%d; and new response size:%d; should differ; %w`,
|
||||||
|
newInitSize,
|
||||||
|
newResponseSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`new init size:%d; and new response size:%d; should differ`,
|
||||||
|
newInitSize,
|
||||||
|
newResponseSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
packetSizeToMsgType = map[int]uint32{
|
||||||
|
newInitSize: MessageInitiationType,
|
||||||
|
newResponseSize: MessageResponseType,
|
||||||
|
MessageCookieReplySize: MessageCookieReplyType,
|
||||||
|
MessageTransportSize: MessageTransportType,
|
||||||
|
}
|
||||||
|
|
||||||
|
msgTypeToJunkSize = map[uint32]int{
|
||||||
|
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
|
||||||
|
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
|
||||||
|
MessageCookieReplyType: 0,
|
||||||
|
MessageTransportType: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
device.isASecOn.SetTo(isASecOn)
|
||||||
|
device.junkCreator, err = NewJunkCreator(device)
|
||||||
|
device.aSecMux.Unlock()
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -11,7 +11,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -19,9 +20,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/conn/bindtest"
|
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
||||||
|
@ -48,7 +50,7 @@ func uapiCfg(cfg ...string) string {
|
||||||
|
|
||||||
// genConfigs generates a pair of configs that connect to each other.
|
// genConfigs generates a pair of configs that connect to each other.
|
||||||
// The configs use distinct, probably-usable ports.
|
// The configs use distinct, probably-usable ports.
|
||||||
func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) {
|
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
var key1, key2 NoisePrivateKey
|
var key1, key2 NoisePrivateKey
|
||||||
_, err := rand.Read(key1[:])
|
_, err := rand.Read(key1[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -89,6 +91,65 @@ func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
|
var key1, key2 NoisePrivateKey
|
||||||
|
_, err := rand.Read(key1[:])
|
||||||
|
if err != nil {
|
||||||
|
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||||
|
}
|
||||||
|
_, err = rand.Read(key2[:])
|
||||||
|
if err != nil {
|
||||||
|
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||||
|
}
|
||||||
|
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||||
|
|
||||||
|
cfgs[0] = uapiCfg(
|
||||||
|
"private_key", hex.EncodeToString(key1[:]),
|
||||||
|
"listen_port", "0",
|
||||||
|
"replace_peers", "true",
|
||||||
|
"jc", "5",
|
||||||
|
"jmin", "500",
|
||||||
|
"jmax", "1000",
|
||||||
|
"s1", "30",
|
||||||
|
"s2", "40",
|
||||||
|
"h1", "123456",
|
||||||
|
"h2", "67543",
|
||||||
|
"h4", "32345",
|
||||||
|
"h3", "123123",
|
||||||
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
|
"protocol_version", "1",
|
||||||
|
"replace_allowed_ips", "true",
|
||||||
|
"allowed_ip", "1.0.0.2/32",
|
||||||
|
)
|
||||||
|
endpointCfgs[0] = uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
)
|
||||||
|
cfgs[1] = uapiCfg(
|
||||||
|
"private_key", hex.EncodeToString(key2[:]),
|
||||||
|
"listen_port", "0",
|
||||||
|
"replace_peers", "true",
|
||||||
|
"jc", "5",
|
||||||
|
"jmin", "500",
|
||||||
|
"jmax", "1000",
|
||||||
|
"s1", "30",
|
||||||
|
"s2", "40",
|
||||||
|
"h1", "123456",
|
||||||
|
"h2", "67543",
|
||||||
|
"h4", "32345",
|
||||||
|
"h3", "123123",
|
||||||
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
|
"protocol_version", "1",
|
||||||
|
"replace_allowed_ips", "true",
|
||||||
|
"allowed_ip", "1.0.0.1/32",
|
||||||
|
)
|
||||||
|
endpointCfgs[1] = uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// A testPair is a pair of testPeers.
|
// A testPair is a pair of testPeers.
|
||||||
type testPair [2]testPeer
|
type testPair [2]testPeer
|
||||||
|
|
||||||
|
@ -96,7 +157,7 @@ type testPair [2]testPeer
|
||||||
type testPeer struct {
|
type testPeer struct {
|
||||||
tun *tuntest.ChannelTUN
|
tun *tuntest.ChannelTUN
|
||||||
dev *Device
|
dev *Device
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type SendDirection bool
|
type SendDirection bool
|
||||||
|
@ -113,7 +174,11 @@ func (d SendDirection) String() string {
|
||||||
return "pong"
|
return "pong"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
|
func (pair *testPair) Send(
|
||||||
|
tb testing.TB,
|
||||||
|
ping SendDirection,
|
||||||
|
done chan struct{},
|
||||||
|
) {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
p0, p1 := pair[0], pair[1]
|
p0, p1 := pair[0], pair[1]
|
||||||
if !ping {
|
if !ping {
|
||||||
|
@ -147,8 +212,16 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// genTestPair creates a testPair.
|
// genTestPair creates a testPair.
|
||||||
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
func genTestPair(
|
||||||
cfg, endpointCfg := genConfigs(tb)
|
tb testing.TB,
|
||||||
|
realSocket, withASecurity bool,
|
||||||
|
) (pair testPair) {
|
||||||
|
var cfg, endpointCfg [2]string
|
||||||
|
if withASecurity {
|
||||||
|
cfg, endpointCfg = genASecurityConfigs(tb)
|
||||||
|
} else {
|
||||||
|
cfg, endpointCfg = genConfigs(tb)
|
||||||
|
}
|
||||||
var binds [2]conn.Bind
|
var binds [2]conn.Bind
|
||||||
if realSocket {
|
if realSocket {
|
||||||
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||||
|
@ -159,7 +232,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
p := &pair[i]
|
p := &pair[i]
|
||||||
p.tun = tuntest.NewChannelTUN()
|
p.tun = tuntest.NewChannelTUN()
|
||||||
p.ip = net.IPv4(1, 0, 0, byte(i+1))
|
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
|
||||||
level := LogLevelVerbose
|
level := LogLevelVerbose
|
||||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||||
level = LogLevelError
|
level = LogLevelError
|
||||||
|
@ -192,7 +265,18 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||||
|
|
||||||
func TestTwoDevicePing(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true)
|
pair := genTestPair(t, true, false)
|
||||||
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
pair.Send(t, Ping, nil)
|
||||||
|
})
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
pair.Send(t, Pong, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestASecurityTwoDevicePing(t *testing.T) {
|
||||||
|
goroutineLeakCheck(t)
|
||||||
|
pair := genTestPair(t, true, true)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
pair.Send(t, Ping, nil)
|
pair.Send(t, Ping, nil)
|
||||||
})
|
})
|
||||||
|
@ -207,7 +291,7 @@ func TestUpDown(t *testing.T) {
|
||||||
const otrials = 10
|
const otrials = 10
|
||||||
|
|
||||||
for n := 0; n < otrials; n++ {
|
for n := 0; n < otrials; n++ {
|
||||||
pair := genTestPair(t, false)
|
pair := genTestPair(t, false, false)
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
for k := range pair[i].dev.peers.keyMap {
|
for k := range pair[i].dev.peers.keyMap {
|
||||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||||
|
@ -241,7 +325,7 @@ func TestUpDown(t *testing.T) {
|
||||||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||||
// It is intended to be used with the race detector to catch data races.
|
// It is intended to be used with the race detector to catch data races.
|
||||||
func TestConcurrencySafety(t *testing.T) {
|
func TestConcurrencySafety(t *testing.T) {
|
||||||
pair := genTestPair(t, true)
|
pair := genTestPair(t, true, false)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
const warmupIters = 10
|
const warmupIters = 10
|
||||||
|
@ -307,11 +391,22 @@ func TestConcurrencySafety(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Perform bind updates and keepalive sends concurrently with tunnel use.
|
||||||
|
t.Run("bindUpdate and keepalive", func(t *testing.T) {
|
||||||
|
const iters = 10
|
||||||
|
for i := 0; i < iters; i++ {
|
||||||
|
for _, peer := range pair {
|
||||||
|
peer.dev.BindUpdate()
|
||||||
|
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
close(done)
|
close(done)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLatency(b *testing.B) {
|
func BenchmarkLatency(b *testing.B) {
|
||||||
pair := genTestPair(b, true)
|
pair := genTestPair(b, true, false)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
|
@ -325,7 +420,7 @@ func BenchmarkLatency(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
func BenchmarkThroughput(b *testing.B) {
|
||||||
pair := genTestPair(b, true)
|
pair := genTestPair(b, true, false)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
|
@ -333,7 +428,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
|
|
||||||
// Measure how long it takes to receive b.N packets,
|
// Measure how long it takes to receive b.N packets,
|
||||||
// starting when we receive the first packet.
|
// starting when we receive the first packet.
|
||||||
var recv uint64
|
var recv atomic.Uint64
|
||||||
var elapsed time.Duration
|
var elapsed time.Duration
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -342,7 +437,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
var start time.Time
|
var start time.Time
|
||||||
for {
|
for {
|
||||||
<-pair[0].tun.Inbound
|
<-pair[0].tun.Inbound
|
||||||
new := atomic.AddUint64(&recv, 1)
|
new := recv.Add(1)
|
||||||
if new == 1 {
|
if new == 1 {
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
}
|
}
|
||||||
|
@ -358,7 +453,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
||||||
pingc := pair[1].tun.Outbound
|
pingc := pair[1].tun.Outbound
|
||||||
var sent uint64
|
var sent uint64
|
||||||
for atomic.LoadUint64(&recv) != uint64(b.N) {
|
for recv.Load() != uint64(b.N) {
|
||||||
sent++
|
sent++
|
||||||
pingc <- ping
|
pingc <- ping
|
||||||
}
|
}
|
||||||
|
@ -369,7 +464,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUAPIGet(b *testing.B) {
|
func BenchmarkUAPIGet(b *testing.B) {
|
||||||
pair := genTestPair(b, true)
|
pair := genTestPair(b, true, false)
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
pair.Send(b, Pong, nil)
|
pair.Send(b, Pong, nil)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -405,3 +500,73 @@ func goroutineLeakCheck(t *testing.T) {
|
||||||
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeBindSized struct {
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *fakeBindSized) Open(
|
||||||
|
port uint16,
|
||||||
|
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *fakeBindSized) Close() error { return nil }
|
||||||
|
|
||||||
|
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||||
|
|
||||||
|
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
||||||
|
|
||||||
|
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
||||||
|
|
||||||
|
func (b *fakeBindSized) BatchSize() int { return b.size }
|
||||||
|
|
||||||
|
type fakeTUNDeviceSized struct {
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
||||||
|
|
||||||
|
func TestBatchSize(t *testing.T) {
|
||||||
|
d := Device{}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{1}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{1}
|
||||||
|
if want, got := 1, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{1}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{128}
|
||||||
|
if want, got := 128, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{128}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{1}
|
||||||
|
if want, got := 128, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{128}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{128}
|
||||||
|
if want, got := 128, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,53 +1,49 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DummyEndpoint struct {
|
type DummyEndpoint struct {
|
||||||
src [16]byte
|
src, dst netip.Addr
|
||||||
dst [16]byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
||||||
var end DummyEndpoint
|
var src, dst [16]byte
|
||||||
if _, err := rand.Read(end.src[:]); err != nil {
|
if _, err := rand.Read(src[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err := rand.Read(end.dst[:])
|
_, err := rand.Read(dst[:])
|
||||||
return &end, err
|
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) ClearSrc() {}
|
func (e *DummyEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToString() string {
|
func (e *DummyEndpoint) SrcToString() string {
|
||||||
var addr net.UDPAddr
|
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
|
||||||
addr.IP = e.SrcIP()
|
|
||||||
addr.Port = 1000
|
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstToString() string {
|
func (e *DummyEndpoint) DstToString() string {
|
||||||
var addr net.UDPAddr
|
return netip.AddrPortFrom(e.DstIP(), 1000).String()
|
||||||
addr.IP = e.DstIP()
|
|
||||||
addr.Port = 1000
|
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToBytes() []byte {
|
func (e *DummyEndpoint) DstToBytes() []byte {
|
||||||
return e.src[:]
|
out := e.DstIP().AsSlice()
|
||||||
|
out = append(out, byte(1000&0xff))
|
||||||
|
out = append(out, byte((1000>>8)&0xff))
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstIP() net.IP {
|
func (e *DummyEndpoint) DstIP() netip.Addr {
|
||||||
return e.dst[:]
|
return e.dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcIP() net.IP {
|
func (e *DummyEndpoint) SrcIP() netip.Addr {
|
||||||
return e.src[:]
|
return e.src
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
69
device/junk_creator.go
Normal file
69
device/junk_creator.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
v2 "math/rand/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type junkCreator struct {
|
||||||
|
device *Device
|
||||||
|
cha8Rand *v2.ChaCha8
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJunkCreator(d *Device) (junkCreator, error) {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
_, err := crand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return junkCreator{}, err
|
||||||
|
}
|
||||||
|
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
|
||||||
|
if jc.device.aSecCfg.junkPacketCount == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
|
||||||
|
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
|
||||||
|
packetSize := jc.randomPacketSize()
|
||||||
|
junk, err := jc.randomJunkWithSize(packetSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
|
||||||
|
}
|
||||||
|
junks = append(junks, junk)
|
||||||
|
}
|
||||||
|
return junks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) randomPacketSize() int {
|
||||||
|
return int(
|
||||||
|
jc.cha8Rand.Uint64()%uint64(
|
||||||
|
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
|
||||||
|
),
|
||||||
|
) + jc.device.aSecCfg.junkPacketMinSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
|
||||||
|
headerJunk, err := jc.randomJunkWithSize(size)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create header junk: %v", err)
|
||||||
|
}
|
||||||
|
_, err = writer.Write(headerJunk)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write header junk: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
||||||
|
junk := make([]byte, size)
|
||||||
|
_, err := jc.cha8Rand.Read(junk)
|
||||||
|
return junk, err
|
||||||
|
}
|
124
device/junk_creator_test.go
Normal file
124
device/junk_creator_test.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
||||||
|
cfg, _ := genASecurityConfigs(t)
|
||||||
|
tun := tuntest.NewChannelTUN()
|
||||||
|
binds := bindtest.NewChannelBinds()
|
||||||
|
level := LogLevelVerbose
|
||||||
|
dev := NewDevice(
|
||||||
|
tun.TUN(),
|
||||||
|
binds[0],
|
||||||
|
NewLogger(level, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := dev.IpcSet(cfg[0]); err != nil {
|
||||||
|
t.Errorf("failed to configure device %v", err)
|
||||||
|
dev.Close()
|
||||||
|
return junkCreator{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
jc, err := NewJunkCreator(dev)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create junk creator %v", err)
|
||||||
|
dev.Close()
|
||||||
|
return junkCreator{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return jc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
got, err := jc.createJunkPackets()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf(
|
||||||
|
"junkCreator.createJunkPackets() = %v; failed",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, junk := range got {
|
||||||
|
key := string(junk)
|
||||||
|
if seen[key] {
|
||||||
|
t.Errorf(
|
||||||
|
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
||||||
|
got,
|
||||||
|
junk,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r1, _ := jc.randomJunkWithSize(10)
|
||||||
|
r2, _ := jc.randomJunkWithSize(10)
|
||||||
|
fmt.Printf("%v\n%v\n", r1, r2)
|
||||||
|
if bytes.Equal(r1, r2) {
|
||||||
|
t.Errorf("same junks %v", err)
|
||||||
|
jc.device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for range [30]struct{}{} {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
|
||||||
|
got > jc.device.aSecCfg.junkPacketMaxSize {
|
||||||
|
t.Errorf(
|
||||||
|
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||||
|
got,
|
||||||
|
jc.device.aSecCfg.junkPacketMinSize,
|
||||||
|
jc.device.aSecCfg.junkPacketMaxSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
s := "apple"
|
||||||
|
buffer := bytes.NewBuffer([]byte(s))
|
||||||
|
err := jc.appendJunk(buffer, 30)
|
||||||
|
if err != nil &&
|
||||||
|
buffer.Len() != len(s)+30 {
|
||||||
|
t.Errorf("appendWithJunk() size don't match")
|
||||||
|
}
|
||||||
|
read := make([]byte, 50)
|
||||||
|
buffer.Read(read)
|
||||||
|
fmt.Println(string(read))
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -20,7 +20,7 @@ type KDFTest struct {
|
||||||
t2 string
|
t2 string
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertEquals(t *testing.T, a string, b string) {
|
func assertEquals(t *testing.T, a, b string) {
|
||||||
if a != b {
|
if a != b {
|
||||||
t.Fatal("expected", a, "=", b)
|
t.Fatal("expected", a, "=", b)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -10,9 +10,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/replay"
|
"github.com/amnezia-vpn/amneziawg-go/replay"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Due to limitations in Go and /x/crypto there is currently
|
/* Due to limitations in Go and /x/crypto there is currently
|
||||||
|
@ -23,7 +22,7 @@ import (
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Keypair struct {
|
type Keypair struct {
|
||||||
sendNonce uint64 // accessed atomically
|
sendNonce atomic.Uint64
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
receive cipher.AEAD
|
receive cipher.AEAD
|
||||||
replayFilter replay.Filter
|
replayFilter replay.Filter
|
||||||
|
@ -37,15 +36,7 @@ type Keypairs struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
current *Keypair
|
current *Keypair
|
||||||
previous *Keypair
|
previous *Keypair
|
||||||
next *Keypair
|
next atomic.Pointer[Keypair]
|
||||||
}
|
|
||||||
|
|
||||||
func (kp *Keypairs) storeNext(next *Keypair) {
|
|
||||||
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kp *Keypairs) loadNext() *Keypair {
|
|
||||||
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -16,8 +16,8 @@ import (
|
||||||
// They do not require a trailing newline in the format.
|
// They do not require a trailing newline in the format.
|
||||||
// If nil, that level of logging will be silent.
|
// If nil, that level of logging will be silent.
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
Verbosef func(format string, args ...interface{})
|
Verbosef func(format string, args ...any)
|
||||||
Errorf func(format string, args ...interface{})
|
Errorf func(format string, args ...any)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log levels for use with NewLogger.
|
// Log levels for use with NewLogger.
|
||||||
|
@ -28,14 +28,14 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Function for use in Logger for discarding logged lines.
|
// Function for use in Logger for discarding logged lines.
|
||||||
func DiscardLogf(format string, args ...interface{}) {}
|
func DiscardLogf(format string, args ...any) {}
|
||||||
|
|
||||||
// NewLogger constructs a Logger that writes to stdout.
|
// NewLogger constructs a Logger that writes to stdout.
|
||||||
// It logs at the specified log level and above.
|
// It logs at the specified log level and above.
|
||||||
// It decorates log lines with the log level, date, time, and prepend.
|
// It decorates log lines with the log level, date, time, and prepend.
|
||||||
func NewLogger(level int, prepend string) *Logger {
|
func NewLogger(level int, prepend string) *Logger {
|
||||||
logger := &Logger{DiscardLogf, DiscardLogf}
|
logger := &Logger{DiscardLogf, DiscardLogf}
|
||||||
logf := func(prefix string) func(string, ...interface{}) {
|
logf := func(prefix string) func(string, ...any) {
|
||||||
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
||||||
}
|
}
|
||||||
if level >= LogLevelVerbose {
|
if level >= LogLevelVerbose {
|
||||||
|
|
|
@ -1,48 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Atomic Boolean */
|
|
||||||
|
|
||||||
const (
|
|
||||||
AtomicFalse = int32(iota)
|
|
||||||
AtomicTrue
|
|
||||||
)
|
|
||||||
|
|
||||||
type AtomicBool struct {
|
|
||||||
int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Get() bool {
|
|
||||||
return atomic.LoadInt32(&a.int32) == AtomicTrue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Swap(val bool) bool {
|
|
||||||
flag := AtomicFalse
|
|
||||||
if val {
|
|
||||||
flag = AtomicTrue
|
|
||||||
}
|
|
||||||
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Set(val bool) {
|
|
||||||
flag := AtomicFalse
|
|
||||||
if val {
|
|
||||||
flag = AtomicTrue
|
|
||||||
}
|
|
||||||
atomic.StoreInt32(&a.int32, flag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a, b uint) uint {
|
|
||||||
if a > b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
|
@ -1,16 +1,19 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
|
||||||
|
// though it will try to deal with it, and race maybe, if called after.
|
||||||
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||||
|
device.net.brokenRoaming = true
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Lock()
|
peer.endpoint.Lock()
|
||||||
peer.disableRoaming = peer.endpoint != nil
|
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||||
peer.Unlock()
|
peer.endpoint.Unlock()
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
"hash"
|
"hash"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
|
@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
var errInvalidPublicKey = errors.New("invalid public key")
|
||||||
|
|
||||||
|
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
||||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
curve25519.ScalarMult(&ss, ask, apk)
|
curve25519.ScalarMult(&ss, ask, apk)
|
||||||
return ss
|
if isZero(ss[:]) {
|
||||||
|
return ss, errInvalidPublicKey
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -15,7 +15,7 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
"github.com/amnezia-vpn/amneziawg-go/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
type handshakeState int
|
type handshakeState int
|
||||||
|
@ -52,11 +52,11 @@ const (
|
||||||
WGLabelCookie = "cookie--"
|
WGLabelCookie = "cookie--"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
MessageInitiationType = 1
|
MessageInitiationType uint32 = 1
|
||||||
MessageResponseType = 2
|
MessageResponseType uint32 = 2
|
||||||
MessageCookieReplyType = 3
|
MessageCookieReplyType uint32 = 3
|
||||||
MessageTransportType = 4
|
MessageTransportType uint32 = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -75,6 +75,10 @@ const (
|
||||||
MessageTransportOffsetContent = 16
|
MessageTransportOffsetContent = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var packetSizeToMsgType map[int]uint32
|
||||||
|
|
||||||
|
var msgTypeToJunkSize map[uint32]int
|
||||||
|
|
||||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||||
* by marshalling the messages in little-endian byteorder
|
* by marshalling the messages in little-endian byteorder
|
||||||
* we can treat these as a 32-bit unsigned int (for now)
|
* we can treat these as a 32-bit unsigned int (for now)
|
||||||
|
@ -138,11 +142,11 @@ var (
|
||||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
)
|
)
|
||||||
|
|
||||||
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
|
func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
|
||||||
KDF1(dst, c[:], data)
|
KDF1(dst, c[:], data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
|
func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
|
||||||
hash, _ := blake2s.New256(nil)
|
hash, _ := blake2s.New256(nil)
|
||||||
hash.Write(h[:])
|
hash.Write(h[:])
|
||||||
hash.Write(data)
|
hash.Write(data)
|
||||||
|
@ -175,8 +179,6 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||||
var errZeroECDHResult = errors.New("ECDH returned all zeros")
|
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.RUnlock()
|
||||||
|
|
||||||
|
@ -195,18 +197,20 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||||
|
|
||||||
handshake.mixHash(handshake.remoteStatic[:])
|
handshake.mixHash(handshake.remoteStatic[:])
|
||||||
|
|
||||||
|
device.aSecMux.RLock()
|
||||||
msg := MessageInitiation{
|
msg := MessageInitiation{
|
||||||
Type: MessageInitiationType,
|
Type: MessageInitiationType,
|
||||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||||
}
|
}
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
|
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt static key
|
// encrypt static key
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
if isZero(ss[:]) {
|
if err != nil {
|
||||||
return nil, errZeroECDHResult
|
return nil, err
|
||||||
}
|
}
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
KDF2(
|
KDF2(
|
||||||
|
@ -221,7 +225,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
return nil, errZeroECDHResult
|
return nil, errInvalidPublicKey
|
||||||
}
|
}
|
||||||
KDF2(
|
KDF2(
|
||||||
&handshake.chainKey,
|
&handshake.chainKey,
|
||||||
|
@ -252,9 +256,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
chainKey [blake2s.Size]byte
|
chainKey [blake2s.Size]byte
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device.aSecMux.RLock()
|
||||||
if msg.Type != MessageInitiationType {
|
if msg.Type != MessageInitiationType {
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.RUnlock()
|
||||||
|
@ -264,11 +271,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
// decrypt static key
|
// decrypt static key
|
||||||
var err error
|
|
||||||
var peerPK NoisePublicKey
|
var peerPK NoisePublicKey
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
if isZero(ss[:]) {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||||
|
@ -282,7 +288,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
// lookup peer
|
// lookup peer
|
||||||
|
|
||||||
peer := device.LookupPeer(peerPK)
|
peer := device.LookupPeer(peerPK)
|
||||||
if peer == nil {
|
if peer == nil || !peer.isRunning.Load() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,7 +376,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg MessageResponse
|
var msg MessageResponse
|
||||||
|
device.aSecMux.RLock()
|
||||||
msg.Type = MessageResponseType
|
msg.Type = MessageResponseType
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
msg.Sender = handshake.localIndex
|
msg.Sender = handshake.localIndex
|
||||||
msg.Receiver = handshake.remoteIndex
|
msg.Receiver = handshake.remoteIndex
|
||||||
|
|
||||||
|
@ -384,12 +392,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
if err != nil {
|
||||||
handshake.mixKey(ss[:])
|
return nil, err
|
||||||
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
}
|
||||||
handshake.mixKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
}()
|
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
handshake.mixKey(ss[:])
|
||||||
|
|
||||||
// add preshared key
|
// add preshared key
|
||||||
|
|
||||||
|
@ -406,11 +418,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
|
|
||||||
handshake.mixHash(tau[:])
|
handshake.mixHash(tau[:])
|
||||||
|
|
||||||
func() {
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
handshake.mixHash(msg.Empty[:])
|
||||||
handshake.mixHash(msg.Empty[:])
|
|
||||||
}()
|
|
||||||
|
|
||||||
handshake.state = handshakeResponseCreated
|
handshake.state = handshakeResponseCreated
|
||||||
|
|
||||||
|
@ -418,9 +428,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
|
device.aSecMux.RLock()
|
||||||
if msg.Type != MessageResponseType {
|
if msg.Type != MessageResponseType {
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
|
|
||||||
// lookup handshake by receiver
|
// lookup handshake by receiver
|
||||||
|
|
||||||
|
@ -436,7 +449,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
)
|
)
|
||||||
|
|
||||||
ok := func() bool {
|
ok := func() bool {
|
||||||
|
|
||||||
// lock handshake state
|
// lock handshake state
|
||||||
|
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
|
@ -456,17 +468,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
return false
|
||||||
setZero(ss[:])
|
}
|
||||||
}()
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
|
setZero(ss[:])
|
||||||
|
|
||||||
func() {
|
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
return false
|
||||||
setZero(ss[:])
|
}
|
||||||
}()
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
|
setZero(ss[:])
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key (psk)
|
||||||
|
|
||||||
|
@ -484,7 +498,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
// authenticate transcript
|
// authenticate transcript
|
||||||
|
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -582,12 +596,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.loadNext()
|
next := keypairs.next.Load()
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.storeNext(nil)
|
keypairs.next.Store(nil)
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
|
@ -596,7 +610,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.storeNext(keypair)
|
keypairs.next.Store(keypair)
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
|
@ -608,18 +622,18 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
|
|
||||||
if keypairs.loadNext() != receivedKeypair {
|
if keypairs.next.Load() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
if keypairs.loadNext() != receivedKeypair {
|
if keypairs.next.Load() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.loadNext()
|
keypairs.current = keypairs.next.Load()
|
||||||
keypairs.storeNext(nil)
|
keypairs.next.Store(nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -10,8 +10,8 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCurveWrappers(t *testing.T) {
|
func TestCurveWrappers(t *testing.T) {
|
||||||
|
@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
|
||||||
pk1 := sk1.publicKey()
|
pk1 := sk1.publicKey()
|
||||||
pk2 := sk2.publicKey()
|
pk2 := sk2.publicKey()
|
||||||
|
|
||||||
ss1 := sk1.sharedSecret(pk2)
|
ss1, err1 := sk1.sharedSecret(pk2)
|
||||||
ss2 := sk2.sharedSecret(pk1)
|
ss2, err2 := sk2.sharedSecret(pk1)
|
||||||
|
|
||||||
if ss1 != ss2 {
|
if ss1 != ss2 || err1 != nil || err2 != nil {
|
||||||
t.Fatal("Failed to compute shared secet")
|
t.Fatal("Failed to compute shared secet")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,6 +71,8 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
peer1.Start()
|
||||||
|
peer2.Start()
|
||||||
|
|
||||||
assertEqual(
|
assertEqual(
|
||||||
t,
|
t,
|
||||||
|
@ -146,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.loadNext()
|
key1 := peer1.keypairs.next.Load()
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
|
167
device/peer.go
167
device/peer.go
|
@ -1,53 +1,46 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
isRunning AtomicBool
|
isRunning atomic.Bool
|
||||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
keypairs Keypairs
|
||||||
keypairs Keypairs
|
handshake Handshake
|
||||||
handshake Handshake
|
device *Device
|
||||||
device *Device
|
stopping sync.WaitGroup // routines pending stop
|
||||||
endpoint conn.Endpoint
|
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||||
stopping sync.WaitGroup // routines pending stop
|
rxBytes atomic.Uint64 // bytes received from peer
|
||||||
|
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||||
|
|
||||||
// These fields are accessed with atomic operations, which must be
|
endpoint struct {
|
||||||
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
sync.Mutex
|
||||||
// allocated struct will be 64-bit aligned. So we place
|
val conn.Endpoint
|
||||||
// atomically-accessed fields up front, so that they can share in
|
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||||
// this alignment before smaller fields throw it off.
|
disableRoaming bool
|
||||||
stats struct {
|
|
||||||
txBytes uint64 // bytes send to peer (endpoint)
|
|
||||||
rxBytes uint64 // bytes received from peer
|
|
||||||
lastHandshakeNano int64 // nano seconds since epoch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
disableRoaming bool
|
|
||||||
|
|
||||||
timers struct {
|
timers struct {
|
||||||
retransmitHandshake *Timer
|
retransmitHandshake *Timer
|
||||||
sendKeepalive *Timer
|
sendKeepalive *Timer
|
||||||
newHandshake *Timer
|
newHandshake *Timer
|
||||||
zeroKeyMaterial *Timer
|
zeroKeyMaterial *Timer
|
||||||
persistentKeepalive *Timer
|
persistentKeepalive *Timer
|
||||||
handshakeAttempts uint32
|
handshakeAttempts atomic.Uint32
|
||||||
needAnotherKeepalive AtomicBool
|
needAnotherKeepalive atomic.Bool
|
||||||
sentLastMinuteHandshake AtomicBool
|
sentLastMinuteHandshake atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
state struct {
|
state struct {
|
||||||
|
@ -55,14 +48,14 @@ type Peer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
staged chan *QueueOutboundElement // staged packets before a handshake is available
|
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieGenerator CookieGenerator
|
cookieGenerator CookieGenerator
|
||||||
trieEntries list.List
|
trieEntries list.List
|
||||||
persistentKeepaliveInterval uint32 // accessed atomically
|
persistentKeepaliveInterval atomic.Uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
|
@ -84,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
|
|
||||||
// create peer
|
// create peer
|
||||||
peer := new(Peer)
|
peer := new(Peer)
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
|
|
||||||
peer.cookieGenerator.Init(pk)
|
peer.cookieGenerator.Init(pk)
|
||||||
peer.device = device
|
peer.device = device
|
||||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||||
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
|
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||||
|
|
||||||
// map public key
|
// map public key
|
||||||
_, ok := device.peers.keyMap[pk]
|
_, ok := device.peers.keyMap[pk]
|
||||||
|
@ -102,26 +93,27 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
// pre-compute DH
|
// pre-compute DH
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||||
handshake.remoteStatic = pk
|
handshake.remoteStatic = pk
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
// reset endpoint
|
// reset endpoint
|
||||||
peer.endpoint = nil
|
peer.endpoint.Lock()
|
||||||
|
peer.endpoint.val = nil
|
||||||
|
peer.endpoint.disableRoaming = false
|
||||||
|
peer.endpoint.clearSrcOnTx = false
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
|
// init timers
|
||||||
|
peer.timersInit()
|
||||||
|
|
||||||
// add
|
// add
|
||||||
device.peers.keyMap[pk] = peer
|
device.peers.keyMap[pk] = peer
|
||||||
|
|
||||||
// start peer
|
|
||||||
peer.timersInit()
|
|
||||||
if peer.device.isUp() {
|
|
||||||
peer.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||||
peer.device.net.RLock()
|
peer.device.net.RLock()
|
||||||
defer peer.device.net.RUnlock()
|
defer peer.device.net.RUnlock()
|
||||||
|
|
||||||
|
@ -129,27 +121,53 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.RLock()
|
peer.endpoint.Lock()
|
||||||
defer peer.RUnlock()
|
endpoint := peer.endpoint.val
|
||||||
|
if endpoint == nil {
|
||||||
if peer.endpoint == nil {
|
peer.endpoint.Unlock()
|
||||||
return errors.New("no known endpoint for peer")
|
return errors.New("no known endpoint for peer")
|
||||||
}
|
}
|
||||||
|
if peer.endpoint.clearSrcOnTx {
|
||||||
|
endpoint.ClearSrc()
|
||||||
|
peer.endpoint.clearSrcOnTx = false
|
||||||
|
}
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
|
var totalLen uint64
|
||||||
|
for _, b := range buffers {
|
||||||
|
totalLen += uint64(len(b))
|
||||||
|
}
|
||||||
|
peer.txBytes.Add(totalLen)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) String() string {
|
func (peer *Peer) String() string {
|
||||||
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
// The awful goo that follows is identical to:
|
||||||
abbreviatedKey := "invalid"
|
//
|
||||||
if len(base64Key) == 44 {
|
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
||||||
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
|
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
|
||||||
|
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
||||||
|
//
|
||||||
|
// except that it is considerably more efficient.
|
||||||
|
src := peer.handshake.remoteStatic
|
||||||
|
b64 := func(input byte) byte {
|
||||||
|
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
b := []byte("peer(____…____)")
|
||||||
|
const first = len("peer(")
|
||||||
|
const second = len("peer(____…")
|
||||||
|
b[first+0] = b64((src[0] >> 2) & 63)
|
||||||
|
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
|
||||||
|
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
|
||||||
|
b[first+3] = b64(src[2] & 63)
|
||||||
|
b[second+0] = b64(src[29] & 63)
|
||||||
|
b[second+1] = b64((src[30] >> 2) & 63)
|
||||||
|
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
|
||||||
|
b[second+3] = b64((src[31] << 2) & 63)
|
||||||
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) Start() {
|
func (peer *Peer) Start() {
|
||||||
|
@ -162,12 +180,12 @@ func (peer *Peer) Start() {
|
||||||
peer.state.Lock()
|
peer.state.Lock()
|
||||||
defer peer.state.Unlock()
|
defer peer.state.Unlock()
|
||||||
|
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
device.log.Verbosef("%v - Starting...", peer)
|
device.log.Verbosef("%v - Starting", peer)
|
||||||
|
|
||||||
// reset routine state
|
// reset routine state
|
||||||
peer.stopping.Wait()
|
peer.stopping.Wait()
|
||||||
|
@ -183,10 +201,14 @@ func (peer *Peer) Start() {
|
||||||
|
|
||||||
device.flushInboundQueue(peer.queue.inbound)
|
device.flushInboundQueue(peer.queue.inbound)
|
||||||
device.flushOutboundQueue(peer.queue.outbound)
|
device.flushOutboundQueue(peer.queue.outbound)
|
||||||
go peer.RoutineSequentialSender()
|
|
||||||
go peer.RoutineSequentialReceiver()
|
|
||||||
|
|
||||||
peer.isRunning.Set(true)
|
// Use the device batch size, not the bind batch size, as the device size is
|
||||||
|
// the size of the batch pools.
|
||||||
|
batchSize := peer.device.BatchSize()
|
||||||
|
go peer.RoutineSequentialSender(batchSize)
|
||||||
|
go peer.RoutineSequentialReceiver(batchSize)
|
||||||
|
|
||||||
|
peer.isRunning.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) ZeroAndFlushAll() {
|
func (peer *Peer) ZeroAndFlushAll() {
|
||||||
|
@ -198,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
device.DeleteKeypair(keypairs.previous)
|
device.DeleteKeypair(keypairs.previous)
|
||||||
device.DeleteKeypair(keypairs.current)
|
device.DeleteKeypair(keypairs.current)
|
||||||
device.DeleteKeypair(keypairs.loadNext())
|
device.DeleteKeypair(keypairs.next.Load())
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
keypairs.current = nil
|
keypairs.current = nil
|
||||||
keypairs.storeNext(nil)
|
keypairs.next.Store(nil)
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
|
|
||||||
// clear handshake state
|
// clear handshake state
|
||||||
|
@ -226,11 +248,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
if keypairs.current != nil {
|
if keypairs.current != nil {
|
||||||
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
|
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
||||||
}
|
}
|
||||||
if keypairs.next != nil {
|
if next := keypairs.next.Load(); next != nil {
|
||||||
next := keypairs.loadNext()
|
next.sendNonce.Store(RejectAfterMessages)
|
||||||
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
|
|
||||||
}
|
}
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
}
|
}
|
||||||
|
@ -243,7 +264,7 @@ func (peer *Peer) Stop() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.device.log.Verbosef("%v - Stopping...", peer)
|
peer.device.log.Verbosef("%v - Stopping", peer)
|
||||||
|
|
||||||
peer.timersStop()
|
peer.timersStop()
|
||||||
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
|
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
|
||||||
|
@ -256,10 +277,20 @@ func (peer *Peer) Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||||
if peer.disableRoaming {
|
peer.endpoint.Lock()
|
||||||
|
defer peer.endpoint.Unlock()
|
||||||
|
if peer.endpoint.disableRoaming {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.endpoint.clearSrcOnTx = false
|
||||||
peer.endpoint = endpoint
|
peer.endpoint.val = endpoint
|
||||||
peer.Unlock()
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) markEndpointSrcForClearing() {
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
defer peer.endpoint.Unlock()
|
||||||
|
if peer.endpoint.val == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peer.endpoint.clearSrcOnTx = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -14,49 +14,85 @@ type WaitPool struct {
|
||||||
pool sync.Pool
|
pool sync.Pool
|
||||||
cond sync.Cond
|
cond sync.Cond
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
count uint32
|
count atomic.Uint32
|
||||||
max uint32
|
max uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
|
func NewWaitPool(max uint32, new func() any) *WaitPool {
|
||||||
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
|
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
|
||||||
p.cond = sync.Cond{L: &p.lock}
|
p.cond = sync.Cond{L: &p.lock}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WaitPool) Get() interface{} {
|
func (p *WaitPool) Get() any {
|
||||||
if p.max != 0 {
|
if p.max != 0 {
|
||||||
p.lock.Lock()
|
p.lock.Lock()
|
||||||
for atomic.LoadUint32(&p.count) >= p.max {
|
for p.count.Load() >= p.max {
|
||||||
p.cond.Wait()
|
p.cond.Wait()
|
||||||
}
|
}
|
||||||
atomic.AddUint32(&p.count, 1)
|
p.count.Add(1)
|
||||||
p.lock.Unlock()
|
p.lock.Unlock()
|
||||||
}
|
}
|
||||||
return p.pool.Get()
|
return p.pool.Get()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WaitPool) Put(x interface{}) {
|
func (p *WaitPool) Put(x any) {
|
||||||
p.pool.Put(x)
|
p.pool.Put(x)
|
||||||
if p.max == 0 {
|
if p.max == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.AddUint32(&p.count, ^uint32(0))
|
p.count.Add(^uint32(0))
|
||||||
p.cond.Signal()
|
p.cond.Signal()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PopulatePools() {
|
func (device *Device) PopulatePools() {
|
||||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
|
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
|
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||||
|
return &QueueInboundElementsContainer{elems: s}
|
||||||
|
})
|
||||||
|
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
|
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||||
|
return &QueueOutboundElementsContainer{elems: s}
|
||||||
|
})
|
||||||
|
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
return new([MaxMessageSize]byte)
|
return new([MaxMessageSize]byte)
|
||||||
})
|
})
|
||||||
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
|
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
return new(QueueInboundElement)
|
return new(QueueInboundElement)
|
||||||
})
|
})
|
||||||
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
|
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
return new(QueueOutboundElement)
|
return new(QueueOutboundElement)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||||
|
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||||
|
c.Mutex = sync.Mutex{}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||||
|
for i := range c.elems {
|
||||||
|
c.elems[i] = nil
|
||||||
|
}
|
||||||
|
c.elems = c.elems[:0]
|
||||||
|
device.pool.inboundElementsContainer.Put(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||||
|
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||||
|
c.Mutex = sync.Mutex{}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||||
|
for i := range c.elems {
|
||||||
|
c.elems[i] = nil
|
||||||
|
}
|
||||||
|
c.elems = c.elems[:0]
|
||||||
|
device.pool.outboundElementsContainer.Put(c)
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -17,29 +17,31 @@ import (
|
||||||
func TestWaitPool(t *testing.T) {
|
func TestWaitPool(t *testing.T) {
|
||||||
t.Skip("Currently disabled")
|
t.Skip("Currently disabled")
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
trials := int32(100000)
|
var trials atomic.Int32
|
||||||
|
startTrials := int32(100000)
|
||||||
if raceEnabled {
|
if raceEnabled {
|
||||||
// This test can be very slow with -race.
|
// This test can be very slow with -race.
|
||||||
trials /= 10
|
startTrials /= 10
|
||||||
}
|
}
|
||||||
|
trials.Store(startTrials)
|
||||||
workers := runtime.NumCPU() + 2
|
workers := runtime.NumCPU() + 2
|
||||||
if workers-4 <= 0 {
|
if workers-4 <= 0 {
|
||||||
t.Skip("Not enough cores")
|
t.Skip("Not enough cores")
|
||||||
}
|
}
|
||||||
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
|
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
||||||
wg.Add(workers)
|
wg.Add(workers)
|
||||||
max := uint32(0)
|
var max atomic.Uint32
|
||||||
updateMax := func() {
|
updateMax := func() {
|
||||||
count := atomic.LoadUint32(&p.count)
|
count := p.count.Load()
|
||||||
if count > p.max {
|
if count > p.max {
|
||||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
old := atomic.LoadUint32(&max)
|
old := max.Load()
|
||||||
if count <= old {
|
if count <= old {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if atomic.CompareAndSwapUint32(&max, old, count) {
|
if max.CompareAndSwap(old, count) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
|
||||||
for i := 0; i < workers; i++ {
|
for i := 0; i < workers; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for atomic.AddInt32(&trials, -1) > 0 {
|
for trials.Add(-1) > 0 {
|
||||||
updateMax()
|
updateMax()
|
||||||
x := p.Get()
|
x := p.Get()
|
||||||
updateMax()
|
updateMax()
|
||||||
|
@ -59,25 +61,74 @@ func TestWaitPool(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
if max != p.max {
|
if max.Load() != p.max {
|
||||||
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWaitPool(b *testing.B) {
|
func BenchmarkWaitPool(b *testing.B) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
trials := int32(b.N)
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
workers := runtime.NumCPU() + 2
|
workers := runtime.NumCPU() + 2
|
||||||
if workers-4 <= 0 {
|
if workers-4 <= 0 {
|
||||||
b.Skip("Not enough cores")
|
b.Skip("Not enough cores")
|
||||||
}
|
}
|
||||||
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
|
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
||||||
wg.Add(workers)
|
wg.Add(workers)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < workers; i++ {
|
for i := 0; i < workers; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for atomic.AddInt32(&trials, -1) > 0 {
|
for trials.Add(-1) > 0 {
|
||||||
|
x := p.Get()
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
p.Put(x)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWaitPoolEmpty(b *testing.B) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
|
workers := runtime.NumCPU() + 2
|
||||||
|
if workers-4 <= 0 {
|
||||||
|
b.Skip("Not enough cores")
|
||||||
|
}
|
||||||
|
p := NewWaitPool(0, func() any { return make([]byte, 16) })
|
||||||
|
wg.Add(workers)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for trials.Add(-1) > 0 {
|
||||||
|
x := p.Get()
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
p.Put(x)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSyncPool(b *testing.B) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
|
workers := runtime.NumCPU() + 2
|
||||||
|
if workers-4 <= 0 {
|
||||||
|
b.Skip("Not enough cores")
|
||||||
|
}
|
||||||
|
p := sync.Pool{New: func() any { return make([]byte, 16) }}
|
||||||
|
wg.Add(workers)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for trials.Add(-1) > 0 {
|
||||||
x := p.Get()
|
x := p.Get()
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
p.Put(x)
|
p.Put(x)
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
|
||||||
/* Reduce memory consumption for Android */
|
/* Reduce memory consumption for Android */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
QueueStagedSize = 128
|
QueueStagedSize = conn.IdealBatchSize
|
||||||
QueueOutboundSize = 1024
|
QueueOutboundSize = 1024
|
||||||
QueueInboundSize = 1024
|
QueueInboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
MaxSegmentSize = 2200
|
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||||
PreallocatedBuffersPerPool = 4096
|
PreallocatedBuffersPerPool = 4096
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
// +build !android,!ios,!windows
|
//go:build !android && !ios && !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
QueueStagedSize = 128
|
QueueStagedSize = conn.IdealBatchSize
|
||||||
QueueOutboundSize = 1024
|
QueueOutboundSize = 1024
|
||||||
QueueInboundSize = 1024
|
QueueInboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
|
|
|
@ -1,19 +1,21 @@
|
||||||
// +build ios
|
//go:build ios
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */
|
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
|
||||||
|
// These are vars instead of consts, because heavier network extensions might want to reduce
|
||||||
const (
|
// them further.
|
||||||
QueueStagedSize = 128
|
var (
|
||||||
QueueOutboundSize = 1024
|
QueueStagedSize = 128
|
||||||
QueueInboundSize = 1024
|
QueueOutboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueInboundSize = 1024
|
||||||
MaxSegmentSize = 1700
|
QueueHandshakeSize = 1024
|
||||||
PreallocatedBuffersPerPool = 1024
|
PreallocatedBuffersPerPool uint32 = 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const MaxSegmentSize = 1700
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
//+build !race
|
//go:build !race
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
//+build race
|
//go:build race
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -11,14 +11,12 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type QueueHandshakeElement struct {
|
type QueueHandshakeElement struct {
|
||||||
|
@ -29,7 +27,6 @@ type QueueHandshakeElement struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueueInboundElement struct {
|
type QueueInboundElement struct {
|
||||||
sync.Mutex
|
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
packet []byte
|
packet []byte
|
||||||
counter uint64
|
counter uint64
|
||||||
|
@ -37,6 +34,11 @@ type QueueInboundElement struct {
|
||||||
endpoint conn.Endpoint
|
endpoint conn.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueueInboundElementsContainer struct {
|
||||||
|
sync.Mutex
|
||||||
|
elems []*QueueInboundElement
|
||||||
|
}
|
||||||
|
|
||||||
// clearPointers clears elem fields that contain pointers.
|
// clearPointers clears elem fields that contain pointers.
|
||||||
// This makes the garbage collector's life easier and
|
// This makes the garbage collector's life easier and
|
||||||
// avoids accidentally keeping other objects around unnecessarily.
|
// avoids accidentally keeping other objects around unnecessarily.
|
||||||
|
@ -53,12 +55,12 @@ func (elem *QueueInboundElement) clearPointers() {
|
||||||
* NOTE: Not thread safe, but called by sequential receiver!
|
* NOTE: Not thread safe, but called by sequential receiver!
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) keepKeyFreshReceiving() {
|
func (peer *Peer) keepKeyFreshReceiving() {
|
||||||
if peer.timers.sentLastMinuteHandshake.Get() {
|
if peer.timers.sentLastMinuteHandshake.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keypair := peer.keypairs.Current()
|
keypair := peer.keypairs.Current()
|
||||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||||
peer.timers.sentLastMinuteHandshake.Set(true)
|
peer.timers.sentLastMinuteHandshake.Store(true)
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,7 +70,10 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||||
* Every time the bind is updated a new routine is started for
|
* Every time the bind is updated a new routine is started for
|
||||||
* IPv4 and IPv6 (separately)
|
* IPv4 and IPv6 (separately)
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
func (device *Device) RoutineReceiveIncoming(
|
||||||
|
maxBatchSize int,
|
||||||
|
recv conn.ReceiveFunc,
|
||||||
|
) {
|
||||||
recvName := recv.PrettyName()
|
recvName := recv.PrettyName()
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||||
|
@ -81,168 +86,226 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||||
|
|
||||||
// receive datagrams until conn is closed
|
// receive datagrams until conn is closed
|
||||||
|
|
||||||
buffer := device.GetMessageBuffer()
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||||
|
bufs = make([][]byte, maxBatchSize)
|
||||||
err error
|
err error
|
||||||
size int
|
sizes = make([]int, maxBatchSize)
|
||||||
endpoint conn.Endpoint
|
count int
|
||||||
|
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||||
deathSpiral int
|
deathSpiral int
|
||||||
|
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for i := range bufsArrs {
|
||||||
size, endpoint, err = recv(buffer[:])
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
|
bufs[i] = bufsArrs[i][:]
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for i := 0; i < maxBatchSize; i++ {
|
||||||
|
if bufsArrs[i] != nil {
|
||||||
|
device.PutMessageBuffer(bufsArrs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
count, err = recv(bufs, sizes, endpoints)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.PutMessageBuffer(buffer)
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
|
||||||
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
device.log.Errorf("Failed to receive packet: %v", err)
|
|
||||||
if deathSpiral < 10 {
|
if deathSpiral < 10 {
|
||||||
deathSpiral++
|
deathSpiral++
|
||||||
time.Sleep(time.Second / 3)
|
time.Sleep(time.Second / 3)
|
||||||
buffer = device.GetMessageBuffer()
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
deathSpiral = 0
|
deathSpiral = 0
|
||||||
|
|
||||||
if size < MinMessageSize {
|
device.aSecMux.RLock()
|
||||||
continue
|
// handle each packet in the batch
|
||||||
}
|
for i, size := range sizes[:count] {
|
||||||
|
if size < MinMessageSize {
|
||||||
// check size of packet
|
|
||||||
|
|
||||||
packet := buffer[:size]
|
|
||||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
|
||||||
|
|
||||||
var okay bool
|
|
||||||
|
|
||||||
switch msgType {
|
|
||||||
|
|
||||||
// check if transport
|
|
||||||
|
|
||||||
case MessageTransportType:
|
|
||||||
|
|
||||||
// check size
|
|
||||||
|
|
||||||
if len(packet) < MessageTransportSize {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookup key pair
|
// check size of packet
|
||||||
|
|
||||||
receiver := binary.LittleEndian.Uint32(
|
packet := bufsArrs[i][:size]
|
||||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
var msgType uint32
|
||||||
)
|
if device.isAdvancedSecurityOn() {
|
||||||
value := device.indexTable.Lookup(receiver)
|
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
||||||
keypair := value.keypair
|
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||||
if keypair == nil {
|
// transport size can align with other header types;
|
||||||
continue
|
// making sure we have the right msgType
|
||||||
}
|
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
|
||||||
|
if msgType == assumedMsgType {
|
||||||
// check keypair expiry
|
packet = packet[junkSize:]
|
||||||
|
} else {
|
||||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
device.log.Verbosef("Transport packet lined up with another msg type")
|
||||||
continue
|
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// create work element
|
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||||
peer := value.peer
|
if msgType != MessageTransportType {
|
||||||
elem := device.GetInboundElement()
|
device.log.Verbosef("ASec: Received message with unknown type")
|
||||||
elem.packet = packet
|
continue
|
||||||
elem.buffer = buffer
|
}
|
||||||
elem.keypair = keypair
|
}
|
||||||
elem.endpoint = endpoint
|
|
||||||
elem.counter = 0
|
|
||||||
elem.Mutex = sync.Mutex{}
|
|
||||||
elem.Lock()
|
|
||||||
|
|
||||||
// add to decryption queues
|
|
||||||
if peer.isRunning.Get() {
|
|
||||||
peer.queue.inbound.c <- elem
|
|
||||||
device.queue.decryption.c <- elem
|
|
||||||
buffer = device.GetMessageBuffer()
|
|
||||||
} else {
|
} else {
|
||||||
device.PutInboundElement(elem)
|
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||||
}
|
}
|
||||||
continue
|
switch msgType {
|
||||||
|
|
||||||
// otherwise it is a fixed size & handshake related packet
|
// check if transport
|
||||||
|
|
||||||
case MessageInitiationType:
|
case MessageTransportType:
|
||||||
okay = len(packet) == MessageInitiationSize
|
|
||||||
|
|
||||||
case MessageResponseType:
|
// check size
|
||||||
okay = len(packet) == MessageResponseSize
|
|
||||||
|
|
||||||
case MessageCookieReplyType:
|
if len(packet) < MessageTransportSize {
|
||||||
okay = len(packet) == MessageCookieReplySize
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
// lookup key pair
|
||||||
device.log.Verbosef("Received message with unknown type")
|
|
||||||
}
|
receiver := binary.LittleEndian.Uint32(
|
||||||
|
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||||
|
)
|
||||||
|
value := device.indexTable.Lookup(receiver)
|
||||||
|
keypair := value.keypair
|
||||||
|
if keypair == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check keypair expiry
|
||||||
|
|
||||||
|
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// create work element
|
||||||
|
peer := value.peer
|
||||||
|
elem := device.GetInboundElement()
|
||||||
|
elem.packet = packet
|
||||||
|
elem.buffer = bufsArrs[i]
|
||||||
|
elem.keypair = keypair
|
||||||
|
elem.endpoint = endpoints[i]
|
||||||
|
elem.counter = 0
|
||||||
|
|
||||||
|
elemsForPeer, ok := elemsByPeer[peer]
|
||||||
|
if !ok {
|
||||||
|
elemsForPeer = device.GetInboundElementsContainer()
|
||||||
|
elemsForPeer.Lock()
|
||||||
|
elemsByPeer[peer] = elemsForPeer
|
||||||
|
}
|
||||||
|
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||||
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
|
bufs[i] = bufsArrs[i][:]
|
||||||
|
continue
|
||||||
|
|
||||||
|
// otherwise it is a fixed size & handshake related packet
|
||||||
|
|
||||||
|
case MessageInitiationType:
|
||||||
|
if len(packet) != MessageInitiationSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageResponseType:
|
||||||
|
if len(packet) != MessageResponseSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageCookieReplyType:
|
||||||
|
if len(packet) != MessageCookieReplySize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
device.log.Verbosef("Received message with unknown type")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if okay {
|
|
||||||
select {
|
select {
|
||||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||||
msgType: msgType,
|
msgType: msgType,
|
||||||
buffer: buffer,
|
buffer: bufsArrs[i],
|
||||||
packet: packet,
|
packet: packet,
|
||||||
endpoint: endpoint,
|
endpoint: endpoints[i],
|
||||||
}:
|
}:
|
||||||
buffer = device.GetMessageBuffer()
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
|
bufs[i] = bufsArrs[i][:]
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
|
for peer, elemsContainer := range elemsByPeer {
|
||||||
|
if peer.isRunning.Load() {
|
||||||
|
peer.queue.inbound.c <- elemsContainer
|
||||||
|
device.queue.decryption.c <- elemsContainer
|
||||||
|
} else {
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutInboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
|
}
|
||||||
|
delete(elemsByPeer, peer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) RoutineDecryption() {
|
func (device *Device) RoutineDecryption(id int) {
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
|
|
||||||
defer device.log.Verbosef("Routine: decryption worker - stopped")
|
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||||
device.log.Verbosef("Routine: decryption worker - started")
|
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||||
|
|
||||||
for elem := range device.queue.decryption.c {
|
for elemsContainer := range device.queue.decryption.c {
|
||||||
// split message into fields
|
for _, elem := range elemsContainer.elems {
|
||||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
// split message into fields
|
||||||
content := elem.packet[MessageTransportOffsetContent:]
|
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||||
|
content := elem.packet[MessageTransportOffsetContent:]
|
||||||
|
|
||||||
// decrypt and release to consumer
|
// decrypt and release to consumer
|
||||||
var err error
|
var err error
|
||||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||||
// copy counter to nonce
|
// copy counter to nonce
|
||||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||||
elem.packet, err = elem.keypair.receive.Open(
|
elem.packet, err = elem.keypair.receive.Open(
|
||||||
content[:0],
|
content[:0],
|
||||||
nonce[:],
|
nonce[:],
|
||||||
content,
|
content,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
elem.packet = nil
|
elem.packet = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
elem.Unlock()
|
elemsContainer.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Handles incoming packets related to handshake
|
/* Handles incoming packets related to handshake
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineHandshake() {
|
func (device *Device) RoutineHandshake(id int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: handshake worker - stopped")
|
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
|
||||||
device.queue.encryption.wg.Done()
|
device.queue.encryption.wg.Done()
|
||||||
}()
|
}()
|
||||||
device.log.Verbosef("Routine: handshake worker - started")
|
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
||||||
|
|
||||||
for elem := range device.queue.handshake.c {
|
for elem := range device.queue.handshake.c {
|
||||||
|
|
||||||
|
device.aSecMux.RLock()
|
||||||
|
|
||||||
// handle cookie fields and ratelimiting
|
// handle cookie fields and ratelimiting
|
||||||
|
|
||||||
switch elem.msgType {
|
switch elem.msgType {
|
||||||
|
@ -269,10 +332,15 @@ func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
// consume reply
|
// consume reply
|
||||||
|
|
||||||
if peer := entry.peer; peer.isRunning.Get() {
|
if peer := entry.peer; peer.isRunning.Load() {
|
||||||
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
device.log.Verbosef(
|
||||||
|
"Receiving cookie response from %s",
|
||||||
|
elem.endpoint.DstToString(),
|
||||||
|
)
|
||||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
||||||
device.log.Verbosef("Could not decrypt invalid cookie response")
|
device.log.Verbosef(
|
||||||
|
"Could not decrypt invalid cookie response",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -314,9 +382,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
switch elem.msgType {
|
switch elem.msgType {
|
||||||
case MessageInitiationType:
|
case MessageInitiationType:
|
||||||
|
|
||||||
// unmarshal
|
// unmarshal
|
||||||
|
|
||||||
var msg MessageInitiation
|
var msg MessageInitiation
|
||||||
reader := bytes.NewReader(elem.packet)
|
reader := bytes.NewReader(elem.packet)
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||||
|
@ -326,7 +392,6 @@ func (device *Device) RoutineHandshake() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume initiation
|
// consume initiation
|
||||||
|
|
||||||
peer := device.ConsumeMessageInitiation(&msg)
|
peer := device.ConsumeMessageInitiation(&msg)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||||
|
@ -342,7 +407,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||||
|
|
||||||
peer.SendHandshakeResponse()
|
peer.SendHandshakeResponse()
|
||||||
|
|
||||||
|
@ -370,7 +435,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
device.log.Verbosef("%v - Received handshake response", peer)
|
device.log.Verbosef("%v - Received handshake response", peer)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
|
|
||||||
|
@ -391,11 +456,12 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
skip:
|
skip:
|
||||||
|
device.aSecMux.RUnlock()
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) RoutineSequentialReceiver() {
|
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||||
device := peer.device
|
device := peer.device
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||||
|
@ -403,89 +469,109 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
}()
|
}()
|
||||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||||
|
|
||||||
for elem := range peer.queue.inbound.c {
|
bufs := make([][]byte, 0, maxBatchSize)
|
||||||
if elem == nil {
|
|
||||||
|
for elemsContainer := range peer.queue.inbound.c {
|
||||||
|
if elemsContainer == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err error
|
elemsContainer.Lock()
|
||||||
elem.Lock()
|
validTailPacket := -1
|
||||||
if elem.packet == nil {
|
dataPacketReceived := false
|
||||||
// decryption failed
|
rxBytesLen := uint64(0)
|
||||||
goto skip
|
for i, elem := range elemsContainer.elems {
|
||||||
}
|
if elem.packet == nil {
|
||||||
|
// decryption failed
|
||||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
continue
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
|
||||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
|
||||||
peer.timersHandshakeComplete()
|
|
||||||
peer.SendStagedPackets()
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.keepKeyFreshReceiving()
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
|
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
|
||||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
peer.timersDataReceived()
|
|
||||||
|
|
||||||
switch elem.packet[0] >> 4 {
|
|
||||||
case ipv4.Version:
|
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
|
||||||
length := binary.BigEndian.Uint16(field)
|
|
||||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
elem.packet = elem.packet[:length]
|
|
||||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
|
||||||
if device.allowedips.LookupIPv4(src) != peer {
|
|
||||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
|
||||||
goto skip
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case ipv6.Version:
|
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||||
if len(elem.packet) < ipv6.HeaderLen {
|
continue
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
|
||||||
length := binary.BigEndian.Uint16(field)
|
|
||||||
length += ipv6.HeaderLen
|
|
||||||
if int(length) > len(elem.packet) {
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
elem.packet = elem.packet[:length]
|
|
||||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
|
||||||
if device.allowedips.LookupIPv6(src) != peer {
|
|
||||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
|
||||||
goto skip
|
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
validTailPacket = i
|
||||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||||
goto skip
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
peer.timersHandshakeComplete()
|
||||||
|
peer.SendStagedPackets()
|
||||||
|
}
|
||||||
|
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
||||||
|
|
||||||
|
if len(elem.packet) == 0 {
|
||||||
|
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataPacketReceived = true
|
||||||
|
|
||||||
|
switch elem.packet[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||||
|
length := binary.BigEndian.Uint16(field)
|
||||||
|
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
elem.packet = elem.packet[:length]
|
||||||
|
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||||
|
if device.allowedips.Lookup(src) != peer {
|
||||||
|
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
case 6:
|
||||||
|
if len(elem.packet) < ipv6.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||||
|
length := binary.BigEndian.Uint16(field)
|
||||||
|
length += ipv6.HeaderLen
|
||||||
|
if int(length) > len(elem.packet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
elem.packet = elem.packet[:length]
|
||||||
|
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||||
|
if device.allowedips.Lookup(src) != peer {
|
||||||
|
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
device.log.Verbosef(
|
||||||
|
"Packet with invalid IP version from %v",
|
||||||
|
peer,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bufs = append(
|
||||||
|
bufs,
|
||||||
|
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
|
peer.rxBytes.Add(rxBytesLen)
|
||||||
if err != nil && !device.isClosed() {
|
if validTailPacket >= 0 {
|
||||||
device.log.Errorf("Failed to write packet to TUN device: %v", err)
|
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||||
|
peer.keepKeyFreshReceiving()
|
||||||
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
|
peer.timersAnyAuthenticatedPacketReceived()
|
||||||
}
|
}
|
||||||
if len(peer.queue.inbound.c) == 0 {
|
if dataPacketReceived {
|
||||||
err = device.tun.device.Flush()
|
peer.timersDataReceived()
|
||||||
if err != nil {
|
}
|
||||||
peer.device.log.Errorf("Unable to flush packets: %v", err)
|
if len(bufs) > 0 {
|
||||||
|
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||||
|
if err != nil && !device.isClosed() {
|
||||||
|
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
skip:
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutInboundElement(elem)
|
device.PutInboundElement(elem)
|
||||||
|
}
|
||||||
|
bufs = bufs[:0]
|
||||||
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
439
device/send.go
439
device/send.go
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -8,11 +8,14 @@ package device
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
@ -43,7 +46,6 @@ import (
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type QueueOutboundElement struct {
|
type QueueOutboundElement struct {
|
||||||
sync.Mutex
|
|
||||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||||
packet []byte // slice of "buffer" (always!)
|
packet []byte // slice of "buffer" (always!)
|
||||||
nonce uint64 // nonce for encryption
|
nonce uint64 // nonce for encryption
|
||||||
|
@ -51,10 +53,14 @@ type QueueOutboundElement struct {
|
||||||
peer *Peer // related peer
|
peer *Peer // related peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueueOutboundElementsContainer struct {
|
||||||
|
sync.Mutex
|
||||||
|
elems []*QueueOutboundElement
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||||
elem := device.GetOutboundElement()
|
elem := device.GetOutboundElement()
|
||||||
elem.buffer = device.GetMessageBuffer()
|
elem.buffer = device.GetMessageBuffer()
|
||||||
elem.Mutex = sync.Mutex{}
|
|
||||||
elem.nonce = 0
|
elem.nonce = 0
|
||||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||||
return elem
|
return elem
|
||||||
|
@ -74,14 +80,17 @@ func (elem *QueueOutboundElement) clearPointers() {
|
||||||
/* Queues a keepalive if no packets are queued for peer
|
/* Queues a keepalive if no packets are queued for peer
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) SendKeepalive() {
|
func (peer *Peer) SendKeepalive() {
|
||||||
if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
|
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||||
elem := peer.device.NewOutboundElement()
|
elem := peer.device.NewOutboundElement()
|
||||||
|
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||||
|
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||||
select {
|
select {
|
||||||
case peer.queue.staged <- elem:
|
case peer.queue.staged <- elemsContainer:
|
||||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||||
default:
|
default:
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutOutboundElement(elem)
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
peer.SendStagedPackets()
|
peer.SendStagedPackets()
|
||||||
|
@ -89,7 +98,7 @@ func (peer *Peer) SendKeepalive() {
|
||||||
|
|
||||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
if !isRetry {
|
if !isRetry {
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.handshake.mutex.RLock()
|
peer.handshake.mutex.RLock()
|
||||||
|
@ -114,17 +123,56 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var sendBuffer [][]byte
|
||||||
|
// so only packet processed for cookie generation
|
||||||
|
var junkedHeader []byte
|
||||||
|
if peer.device.isAdvancedSecurityOn() {
|
||||||
|
peer.device.aSecMux.RLock()
|
||||||
|
junks, err := peer.device.junkCreator.createJunkPackets()
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
|
|
||||||
var buff [MessageInitiationSize]byte
|
if err != nil {
|
||||||
writer := bytes.NewBuffer(buff[:0])
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(junks) > 0 {
|
||||||
|
err = peer.SendBuffers(junks)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.device.aSecMux.RLock()
|
||||||
|
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
||||||
|
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
||||||
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
|
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
||||||
|
if err != nil {
|
||||||
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
junkedHeader = writer.Bytes()
|
||||||
|
}
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf [MessageInitiationSize]byte
|
||||||
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, msg)
|
binary.Write(writer, binary.LittleEndian, msg)
|
||||||
packet := writer.Bytes()
|
packet := writer.Bytes()
|
||||||
peer.cookieGenerator.AddMacs(packet)
|
peer.cookieGenerator.AddMacs(packet)
|
||||||
|
junkedHeader = append(junkedHeader, packet...)
|
||||||
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
err = peer.SendBuffer(packet)
|
sendBuffer = append(sendBuffer, junkedHeader)
|
||||||
|
|
||||||
|
err = peer.SendBuffers(sendBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||||
}
|
}
|
||||||
|
@ -145,12 +193,29 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||||
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var junkedHeader []byte
|
||||||
|
if peer.device.isAdvancedSecurityOn() {
|
||||||
|
peer.device.aSecMux.RLock()
|
||||||
|
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
||||||
|
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
||||||
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
|
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
||||||
|
if err != nil {
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
junkedHeader = writer.Bytes()
|
||||||
|
}
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
|
}
|
||||||
|
var buf [MessageResponseSize]byte
|
||||||
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
|
|
||||||
var buff [MessageResponseSize]byte
|
|
||||||
writer := bytes.NewBuffer(buff[:0])
|
|
||||||
binary.Write(writer, binary.LittleEndian, response)
|
binary.Write(writer, binary.LittleEndian, response)
|
||||||
packet := writer.Bytes()
|
packet := writer.Bytes()
|
||||||
peer.cookieGenerator.AddMacs(packet)
|
peer.cookieGenerator.AddMacs(packet)
|
||||||
|
junkedHeader = append(junkedHeader, packet...)
|
||||||
|
|
||||||
err = peer.BeginSymmetricSession()
|
err = peer.BeginSymmetricSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -162,27 +227,35 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
err = peer.SendBuffer(packet)
|
// TODO: allocation could be avoided
|
||||||
|
err = peer.SendBuffers([][]byte{junkedHeader})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
func (device *Device) SendHandshakeCookie(
|
||||||
|
initiatingElem *QueueHandshakeElement,
|
||||||
|
) error {
|
||||||
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
||||||
|
|
||||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
||||||
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
|
reply, err := device.cookieChecker.CreateReply(
|
||||||
|
initiatingElem.packet,
|
||||||
|
sender,
|
||||||
|
initiatingElem.endpoint.DstToBytes(),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Errorf("Failed to create cookie reply: %v", err)
|
device.log.Errorf("Failed to create cookie reply: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var buff [MessageCookieReplySize]byte
|
var buf [MessageCookieReplySize]byte
|
||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
// TODO: allocation could be avoided
|
||||||
|
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,17 +264,12 @@ func (peer *Peer) keepKeyFreshSending() {
|
||||||
if keypair == nil {
|
if keypair == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nonce := atomic.LoadUint64(&keypair.sendNonce)
|
nonce := keypair.sendNonce.Load()
|
||||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Reads packets from the TUN and inserts
|
|
||||||
* into staged queue for peer
|
|
||||||
*
|
|
||||||
* Obs. Single instance per TUN device
|
|
||||||
*/
|
|
||||||
func (device *Device) RoutineReadFromTUN() {
|
func (device *Device) RoutineReadFromTUN() {
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||||
|
@ -211,80 +279,123 @@ func (device *Device) RoutineReadFromTUN() {
|
||||||
|
|
||||||
device.log.Verbosef("Routine: TUN reader - started")
|
device.log.Verbosef("Routine: TUN reader - started")
|
||||||
|
|
||||||
var elem *QueueOutboundElement
|
var (
|
||||||
|
batchSize = device.BatchSize()
|
||||||
|
readErr error
|
||||||
|
elems = make([]*QueueOutboundElement, batchSize)
|
||||||
|
bufs = make([][]byte, batchSize)
|
||||||
|
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||||
|
count = 0
|
||||||
|
sizes = make([]int, batchSize)
|
||||||
|
offset = MessageTransportHeaderSize
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := range elems {
|
||||||
|
elems[i] = device.NewOutboundElement()
|
||||||
|
bufs[i] = elems[i].buffer[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for _, elem := range elems {
|
||||||
|
if elem != nil {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if elem != nil {
|
// read packets
|
||||||
device.PutMessageBuffer(elem.buffer)
|
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
||||||
device.PutOutboundElement(elem)
|
for i := 0; i < count; i++ {
|
||||||
|
if sizes[i] < 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := elems[i]
|
||||||
|
elem.packet = bufs[i][offset : offset+sizes[i]]
|
||||||
|
|
||||||
|
// lookup peer
|
||||||
|
var peer *Peer
|
||||||
|
switch elem.packet[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||||
|
peer = device.allowedips.Lookup(dst)
|
||||||
|
|
||||||
|
case 6:
|
||||||
|
if len(elem.packet) < ipv6.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||||
|
peer = device.allowedips.Lookup(dst)
|
||||||
|
|
||||||
|
default:
|
||||||
|
device.log.Verbosef("Received packet with unknown IP version")
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
elemsForPeer, ok := elemsByPeer[peer]
|
||||||
|
if !ok {
|
||||||
|
elemsForPeer = device.GetOutboundElementsContainer()
|
||||||
|
elemsByPeer[peer] = elemsForPeer
|
||||||
|
}
|
||||||
|
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||||
|
elems[i] = device.NewOutboundElement()
|
||||||
|
bufs[i] = elems[i].buffer[:]
|
||||||
}
|
}
|
||||||
elem = device.NewOutboundElement()
|
|
||||||
|
|
||||||
// read packet
|
for peer, elemsForPeer := range elemsByPeer {
|
||||||
|
if peer.isRunning.Load() {
|
||||||
|
peer.StagePackets(elemsForPeer)
|
||||||
|
peer.SendStagedPackets()
|
||||||
|
} else {
|
||||||
|
for _, elem := range elemsForPeer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutOutboundElementsContainer(elemsForPeer)
|
||||||
|
}
|
||||||
|
delete(elemsByPeer, peer)
|
||||||
|
}
|
||||||
|
|
||||||
offset := MessageTransportHeaderSize
|
if readErr != nil {
|
||||||
size, err := device.tun.device.Read(elem.buffer[:], offset)
|
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||||
|
// TODO: record stat for this
|
||||||
if err != nil {
|
// This will happen if MSS is surprisingly small (< 576)
|
||||||
|
// coincident with reasonably high throughput.
|
||||||
|
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !device.isClosed() {
|
if !device.isClosed() {
|
||||||
device.log.Errorf("Failed to read packet from TUN device: %v", err)
|
if !errors.Is(readErr, os.ErrClosed) {
|
||||||
|
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||||
|
}
|
||||||
go device.Close()
|
go device.Close()
|
||||||
}
|
}
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if size == 0 || size > MaxContentSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elem.packet = elem.buffer[offset : offset+size]
|
|
||||||
|
|
||||||
// lookup peer
|
|
||||||
|
|
||||||
var peer *Peer
|
|
||||||
switch elem.packet[0] >> 4 {
|
|
||||||
case ipv4.Version:
|
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
|
||||||
peer = device.allowedips.LookupIPv4(dst)
|
|
||||||
|
|
||||||
case ipv6.Version:
|
|
||||||
if len(elem.packet) < ipv6.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
|
||||||
peer = device.allowedips.LookupIPv6(dst)
|
|
||||||
|
|
||||||
default:
|
|
||||||
device.log.Verbosef("Received packet with unknown IP version")
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if peer.isRunning.Get() {
|
|
||||||
peer.StagePacket(elem)
|
|
||||||
elem = nil
|
|
||||||
peer.SendStagedPackets()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
|
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case peer.queue.staged <- elem:
|
case peer.queue.staged <- elems:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case tooOld := <-peer.queue.staged:
|
case tooOld := <-peer.queue.staged:
|
||||||
peer.device.PutMessageBuffer(tooOld.buffer)
|
for _, elem := range tooOld.elems {
|
||||||
peer.device.PutOutboundElement(tooOld)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
|
peer.device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
peer.device.PutOutboundElementsContainer(tooOld)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -297,32 +408,59 @@ top:
|
||||||
}
|
}
|
||||||
|
|
||||||
keypair := peer.keypairs.Current()
|
keypair := peer.keypairs.Current()
|
||||||
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||||
select {
|
select {
|
||||||
case elem := <-peer.queue.staged:
|
case elemsContainer := <-peer.queue.staged:
|
||||||
elem.peer = peer
|
i := 0
|
||||||
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
|
for _, elem := range elemsContainer.elems {
|
||||||
if elem.nonce >= RejectAfterMessages {
|
elem.peer = peer
|
||||||
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
|
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||||
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
|
if elem.nonce >= RejectAfterMessages {
|
||||||
|
keypair.sendNonce.Store(RejectAfterMessages)
|
||||||
|
if elemsContainerOOO == nil {
|
||||||
|
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||||
|
}
|
||||||
|
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
elemsContainer.elems[i] = elem
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
elem.keypair = keypair
|
||||||
|
}
|
||||||
|
elemsContainer.Lock()
|
||||||
|
elemsContainer.elems = elemsContainer.elems[:i]
|
||||||
|
|
||||||
|
if elemsContainerOOO != nil {
|
||||||
|
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(elemsContainer.elems) == 0 {
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
goto top
|
goto top
|
||||||
}
|
}
|
||||||
|
|
||||||
elem.keypair = keypair
|
|
||||||
elem.Lock()
|
|
||||||
|
|
||||||
// add to parallel and sequential queue
|
// add to parallel and sequential queue
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Load() {
|
||||||
peer.queue.outbound.c <- elem
|
peer.queue.outbound.c <- elemsContainer
|
||||||
peer.device.queue.encryption.c <- elem
|
peer.device.queue.encryption.c <- elemsContainer
|
||||||
} else {
|
} else {
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
for _, elem := range elemsContainer.elems {
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
|
peer.device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if elemsContainerOOO != nil {
|
||||||
|
goto top
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
|
@ -333,9 +471,12 @@ top:
|
||||||
func (peer *Peer) FlushStagedPackets() {
|
func (peer *Peer) FlushStagedPackets() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elem := <-peer.queue.staged:
|
case elemsContainer := <-peer.queue.staged:
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
for _, elem := range elemsContainer.elems {
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
|
peer.device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -362,48 +503,45 @@ func calculatePaddingSize(packetSize, mtu int) int {
|
||||||
*
|
*
|
||||||
* Obs. One instance per core
|
* Obs. One instance per core
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineEncryption() {
|
func (device *Device) RoutineEncryption(id int) {
|
||||||
var paddingZeros [PaddingMultiple]byte
|
var paddingZeros [PaddingMultiple]byte
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
|
|
||||||
defer device.log.Verbosef("Routine: encryption worker - stopped")
|
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||||
device.log.Verbosef("Routine: encryption worker - started")
|
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||||
|
|
||||||
for elem := range device.queue.encryption.c {
|
for elemsContainer := range device.queue.encryption.c {
|
||||||
// populate header fields
|
for _, elem := range elemsContainer.elems {
|
||||||
header := elem.buffer[:MessageTransportHeaderSize]
|
// populate header fields
|
||||||
|
header := elem.buffer[:MessageTransportHeaderSize]
|
||||||
|
|
||||||
fieldType := header[0:4]
|
fieldType := header[0:4]
|
||||||
fieldReceiver := header[4:8]
|
fieldReceiver := header[4:8]
|
||||||
fieldNonce := header[8:16]
|
fieldNonce := header[8:16]
|
||||||
|
|
||||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||||
|
|
||||||
// pad content to multiple of 16
|
// pad content to multiple of 16
|
||||||
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
|
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||||
|
|
||||||
// encrypt content and release to consumer
|
// encrypt content and release to consumer
|
||||||
|
|
||||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||||
elem.packet = elem.keypair.send.Seal(
|
elem.packet = elem.keypair.send.Seal(
|
||||||
header,
|
header,
|
||||||
nonce[:],
|
nonce[:],
|
||||||
elem.packet,
|
elem.packet,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
elem.Unlock()
|
}
|
||||||
|
elemsContainer.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Sequentially reads packets from queue and sends to endpoint
|
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||||
*
|
|
||||||
* Obs. Single instance per peer.
|
|
||||||
* The routine terminates then the outbound queue is closed.
|
|
||||||
*/
|
|
||||||
func (peer *Peer) RoutineSequentialSender() {
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
defer func() {
|
defer func() {
|
||||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||||
|
@ -411,36 +549,57 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||||
}()
|
}()
|
||||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||||
|
|
||||||
for elem := range peer.queue.outbound.c {
|
bufs := make([][]byte, 0, maxBatchSize)
|
||||||
if elem == nil {
|
|
||||||
|
for elemsContainer := range peer.queue.outbound.c {
|
||||||
|
bufs = bufs[:0]
|
||||||
|
if elemsContainer == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
elem.Lock()
|
if !peer.isRunning.Load() {
|
||||||
if !peer.isRunning.Get() {
|
|
||||||
// peer has been stopped; return re-usable elems to the shared pool.
|
// peer has been stopped; return re-usable elems to the shared pool.
|
||||||
// This is an optimization only. It is possible for the peer to be stopped
|
// This is an optimization only. It is possible for the peer to be stopped
|
||||||
// immediately after this check, in which case, elem will get processed.
|
// immediately after this check, in which case, elem will get processed.
|
||||||
// The timers and SendBuffer code are resilient to a few stragglers.
|
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||||
// TODO: rework peer shutdown order to ensure
|
// TODO: rework peer shutdown order to ensure
|
||||||
// that we never accidentally keep timers alive longer than necessary.
|
// that we never accidentally keep timers alive longer than necessary.
|
||||||
device.PutMessageBuffer(elem.buffer)
|
elemsContainer.Lock()
|
||||||
device.PutOutboundElement(elem)
|
for _, elem := range elemsContainer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
dataSent := false
|
||||||
|
elemsContainer.Lock()
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
if len(elem.packet) != MessageKeepaliveSize {
|
||||||
|
dataSent = true
|
||||||
|
}
|
||||||
|
bufs = append(bufs, elem.packet)
|
||||||
|
}
|
||||||
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
// send message and return buffer to pool
|
err := peer.SendBuffers(bufs)
|
||||||
|
if dataSent {
|
||||||
err := peer.SendBuffer(elem.packet)
|
|
||||||
if len(elem.packet) != MessageKeepaliveSize {
|
|
||||||
peer.timersDataSent()
|
peer.timersDataSent()
|
||||||
}
|
}
|
||||||
device.PutMessageBuffer(elem.buffer)
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutOutboundElement(elem)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutOutboundElementsContainer(elemsContainer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
|
var errGSO conn.ErrUDPGSODisabled
|
||||||
|
if errors.As(err, &errGSO) {
|
||||||
|
device.log.Verbosef(err.Error())
|
||||||
|
err = errGSO.RetryErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
// +build !linux
|
//go:build !linux
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*
|
*
|
||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
|
@ -20,12 +20,15 @@ import (
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
if _, ok := bind.(*conn.LinuxSocketBind); !ok {
|
if !conn.StdNetSupportsStickySockets {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if _, ok := bind.(*conn.StdNetBind); !ok {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
if !ok {
|
if !ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
pePtr.peer.Lock()
|
pePtr.peer.endpoint.Lock()
|
||||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||||
pePtr.peer.Unlock()
|
pePtr.peer.endpoint.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
|
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||||
pePtr.peer.Unlock()
|
pePtr.peer.endpoint.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
|
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||||
pePtr.peer.Unlock()
|
pePtr.peer.endpoint.Unlock()
|
||||||
}
|
}
|
||||||
attr = attr[attrhdr.Len:]
|
attr = attr[attrhdr.Len:]
|
||||||
}
|
}
|
||||||
|
@ -131,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
i := uint32(1)
|
i := uint32(1)
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.RLock()
|
peer.endpoint.Lock()
|
||||||
if peer.endpoint == nil {
|
if peer.endpoint.val == nil {
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
|
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||||
if nativeEP == nil {
|
if nativeEP == nil {
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
|
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
nlmsg := struct {
|
nlmsg := struct {
|
||||||
|
@ -169,12 +172,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
Len: 8,
|
Len: 8,
|
||||||
Type: unix.RTA_DST,
|
Type: unix.RTA_DST,
|
||||||
},
|
},
|
||||||
nativeEP.Dst4().Addr,
|
nativeEP.DstIP().As4(),
|
||||||
unix.RtAttr{
|
unix.RtAttr{
|
||||||
Len: 8,
|
Len: 8,
|
||||||
Type: unix.RTA_SRC,
|
Type: unix.RTA_SRC,
|
||||||
},
|
},
|
||||||
nativeEP.Src4().Src,
|
nativeEP.SrcIP().As4(),
|
||||||
unix.RtAttr{
|
unix.RtAttr{
|
||||||
Len: 8,
|
Len: 8,
|
||||||
Type: unix.RTA_MARK,
|
Type: unix.RTA_MARK,
|
||||||
|
@ -185,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
reqPeerLock.Lock()
|
reqPeerLock.Lock()
|
||||||
reqPeer[i] = peerEndpointPtr{
|
reqPeer[i] = peerEndpointPtr{
|
||||||
peer: peer,
|
peer: peer,
|
||||||
endpoint: &peer.endpoint,
|
endpoint: &peer.endpoint.val,
|
||||||
}
|
}
|
||||||
reqPeerLock.Unlock()
|
reqPeerLock.Unlock()
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
i++
|
i++
|
||||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -204,7 +207,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNetlinkRouteSocket() (int, error) {
|
func createNetlinkRouteSocket() (int, error) {
|
||||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*
|
*
|
||||||
* This is based heavily on timers.c from the kernel implementation.
|
* This is based heavily on timers.c from the kernel implementation.
|
||||||
*/
|
*/
|
||||||
|
@ -8,12 +8,14 @@
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
_ "unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:linkname fastrandn runtime.fastrandn
|
||||||
|
func fastrandn(n uint32) uint32
|
||||||
|
|
||||||
// A Timer manages time-based aspects of the WireGuard protocol.
|
// A Timer manages time-based aspects of the WireGuard protocol.
|
||||||
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
|
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
|
||||||
type Timer struct {
|
type Timer struct {
|
||||||
|
@ -71,11 +73,11 @@ func (timer *Timer) IsPending() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersActive() bool {
|
func (peer *Peer) timersActive() bool {
|
||||||
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
|
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredRetransmitHandshake(peer *Peer) {
|
func expiredRetransmitHandshake(peer *Peer) {
|
||||||
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
|
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||||
|
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
|
@ -94,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
|
||||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
|
peer.timers.handshakeAttempts.Add(1)
|
||||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||||
|
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
peer.Unlock()
|
|
||||||
|
|
||||||
peer.SendHandshakeInitiation(true)
|
peer.SendHandshakeInitiation(true)
|
||||||
}
|
}
|
||||||
|
@ -110,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
|
||||||
|
|
||||||
func expiredSendKeepalive(peer *Peer) {
|
func expiredSendKeepalive(peer *Peer) {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
if peer.timers.needAnotherKeepalive.Get() {
|
if peer.timers.needAnotherKeepalive.Load() {
|
||||||
peer.timers.needAnotherKeepalive.Set(false)
|
peer.timers.needAnotherKeepalive.Store(false)
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||||
}
|
}
|
||||||
|
@ -121,13 +119,8 @@ func expiredSendKeepalive(peer *Peer) {
|
||||||
func expiredNewHandshake(peer *Peer) {
|
func expiredNewHandshake(peer *Peer) {
|
||||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
peer.Unlock()
|
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredZeroKeyMaterial(peer *Peer) {
|
func expiredZeroKeyMaterial(peer *Peer) {
|
||||||
|
@ -136,7 +129,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredPersistentKeepalive(peer *Peer) {
|
func expiredPersistentKeepalive(peer *Peer) {
|
||||||
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -144,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
|
||||||
/* Should be called after an authenticated data packet is sent. */
|
/* Should be called after an authenticated data packet is sent. */
|
||||||
func (peer *Peer) timersDataSent() {
|
func (peer *Peer) timersDataSent() {
|
||||||
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
||||||
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
|
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
|
||||||
if !peer.timers.sendKeepalive.IsPending() {
|
if !peer.timers.sendKeepalive.IsPending() {
|
||||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||||
} else {
|
} else {
|
||||||
peer.timers.needAnotherKeepalive.Set(true)
|
peer.timers.needAnotherKeepalive.Store(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -176,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
|
||||||
/* Should be called after a handshake initiation message is sent. */
|
/* Should be called after a handshake initiation message is sent. */
|
||||||
func (peer *Peer) timersHandshakeInitiated() {
|
func (peer *Peer) timersHandshakeInitiated() {
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
|
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.retransmitHandshake.Del()
|
peer.timers.retransmitHandshake.Del()
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||||
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
|
peer.lastHandshakeNano.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
||||||
|
@ -199,7 +192,7 @@ func (peer *Peer) timersSessionDerived() {
|
||||||
|
|
||||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||||
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
|
keepalive := peer.persistentKeepaliveInterval.Load()
|
||||||
if keepalive > 0 && peer.timersActive() {
|
if keepalive > 0 && peer.timersActive() {
|
||||||
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||||
}
|
}
|
||||||
|
@ -214,9 +207,9 @@ func (peer *Peer) timersInit() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersStart() {
|
func (peer *Peer) timersStart() {
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||||
peer.timers.needAnotherKeepalive.Set(false)
|
peer.timers.needAnotherKeepalive.Store(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersStop() {
|
func (peer *Peer) timersStop() {
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1420
|
const DefaultMTU = 1420
|
||||||
|
@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
|
||||||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
||||||
mtu = MaxContentSize
|
mtu = MaxContentSize
|
||||||
}
|
}
|
||||||
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
|
old := device.tun.mtu.Swap(int32(mtu))
|
||||||
if int(old) != mtu {
|
if int(old) != mtu {
|
||||||
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
||||||
}
|
}
|
||||||
|
|
216
device/uapi.go
216
device/uapi.go
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
@ -12,13 +12,13 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPCError struct {
|
type IPCError struct {
|
||||||
|
@ -38,12 +38,12 @@ func (s IPCError) ErrorCode() int64 {
|
||||||
return s.code
|
return s.code
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
|
func ipcErrorf(code int64, msg string, args ...any) *IPCError {
|
||||||
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
var byteBufferPool = &sync.Pool{
|
var byteBufferPool = &sync.Pool{
|
||||||
New: func() interface{} { return new(bytes.Buffer) },
|
New: func() any { return new(bytes.Buffer) },
|
||||||
}
|
}
|
||||||
|
|
||||||
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
||||||
|
@ -55,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
buf := byteBufferPool.Get().(*bytes.Buffer)
|
buf := byteBufferPool.Get().(*bytes.Buffer)
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
defer byteBufferPool.Put(buf)
|
defer byteBufferPool.Put(buf)
|
||||||
sendf := func(format string, args ...interface{}) {
|
sendf := func(format string, args ...any) {
|
||||||
fmt.Fprintf(buf, format, args...)
|
fmt.Fprintf(buf, format, args...)
|
||||||
buf.WriteByte('\n')
|
buf.WriteByte('\n')
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
|
||||||
// lock required resources
|
// lock required resources
|
||||||
|
|
||||||
device.net.RLock()
|
device.net.RLock()
|
||||||
|
@ -98,31 +97,61 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
sendf("fwmark=%d", device.net.fwmark)
|
sendf("fwmark=%d", device.net.fwmark)
|
||||||
}
|
}
|
||||||
|
|
||||||
// serialize each peer state
|
if device.isAdvancedSecurityOn() {
|
||||||
|
if device.aSecCfg.junkPacketCount != 0 {
|
||||||
|
sendf("jc=%d", device.aSecCfg.junkPacketCount)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.junkPacketMinSize != 0 {
|
||||||
|
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.junkPacketMaxSize != 0 {
|
||||||
|
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.initPacketJunkSize != 0 {
|
||||||
|
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.responsePacketJunkSize != 0 {
|
||||||
|
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.initPacketMagicHeader != 0 {
|
||||||
|
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.responsePacketMagicHeader != 0 {
|
||||||
|
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.underloadPacketMagicHeader != 0 {
|
||||||
|
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.transportPacketMagicHeader != 0 {
|
||||||
|
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.RLock()
|
// Serialize peer state.
|
||||||
defer peer.RUnlock()
|
peer.handshake.mutex.RLock()
|
||||||
|
|
||||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||||
|
peer.handshake.mutex.RUnlock()
|
||||||
sendf("protocol_version=1")
|
sendf("protocol_version=1")
|
||||||
if peer.endpoint != nil {
|
peer.endpoint.Lock()
|
||||||
sendf("endpoint=%s", peer.endpoint.DstToString())
|
if peer.endpoint.val != nil {
|
||||||
|
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||||
}
|
}
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
nano := peer.lastHandshakeNano.Load()
|
||||||
secs := nano / time.Second.Nanoseconds()
|
secs := nano / time.Second.Nanoseconds()
|
||||||
nano %= time.Second.Nanoseconds()
|
nano %= time.Second.Nanoseconds()
|
||||||
|
|
||||||
sendf("last_handshake_time_sec=%d", secs)
|
sendf("last_handshake_time_sec=%d", secs)
|
||||||
sendf("last_handshake_time_nsec=%d", nano)
|
sendf("last_handshake_time_nsec=%d", nano)
|
||||||
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
|
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
|
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||||
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
sendf("allowed_ip=%s", prefix.String())
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -151,19 +180,27 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
peer := new(ipcSetPeer)
|
peer := new(ipcSetPeer)
|
||||||
deviceConfig := true
|
deviceConfig := true
|
||||||
|
|
||||||
|
tempASecCfg := aSecCfgType{}
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if line == "" {
|
if line == "" {
|
||||||
// Blank line means terminate operation.
|
// Blank line means terminate operation.
|
||||||
|
err := device.handlePostConfig(&tempASecCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peer.handlePostConfig()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
parts := strings.Split(line, "=")
|
key, value, ok := strings.Cut(line, "=")
|
||||||
if len(parts) != 2 {
|
if !ok {
|
||||||
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorProtocol,
|
||||||
|
"failed to parse line %q",
|
||||||
|
line,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
key := parts[0]
|
|
||||||
value := parts[1]
|
|
||||||
|
|
||||||
if key == "public_key" {
|
if key == "public_key" {
|
||||||
if deviceConfig {
|
if deviceConfig {
|
||||||
|
@ -180,7 +217,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if deviceConfig {
|
if deviceConfig {
|
||||||
err = device.handleDeviceLine(key, value)
|
err = device.handleDeviceLine(key, value, &tempASecCfg)
|
||||||
} else {
|
} else {
|
||||||
err = device.handlePeerLine(peer, key, value)
|
err = device.handlePeerLine(peer, key, value)
|
||||||
}
|
}
|
||||||
|
@ -188,6 +225,10 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
err = device.handlePostConfig(&tempASecCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
peer.handlePostConfig()
|
peer.handlePostConfig()
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
|
@ -196,7 +237,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) handleDeviceLine(key, value string) error {
|
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
|
||||||
switch key {
|
switch key {
|
||||||
case "private_key":
|
case "private_key":
|
||||||
var sk NoisePrivateKey
|
var sk NoisePrivateKey
|
||||||
|
@ -242,6 +283,83 @@ func (device *Device) handleDeviceLine(key, value string) error {
|
||||||
device.log.Verbosef("UAPI: Removing all peers")
|
device.log.Verbosef("UAPI: Removing all peers")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
|
||||||
|
case "jc":
|
||||||
|
junkPacketCount, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||||
|
tempASecCfg.junkPacketCount = junkPacketCount
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "jmin":
|
||||||
|
junkPacketMinSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||||
|
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "jmax":
|
||||||
|
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||||
|
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "s1":
|
||||||
|
initPacketJunkSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||||
|
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "s2":
|
||||||
|
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||||
|
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "h1":
|
||||||
|
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
|
||||||
|
}
|
||||||
|
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "h2":
|
||||||
|
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
|
||||||
|
}
|
||||||
|
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "h3":
|
||||||
|
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
|
||||||
|
}
|
||||||
|
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
|
case "h4":
|
||||||
|
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
|
||||||
|
}
|
||||||
|
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||||
}
|
}
|
||||||
|
@ -254,15 +372,29 @@ type ipcSetPeer struct {
|
||||||
*Peer // Peer is the current peer being operated on
|
*Peer // Peer is the current peer being operated on
|
||||||
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
||||||
created bool // new reports whether this is a newly created peer
|
created bool // new reports whether this is a newly created peer
|
||||||
|
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *ipcSetPeer) handlePostConfig() {
|
func (peer *ipcSetPeer) handlePostConfig() {
|
||||||
if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() {
|
if peer.Peer == nil || peer.dummy {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if peer.created {
|
||||||
|
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||||
|
}
|
||||||
|
if peer.device.isUp() {
|
||||||
|
peer.Start()
|
||||||
|
if peer.pkaOn {
|
||||||
|
peer.SendKeepalive()
|
||||||
|
}
|
||||||
peer.SendStagedPackets()
|
peer.SendStagedPackets()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
func (device *Device) handlePublicKeyLine(
|
||||||
|
peer *ipcSetPeer,
|
||||||
|
value string,
|
||||||
|
) error {
|
||||||
// Load/create the peer we are configuring.
|
// Load/create the peer we are configuring.
|
||||||
var publicKey NoisePublicKey
|
var publicKey NoisePublicKey
|
||||||
err := publicKey.FromHex(value)
|
err := publicKey.FromHex(value)
|
||||||
|
@ -292,7 +424,10 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
func (device *Device) handlePeerLine(
|
||||||
|
peer *ipcSetPeer,
|
||||||
|
key, value string,
|
||||||
|
) error {
|
||||||
switch key {
|
switch key {
|
||||||
case "update_only":
|
case "update_only":
|
||||||
// allow disabling of creation
|
// allow disabling of creation
|
||||||
|
@ -334,9 +469,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.endpoint.Lock()
|
||||||
defer peer.Unlock()
|
defer peer.endpoint.Unlock()
|
||||||
peer.endpoint = endpoint
|
peer.endpoint.val = endpoint
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||||
|
@ -346,17 +481,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||||
|
|
||||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||||
if old == 0 && secs != 0 {
|
peer.pkaOn = old == 0 && secs != 0
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
|
|
||||||
}
|
|
||||||
if device.isUp() && !peer.dummy {
|
|
||||||
peer.SendKeepalive()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "replace_allowed_ips":
|
case "replace_allowed_ips":
|
||||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||||
|
@ -370,16 +498,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
|
|
||||||
case "allowed_ip":
|
case "allowed_ip":
|
||||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||||
|
prefix, err := netip.ParsePrefix(value)
|
||||||
_, network, err := net.ParseCIDR(value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||||
}
|
}
|
||||||
if peer.dummy {
|
if peer.dummy {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
ones, _ := network.Mask.Size()
|
device.allowedips.Insert(prefix, peer.Peer)
|
||||||
device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
|
|
||||||
|
|
||||||
case "protocol_version":
|
case "protocol_version":
|
||||||
if value != "1" {
|
if value != "1" {
|
||||||
|
|
51
format_test.go
Normal file
51
format_test.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"go/format"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatting(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to walk %s: %v", path, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if d.IsDir() || filepath.Ext(path) != ".go" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(path string) {
|
||||||
|
defer wg.Done()
|
||||||
|
src, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to read %s: %v", path, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
|
||||||
|
}
|
||||||
|
formatted, err := format.Source(src)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to format %s: %v", path, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(src, formatted) {
|
||||||
|
t.Errorf("unformatted code: %s", path)
|
||||||
|
}
|
||||||
|
}(path)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
}
|
18
go.mod
18
go.mod
|
@ -1,9 +1,17 @@
|
||||||
module golang.zx2c4.com/wireguard
|
module github.com/amnezia-vpn/amneziawg-go
|
||||||
|
|
||||||
go 1.16
|
go 1.24
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
github.com/tevino/abool/v2 v2.1.0
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
|
golang.org/x/crypto v0.36.0
|
||||||
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169
|
golang.org/x/net v0.37.0
|
||||||
|
golang.org/x/sys v0.31.0
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/btree v1.1.3 // indirect
|
||||||
|
golang.org/x/time v0.9.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
36
go.sum
36
go.sum
|
@ -1,16 +1,20 @@
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
|
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
|
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||||
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169 h1:fpeMGRM6A+XFcw4RPCO8s8hH7ppgrGR22pSIjwM7YUI=
|
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||||
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
// +build windows
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
//go:build windows
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
package namedpipe
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
@ -22,8 +21,10 @@ import (
|
||||||
|
|
||||||
type timeoutChan chan struct{}
|
type timeoutChan chan struct{}
|
||||||
|
|
||||||
var ioInitOnce sync.Once
|
var (
|
||||||
var ioCompletionPort windows.Handle
|
ioInitOnce sync.Once
|
||||||
|
ioCompletionPort windows.Handle
|
||||||
|
)
|
||||||
|
|
||||||
// ioResult contains the result of an asynchronous IO operation
|
// ioResult contains the result of an asynchronous IO operation
|
||||||
type ioResult struct {
|
type ioResult struct {
|
||||||
|
@ -52,7 +53,7 @@ type file struct {
|
||||||
handle windows.Handle
|
handle windows.Handle
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
wgLock sync.RWMutex
|
wgLock sync.RWMutex
|
||||||
closing uint32 // used as atomic boolean
|
closing atomic.Bool
|
||||||
socket bool
|
socket bool
|
||||||
readDeadline deadlineHandler
|
readDeadline deadlineHandler
|
||||||
writeDeadline deadlineHandler
|
writeDeadline deadlineHandler
|
||||||
|
@ -63,7 +64,7 @@ type deadlineHandler struct {
|
||||||
channel timeoutChan
|
channel timeoutChan
|
||||||
channelLock sync.RWMutex
|
channelLock sync.RWMutex
|
||||||
timer *time.Timer
|
timer *time.Timer
|
||||||
timedout uint32 // used as atomic boolean
|
timedout atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeFile makes a new file from an existing file handle
|
// makeFile makes a new file from an existing file handle
|
||||||
|
@ -87,7 +88,7 @@ func makeFile(h windows.Handle) (*file, error) {
|
||||||
func (f *file) closeHandle() {
|
func (f *file) closeHandle() {
|
||||||
f.wgLock.Lock()
|
f.wgLock.Lock()
|
||||||
// Atomically set that we are closing, releasing the resources only once.
|
// Atomically set that we are closing, releasing the resources only once.
|
||||||
if atomic.SwapUint32(&f.closing, 1) == 0 {
|
if f.closing.Swap(true) == false {
|
||||||
f.wgLock.Unlock()
|
f.wgLock.Unlock()
|
||||||
// cancel all IO and wait for it to complete
|
// cancel all IO and wait for it to complete
|
||||||
windows.CancelIoEx(f.handle, nil)
|
windows.CancelIoEx(f.handle, nil)
|
||||||
|
@ -110,7 +111,7 @@ func (f *file) Close() error {
|
||||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||||
func (f *file) prepareIo() (*ioOperation, error) {
|
func (f *file) prepareIo() (*ioOperation, error) {
|
||||||
f.wgLock.RLock()
|
f.wgLock.RLock()
|
||||||
if atomic.LoadUint32(&f.closing) == 1 {
|
if f.closing.Load() {
|
||||||
f.wgLock.RUnlock()
|
f.wgLock.RUnlock()
|
||||||
return nil, os.ErrClosed
|
return nil, os.ErrClosed
|
||||||
}
|
}
|
||||||
|
@ -142,7 +143,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
|
||||||
return int(bytes), err
|
return int(bytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
if atomic.LoadUint32(&f.closing) == 1 {
|
if f.closing.Load() {
|
||||||
windows.CancelIoEx(f.handle, &c.o)
|
windows.CancelIoEx(f.handle, &c.o)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +159,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
|
||||||
case r = <-c.ch:
|
case r = <-c.ch:
|
||||||
err = r.err
|
err = r.err
|
||||||
if err == windows.ERROR_OPERATION_ABORTED {
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
if atomic.LoadUint32(&f.closing) == 1 {
|
if f.closing.Load() {
|
||||||
err = os.ErrClosed
|
err = os.ErrClosed
|
||||||
}
|
}
|
||||||
} else if err != nil && f.socket {
|
} else if err != nil && f.socket {
|
||||||
|
@ -190,7 +191,7 @@ func (f *file) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
defer f.wg.Done()
|
defer f.wg.Done()
|
||||||
|
|
||||||
if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
|
if f.readDeadline.timedout.Load() {
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,7 +218,7 @@ func (f *file) Write(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
defer f.wg.Done()
|
defer f.wg.Done()
|
||||||
|
|
||||||
if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
|
if f.writeDeadline.timedout.Load() {
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -254,7 +255,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
|
||||||
}
|
}
|
||||||
d.timer = nil
|
d.timer = nil
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&d.timedout, 0)
|
d.timedout.Store(false)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-d.channel:
|
case <-d.channel:
|
||||||
|
@ -269,7 +270,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
timeoutIO := func() {
|
timeoutIO := func() {
|
||||||
atomic.StoreUint32(&d.timedout, 1)
|
d.timedout.Store(true)
|
||||||
close(d.channel)
|
close(d.channel)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
// +build windows
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
//go:build windows
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Package winpipe implements a net.Conn and net.Listener around Windows named pipes.
|
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
||||||
package winpipe
|
package namedpipe
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -15,6 +14,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ type pipe struct {
|
||||||
|
|
||||||
type messageBytePipe struct {
|
type messageBytePipe struct {
|
||||||
pipe
|
pipe
|
||||||
writeClosed bool
|
writeClosed atomic.Bool
|
||||||
readEOF bool
|
readEOF bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,25 +50,26 @@ func (f *pipe) SetDeadline(t time.Time) error {
|
||||||
|
|
||||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||||
func (f *messageBytePipe) CloseWrite() error {
|
func (f *messageBytePipe) CloseWrite() error {
|
||||||
if f.writeClosed {
|
if !f.writeClosed.CompareAndSwap(false, true) {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
err := f.file.Flush()
|
err := f.file.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
f.writeClosed.Store(false)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = f.file.Write(nil)
|
_, err = f.file.Write(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
f.writeClosed.Store(false)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
f.writeClosed = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||||
// they are used to implement CloseWrite.
|
// they are used to implement CloseWrite.
|
||||||
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
||||||
if f.writeClosed {
|
if f.writeClosed.Load() {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
|
@ -142,30 +143,24 @@ type DialConfig struct {
|
||||||
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the specified named pipe by path, timing out if the connection
|
// DialTimeout connects to the specified named pipe by path, timing out if the
|
||||||
// takes longer than the specified duration. If timeout is nil, then we use
|
// connection takes longer than the specified duration. If timeout is zero, then
|
||||||
// a default timeout of 2 seconds.
|
// we use a default timeout of 2 seconds.
|
||||||
func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) {
|
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
var absTimeout time.Time
|
if timeout == 0 {
|
||||||
if timeout != nil {
|
timeout = time.Second * 2
|
||||||
absTimeout = time.Now().Add(*timeout)
|
|
||||||
} else {
|
|
||||||
absTimeout = time.Now().Add(2 * time.Second)
|
|
||||||
}
|
}
|
||||||
|
absTimeout := time.Now().Add(timeout)
|
||||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||||
conn, err := DialContext(ctx, path, config)
|
conn, err := config.DialContext(ctx, path)
|
||||||
if err == context.DeadlineExceeded {
|
if err == context.DeadlineExceeded {
|
||||||
return nil, os.ErrDeadlineExceeded
|
return nil, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext attempts to connect to the specified named pipe by path
|
// DialContext attempts to connect to the specified named pipe by path.
|
||||||
// cancellation or timeout.
|
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) {
|
|
||||||
if config == nil {
|
|
||||||
config = &DialConfig{}
|
|
||||||
}
|
|
||||||
var err error
|
var err error
|
||||||
var h windows.Handle
|
var h windows.Handle
|
||||||
h, err = tryDialPipe(ctx, &path)
|
h, err = tryDialPipe(ctx, &path)
|
||||||
|
@ -213,6 +208,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn
|
||||||
return &pipe{file: f, path: path}, nil
|
return &pipe{file: f, path: path}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultDialer DialConfig
|
||||||
|
|
||||||
|
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
|
||||||
|
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return defaultDialer.DialTimeout(path, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext calls DialConfig.DialContext using an empty configuration.
|
||||||
|
func DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
|
return defaultDialer.DialContext(ctx, path)
|
||||||
|
}
|
||||||
|
|
||||||
type acceptResponse struct {
|
type acceptResponse struct {
|
||||||
f *file
|
f *file
|
||||||
err error
|
err error
|
||||||
|
@ -222,12 +229,12 @@ type pipeListener struct {
|
||||||
firstHandle windows.Handle
|
firstHandle windows.Handle
|
||||||
path string
|
path string
|
||||||
config ListenConfig
|
config ListenConfig
|
||||||
acceptCh chan (chan acceptResponse)
|
acceptCh chan chan acceptResponse
|
||||||
closeCh chan int
|
closeCh chan int
|
||||||
doneCh chan int
|
doneCh chan int
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) {
|
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
|
||||||
path16, err := windows.UTF16PtrFromString(path)
|
path16, err := windows.UTF16PtrFromString(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
@ -247,7 +254,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
|
||||||
oa.ObjectName = &ntPath
|
oa.ObjectName = &ntPath
|
||||||
|
|
||||||
// The security descriptor is only needed for the first pipe.
|
// The security descriptor is only needed for the first pipe.
|
||||||
if first {
|
if isFirstPipe {
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
oa.SecurityDescriptor = sd
|
oa.SecurityDescriptor = sd
|
||||||
} else {
|
} else {
|
||||||
|
@ -257,7 +264,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
||||||
sd, err := windows.NewSecurityDescriptor()
|
sd, err = windows.NewSecurityDescriptor()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -275,11 +282,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
|
||||||
|
|
||||||
disposition := uint32(windows.FILE_OPEN)
|
disposition := uint32(windows.FILE_OPEN)
|
||||||
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||||
if first {
|
if isFirstPipe {
|
||||||
disposition = windows.FILE_CREATE
|
disposition = windows.FILE_CREATE
|
||||||
// By not asking for read or write access, the named pipe file system
|
// By not asking for read or write access, the named pipe file system
|
||||||
// will put this pipe into an initially disconnected state, blocking
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
// client connections until the next call with first == false.
|
// client connections until the next call with isFirstPipe == false.
|
||||||
access = windows.SYNCHRONIZE
|
access = windows.SYNCHRONIZE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,10 +402,7 @@ type ListenConfig struct {
|
||||||
|
|
||||||
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
||||||
// The pipe must not already exist.
|
// The pipe must not already exist.
|
||||||
func Listen(path string, c *ListenConfig) (net.Listener, error) {
|
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
|
||||||
if c == nil {
|
|
||||||
c = &ListenConfig{}
|
|
||||||
}
|
|
||||||
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -407,12 +411,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
|
||||||
firstHandle: h,
|
firstHandle: h,
|
||||||
path: path,
|
path: path,
|
||||||
config: *c,
|
config: *c,
|
||||||
acceptCh: make(chan (chan acceptResponse)),
|
acceptCh: make(chan chan acceptResponse),
|
||||||
closeCh: make(chan int),
|
closeCh: make(chan int),
|
||||||
doneCh: make(chan int),
|
doneCh: make(chan int),
|
||||||
}
|
}
|
||||||
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
||||||
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
|
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
|
||||||
path16, err := windows.UTF16PtrFromString(path)
|
path16, err := windows.UTF16PtrFromString(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
||||||
|
@ -425,6 +429,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
|
||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultListener ListenConfig
|
||||||
|
|
||||||
|
// Listen calls ListenConfig.Listen using an empty configuration.
|
||||||
|
func Listen(path string) (net.Listener, error) {
|
||||||
|
return defaultListener.Listen(path)
|
||||||
|
}
|
||||||
|
|
||||||
func connectPipe(p *file) error {
|
func connectPipe(p *file) error {
|
||||||
c, err := p.prepareIo()
|
c, err := p.prepareIo()
|
||||||
if err != nil {
|
if err != nil {
|
|
@ -1,12 +1,11 @@
|
||||||
// +build windows
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
//go:build windows
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe_test
|
package namedpipe_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -21,8 +20,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func randomPipePath() string {
|
func randomPipePath() string {
|
||||||
|
@ -30,7 +29,7 @@ func randomPipePath() string {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return `\\.\PIPE\go-winpipe-test-` + guid.String()
|
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPingPong(t *testing.T) {
|
func TestPingPong(t *testing.T) {
|
||||||
|
@ -39,7 +38,7 @@ func TestPingPong(t *testing.T) {
|
||||||
pong = 24
|
pong = 24
|
||||||
)
|
)
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
listener, err := winpipe.Listen(pipePath, nil)
|
listener, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to listen on pipe: %v", err)
|
t.Fatalf("unable to listen on pipe: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -64,11 +63,12 @@ func TestPingPong(t *testing.T) {
|
||||||
t.Fatalf("unable to write pong to pipe: %v", err)
|
t.Fatalf("unable to write pong to pipe: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
client, err := winpipe.Dial(pipePath, nil, nil)
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to dial pipe: %v", err)
|
t.Fatalf("unable to dial pipe: %v", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
client.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
var data [1]byte
|
var data [1]byte
|
||||||
data[0] = ping
|
data[0] = ping
|
||||||
_, err = client.Write(data[:])
|
_, err = client.Write(data[:])
|
||||||
|
@ -85,7 +85,7 @@ func TestPingPong(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialUnknownFailsImmediately(t *testing.T) {
|
func TestDialUnknownFailsImmediately(t *testing.T) {
|
||||||
_, err := winpipe.Dial(randomPipePath(), nil, nil)
|
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
|
||||||
if !errors.Is(err, syscall.ENOENT) {
|
if !errors.Is(err, syscall.ENOENT) {
|
||||||
t.Fatalf("expected ENOENT got %v", err)
|
t.Fatalf("expected ENOENT got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -93,13 +93,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) {
|
||||||
|
|
||||||
func TestDialListenerTimesOut(t *testing.T) {
|
func TestDialListenerTimesOut(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
d := 10 * time.Millisecond
|
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
|
||||||
_, err = winpipe.Dial(pipePath, &d, nil)
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if err != os.ErrDeadlineExceeded {
|
if err != os.ErrDeadlineExceeded {
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -107,14 +109,17 @@ func TestDialListenerTimesOut(t *testing.T) {
|
||||||
|
|
||||||
func TestDialContextListenerTimesOut(t *testing.T) {
|
func TestDialContextListenerTimesOut(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
d := 10 * time.Millisecond
|
d := 10 * time.Millisecond
|
||||||
ctx, _ := context.WithTimeout(context.Background(), d)
|
ctx, _ := context.WithTimeout(context.Background(), d)
|
||||||
_, err = winpipe.DialContext(ctx, pipePath, nil)
|
pipe, err := namedpipe.DialContext(ctx, pipePath)
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if err != context.DeadlineExceeded {
|
if err != context.DeadlineExceeded {
|
||||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -123,14 +128,14 @@ func TestDialContextListenerTimesOut(t *testing.T) {
|
||||||
func TestDialListenerGetsCancelled(t *testing.T) {
|
func TestDialListenerGetsCancelled(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ch := make(chan error)
|
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
ch := make(chan error)
|
||||||
go func(ctx context.Context, ch chan error) {
|
go func(ctx context.Context, ch chan error) {
|
||||||
_, err := winpipe.DialContext(ctx, pipePath, nil)
|
_, err := namedpipe.DialContext(ctx, pipePath)
|
||||||
ch <- err
|
ch <- err
|
||||||
}(ctx, ch)
|
}(ctx, ch)
|
||||||
time.Sleep(time.Millisecond * 30)
|
time.Sleep(time.Millisecond * 30)
|
||||||
|
@ -147,23 +152,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
|
||||||
}
|
}
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
sd, _ := windows.SecurityDescriptorFromString("D:")
|
sd, _ := windows.SecurityDescriptorFromString("D:")
|
||||||
c := winpipe.ListenConfig{
|
l, err := (&namedpipe.ListenConfig{
|
||||||
SecurityDescriptor: sd,
|
SecurityDescriptor: sd,
|
||||||
}
|
}).Listen(pipePath)
|
||||||
l, err := winpipe.Listen(pipePath, &c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
_, err = winpipe.Dial(pipePath, nil, nil)
|
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
||||||
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
|
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
|
func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, cfg)
|
if cfg == nil {
|
||||||
|
cfg = &namedpipe.ListenConfig{}
|
||||||
|
}
|
||||||
|
l, err := cfg.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -179,7 +189,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
|
||||||
ch <- response{c, err}
|
ch <- response{c, err}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -236,7 +246,7 @@ func server(l net.Listener, ch chan int) {
|
||||||
|
|
||||||
func TestFullListenDialReadWrite(t *testing.T) {
|
func TestFullListenDialReadWrite(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -245,7 +255,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
|
||||||
ch := make(chan int)
|
ch := make(chan int)
|
||||||
go server(l, ch)
|
go server(l, ch)
|
||||||
|
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -275,7 +285,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
|
||||||
|
|
||||||
func TestCloseAbortsListen(t *testing.T) {
|
func TestCloseAbortsListen(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -328,7 +338,7 @@ func TestCloseServerEOFClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseWriteEOF(t *testing.T) {
|
func TestCloseWriteEOF(t *testing.T) {
|
||||||
cfg := &winpipe.ListenConfig{
|
cfg := &namedpipe.ListenConfig{
|
||||||
MessageMode: true,
|
MessageMode: true,
|
||||||
}
|
}
|
||||||
c, s, err := getConnection(cfg)
|
c, s, err := getConnection(cfg)
|
||||||
|
@ -356,7 +366,7 @@ func TestCloseWriteEOF(t *testing.T) {
|
||||||
|
|
||||||
func TestAcceptAfterCloseFails(t *testing.T) {
|
func TestAcceptAfterCloseFails(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -369,12 +379,15 @@ func TestAcceptAfterCloseFails(t *testing.T) {
|
||||||
|
|
||||||
func TestDialTimesOutByDefault(t *testing.T) {
|
func TestDialTimesOutByDefault(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
_, err = winpipe.Dial(pipePath, nil, nil)
|
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if err != os.ErrDeadlineExceeded {
|
if err != os.ErrDeadlineExceeded {
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -382,7 +395,7 @@ func TestDialTimesOutByDefault(t *testing.T) {
|
||||||
|
|
||||||
func TestTimeoutPendingRead(t *testing.T) {
|
func TestTimeoutPendingRead(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -400,7 +413,7 @@ func TestTimeoutPendingRead(t *testing.T) {
|
||||||
close(serverDone)
|
close(serverDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := winpipe.Dial(pipePath, nil, nil)
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -430,7 +443,7 @@ func TestTimeoutPendingRead(t *testing.T) {
|
||||||
|
|
||||||
func TestTimeoutPendingWrite(t *testing.T) {
|
func TestTimeoutPendingWrite(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -448,7 +461,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
|
||||||
close(serverDone)
|
close(serverDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := winpipe.Dial(pipePath, nil, nil)
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -480,13 +493,12 @@ type CloseWriter interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEchoWithMessaging(t *testing.T) {
|
func TestEchoWithMessaging(t *testing.T) {
|
||||||
c := winpipe.ListenConfig{
|
pipePath := randomPipePath()
|
||||||
|
l, err := (&namedpipe.ListenConfig{
|
||||||
MessageMode: true, // Use message mode so that CloseWrite() is supported
|
MessageMode: true, // Use message mode so that CloseWrite() is supported
|
||||||
InputBufferSize: 65536, // Use 64KB buffers to improve performance
|
InputBufferSize: 65536, // Use 64KB buffers to improve performance
|
||||||
OutputBufferSize: 65536,
|
OutputBufferSize: 65536,
|
||||||
}
|
}).Listen(pipePath)
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := winpipe.Listen(pipePath, &c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -496,19 +508,21 @@ func TestEchoWithMessaging(t *testing.T) {
|
||||||
clientDone := make(chan bool)
|
clientDone := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
// server echo
|
// server echo
|
||||||
conn, e := l.Accept()
|
conn, err := l.Accept()
|
||||||
if e != nil {
|
if err != nil {
|
||||||
t.Fatal(e)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
|
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
|
||||||
io.Copy(conn, conn)
|
_, err = io.Copy(conn, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
conn.(CloseWriter).CloseWrite()
|
conn.(CloseWriter).CloseWrite()
|
||||||
close(listenerDone)
|
close(listenerDone)
|
||||||
}()
|
}()
|
||||||
timeout := 1 * time.Second
|
client, err := namedpipe.DialTimeout(pipePath, time.Second)
|
||||||
client, err := winpipe.Dial(pipePath, &timeout, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -521,7 +535,7 @@ func TestEchoWithMessaging(t *testing.T) {
|
||||||
if e != nil {
|
if e != nil {
|
||||||
t.Fatal(e)
|
t.Fatal(e)
|
||||||
}
|
}
|
||||||
if n != 2 {
|
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
|
||||||
t.Fatalf("expected 2 bytes, got %v", n)
|
t.Fatalf("expected 2 bytes, got %v", n)
|
||||||
}
|
}
|
||||||
close(clientDone)
|
close(clientDone)
|
||||||
|
@ -545,7 +559,7 @@ func TestEchoWithMessaging(t *testing.T) {
|
||||||
|
|
||||||
func TestConnectRace(t *testing.T) {
|
func TestConnectRace(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -565,7 +579,7 @@ func TestConnectRace(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -580,7 +594,7 @@ func TestMessageReadMode(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer wg.Wait()
|
defer wg.Wait()
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true})
|
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -602,7 +616,7 @@ func TestMessageReadMode(t *testing.T) {
|
||||||
s.Close()
|
s.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -643,13 +657,13 @@ func TestListenConnectRace(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
s, err := winpipe.Listen(pipePath, nil)
|
s, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(i, err)
|
t.Error(i, err)
|
||||||
} else {
|
} else {
|
|
@ -1,8 +1,8 @@
|
||||||
// +build darwin freebsd openbsd
|
//go:build darwin || freebsd || openbsd
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
@ -54,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
|
|
||||||
// wrap file in listener
|
// wrap file in listener
|
||||||
|
|
||||||
listener, err := net.FileListener(file)
|
listener, err := net.FileListener(file)
|
||||||
|
@ -104,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if kerr != nil || n != 1 {
|
if (kerr != nil || n != 1) && kerr != unix.EINTR {
|
||||||
if kerr != nil {
|
if kerr != nil {
|
||||||
l.connErr <- kerr
|
l.connErr <- kerr
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
@ -9,8 +9,8 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
|
@ -51,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
|
|
||||||
// wrap file in listener
|
// wrap file in listener
|
||||||
|
|
||||||
listener, err := net.FileListener(file)
|
listener, err := net.FileListener(file)
|
||||||
|
@ -97,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
go func(l *UAPIListener) {
|
go func(l *UAPIListener) {
|
||||||
var buff [0]byte
|
var buf [0]byte
|
||||||
for {
|
for {
|
||||||
defer uapi.inotifyRWCancel.Close()
|
defer uapi.inotifyRWCancel.Close()
|
||||||
// start with lstat to avoid race condition
|
// start with lstat to avoid race condition
|
||||||
|
@ -105,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err := uapi.inotifyRWCancel.Read(buff[:])
|
_, err := uapi.inotifyRWCancel.Read(buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
// +build linux darwin freebsd openbsd
|
//go:build linux || darwin || freebsd || openbsd
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
@ -26,14 +26,14 @@ const (
|
||||||
|
|
||||||
// socketDirectory is variable because it is modified by a linker
|
// socketDirectory is variable because it is modified by a linker
|
||||||
// flag in wireguard-android.
|
// flag in wireguard-android.
|
||||||
var socketDirectory = "/var/run/wireguard"
|
var socketDirectory = "/var/run/amneziawg"
|
||||||
|
|
||||||
func sockPath(iface string) string {
|
func sockPath(iface string) string {
|
||||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIOpen(name string) (*os.File, error) {
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
if err := os.MkdirAll(socketDirectory, 0755); err != nil {
|
if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ func UAPIOpen(name string) (*os.File, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldUmask := unix.Umask(0077)
|
oldUmask := unix.Umask(0o077)
|
||||||
defer unix.Umask(oldUmask)
|
defer unix.Umask(oldUmask)
|
||||||
|
|
||||||
listener, err := net.ListenUnix("unix", addr)
|
listener, err := net.ListenUnix("unix", addr)
|
||||||
|
|
15
ipc/uapi_wasm.go
Normal file
15
ipc/uapi_wasm.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ipc
|
||||||
|
|
||||||
|
// Made up sentinel error codes for {js,wasip1}/wasm.
|
||||||
|
const (
|
||||||
|
IpcErrorIO = 1
|
||||||
|
IpcErrorInvalid = 2
|
||||||
|
IpcErrorPortInUse = 3
|
||||||
|
IpcErrorUnknown = 4
|
||||||
|
IpcErrorProtocol = 5
|
||||||
|
)
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
@ -8,9 +8,8 @@ package ipc
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: replace these with actual standard windows error numbers from the win package
|
// TODO: replace these with actual standard windows error numbers from the win package
|
||||||
|
@ -54,18 +53,16 @@ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var err error
|
var err error
|
||||||
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
|
||||||
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
config := winpipe.ListenConfig{
|
listener, err := (&namedpipe.ListenConfig{
|
||||||
SecurityDescriptor: UAPISecurityDescriptor,
|
SecurityDescriptor: UAPISecurityDescriptor,
|
||||||
}
|
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
|
||||||
listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
49
main.go
49
main.go
|
@ -1,8 +1,8 @@
|
||||||
// +build !windows
|
//go:build !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
@ -13,12 +13,12 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -46,20 +46,20 @@ func warning() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
|
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
|
||||||
fmt.Fprintln(os.Stderr, "│ │")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
|
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
|
||||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
|
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
|
||||||
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||||
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
||||||
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
|
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
|
||||||
fmt.Fprintln(os.Stderr, "│ │")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
|
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
||||||
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
|
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ func main() {
|
||||||
|
|
||||||
// open TUN device (or use supplied fd)
|
// open TUN device (or use supplied fd)
|
||||||
|
|
||||||
tun, err := func() (tun.Device, error) {
|
tdev, err := func() (tun.Device, error) {
|
||||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
if tunFdStr == "" {
|
if tunFdStr == "" {
|
||||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
||||||
|
@ -124,7 +124,7 @@ func main() {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = syscall.SetNonblock(int(fd), true)
|
err = unix.SetNonblock(int(fd), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ func main() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tun.Name()
|
realInterfaceName, err2 := tdev.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
interfaceName = realInterfaceName
|
interfaceName = realInterfaceName
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ func main() {
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to create TUN device: %v", err)
|
logger.Errorf("Failed to create TUN device: %v", err)
|
||||||
|
@ -169,7 +169,6 @@ func main() {
|
||||||
|
|
||||||
return os.NewFile(uintptr(fd), ""), nil
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("UAPI listen error: %v", err)
|
logger.Errorf("UAPI listen error: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
|
@ -197,7 +196,7 @@ func main() {
|
||||||
files[0], // stdin
|
files[0], // stdin
|
||||||
files[1], // stdout
|
files[1], // stdout
|
||||||
files[2], // stderr
|
files[2], // stderr
|
||||||
tun.File(),
|
tdev.File(),
|
||||||
fileUAPI,
|
fileUAPI,
|
||||||
},
|
},
|
||||||
Dir: ".",
|
Dir: ".",
|
||||||
|
@ -223,7 +222,7 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
|
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
||||||
|
|
||||||
logger.Verbosef("Device started")
|
logger.Verbosef("Device started")
|
||||||
|
|
||||||
|
@ -251,7 +250,7 @@ func main() {
|
||||||
|
|
||||||
// wait for program to terminate
|
// wait for program to terminate
|
||||||
|
|
||||||
signal.Notify(term, syscall.SIGTERM)
|
signal.Notify(term, unix.SIGTERM)
|
||||||
signal.Notify(term, os.Interrupt)
|
signal.Notify(term, os.Interrupt)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
@ -9,13 +9,14 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -29,13 +30,13 @@ func main() {
|
||||||
}
|
}
|
||||||
interfaceName := os.Args[1]
|
interfaceName := os.Args[1]
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
|
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
|
||||||
|
|
||||||
logger := device.NewLogger(
|
logger := device.NewLogger(
|
||||||
device.LogLevelVerbose,
|
device.LogLevelVerbose,
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName, 0)
|
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -81,7 +82,7 @@ func main() {
|
||||||
|
|
||||||
signal.Notify(term, os.Interrupt)
|
signal.Notify(term, os.Interrupt)
|
||||||
signal.Notify(term, os.Kill)
|
signal.Notify(term, os.Kill)
|
||||||
signal.Notify(term, syscall.SIGTERM)
|
signal.Notify(term, windows.SIGTERM)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-term:
|
case <-term:
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -30,8 +30,7 @@ type Ratelimiter struct {
|
||||||
timeNow func() time.Time
|
timeNow func() time.Time
|
||||||
|
|
||||||
stopReset chan struct{} // send to reset, close to stop
|
stopReset chan struct{} // send to reset, close to stop
|
||||||
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
table map[netip.Addr]*RatelimiterEntry
|
||||||
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Close() {
|
func (rate *Ratelimiter) Close() {
|
||||||
|
@ -57,8 +56,7 @@ func (rate *Ratelimiter) Init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
rate.stopReset = make(chan struct{})
|
rate.stopReset = make(chan struct{})
|
||||||
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
rate.table = make(map[netip.Addr]*RatelimiterEntry)
|
||||||
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
|
||||||
|
|
||||||
stopReset := rate.stopReset // store in case Init is called again.
|
stopReset := rate.stopReset // store in case Init is called again.
|
||||||
|
|
||||||
|
@ -87,71 +85,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
|
||||||
rate.mu.Lock()
|
rate.mu.Lock()
|
||||||
defer rate.mu.Unlock()
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv4 {
|
for key, entry := range rate.table {
|
||||||
entry.mu.Lock()
|
entry.mu.Lock()
|
||||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
delete(rate.tableIPv4, key)
|
delete(rate.table, key)
|
||||||
}
|
}
|
||||||
entry.mu.Unlock()
|
entry.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv6 {
|
return len(rate.table) == 0
|
||||||
entry.mu.Lock()
|
|
||||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
|
||||||
delete(rate.tableIPv6, key)
|
|
||||||
}
|
|
||||||
entry.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
||||||
var entry *RatelimiterEntry
|
var entry *RatelimiterEntry
|
||||||
var keyIPv4 [net.IPv4len]byte
|
|
||||||
var keyIPv6 [net.IPv6len]byte
|
|
||||||
|
|
||||||
// lookup entry
|
// lookup entry
|
||||||
|
|
||||||
IPv4 := ip.To4()
|
|
||||||
IPv6 := ip.To16()
|
|
||||||
|
|
||||||
rate.mu.RLock()
|
rate.mu.RLock()
|
||||||
|
entry = rate.table[ip]
|
||||||
if IPv4 != nil {
|
|
||||||
copy(keyIPv4[:], IPv4)
|
|
||||||
entry = rate.tableIPv4[keyIPv4]
|
|
||||||
} else {
|
|
||||||
copy(keyIPv6[:], IPv6)
|
|
||||||
entry = rate.tableIPv6[keyIPv6]
|
|
||||||
}
|
|
||||||
|
|
||||||
rate.mu.RUnlock()
|
rate.mu.RUnlock()
|
||||||
|
|
||||||
// make new entry if not found
|
// make new entry if not found
|
||||||
|
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
entry = new(RatelimiterEntry)
|
entry = new(RatelimiterEntry)
|
||||||
entry.tokens = maxTokens - packetCost
|
entry.tokens = maxTokens - packetCost
|
||||||
entry.lastTime = rate.timeNow()
|
entry.lastTime = rate.timeNow()
|
||||||
rate.mu.Lock()
|
rate.mu.Lock()
|
||||||
if IPv4 != nil {
|
rate.table[ip] = entry
|
||||||
rate.tableIPv4[keyIPv4] = entry
|
if len(rate.table) == 1 {
|
||||||
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
rate.stopReset <- struct{}{}
|
||||||
rate.stopReset <- struct{}{}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rate.tableIPv6[keyIPv6] = entry
|
|
||||||
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
|
|
||||||
rate.stopReset <- struct{}{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
rate.mu.Unlock()
|
rate.mu.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// add tokens to entry
|
// add tokens to entry
|
||||||
|
|
||||||
entry.mu.Lock()
|
entry.mu.Lock()
|
||||||
now := rate.timeNow()
|
now := rate.timeNow()
|
||||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||||
|
@ -161,7 +127,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// subtract cost of packet
|
// subtract cost of packet
|
||||||
|
|
||||||
if entry.tokens > packetCost {
|
if entry.tokens > packetCost {
|
||||||
entry.tokens -= packetCost
|
entry.tokens -= packetCost
|
||||||
entry.mu.Unlock()
|
entry.mu.Unlock()
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -71,21 +71,21 @@ func TestRatelimiter(t *testing.T) {
|
||||||
text: "packet following 2 packet burst",
|
text: "packet following 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
ips := []net.IP{
|
ips := []netip.Addr{
|
||||||
net.ParseIP("127.0.0.1"),
|
netip.MustParseAddr("127.0.0.1"),
|
||||||
net.ParseIP("192.168.1.1"),
|
netip.MustParseAddr("192.168.1.1"),
|
||||||
net.ParseIP("172.167.2.3"),
|
netip.MustParseAddr("172.167.2.3"),
|
||||||
net.ParseIP("97.231.252.215"),
|
netip.MustParseAddr("97.231.252.215"),
|
||||||
net.ParseIP("248.97.91.167"),
|
netip.MustParseAddr("248.97.91.167"),
|
||||||
net.ParseIP("188.208.233.47"),
|
netip.MustParseAddr("188.208.233.47"),
|
||||||
net.ParseIP("104.2.183.179"),
|
netip.MustParseAddr("104.2.183.179"),
|
||||||
net.ParseIP("72.129.46.120"),
|
netip.MustParseAddr("72.129.46.120"),
|
||||||
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
||||||
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
||||||
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
||||||
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
||||||
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
||||||
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||||
|
@ -34,7 +34,7 @@ func (f *Filter) Reset() {
|
||||||
|
|
||||||
// ValidateCounter checks if the counter should be accepted.
|
// ValidateCounter checks if the counter should be accepted.
|
||||||
// Overlimit counters (>= limit) are always rejected.
|
// Overlimit counters (>= limit) are always rejected.
|
||||||
func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool {
|
func (f *Filter) ValidateCounter(counter, limit uint64) bool {
|
||||||
if counter >= limit {
|
if counter >= limit {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package replay
|
package replay
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
// +build !windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package rwcancel
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
type fdSet struct {
|
|
||||||
unix.FdSet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fdset *fdSet) set(i int) {
|
|
||||||
bits := 32 << (^uint(0) >> 63)
|
|
||||||
fdset.Bits[i/bits] |= 1 << uint(i%bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fdset *fdSet) check(i int) bool {
|
|
||||||
bits := 32 << (^uint(0) >> 63)
|
|
||||||
return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
|
|
||||||
}
|
|
|
@ -1,8 +1,8 @@
|
||||||
// +build !windows
|
//go:build !windows && !wasm
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Package rwcancel implements cancelable read/write operations on
|
// Package rwcancel implements cancelable read/write operations on
|
||||||
|
@ -17,13 +17,6 @@ import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func max(a, b int) int {
|
|
||||||
if a > b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
type RWCancel struct {
|
type RWCancel struct {
|
||||||
fd int
|
fd int
|
||||||
closingReader *os.File
|
closingReader *os.File
|
||||||
|
@ -50,13 +43,12 @@ func RetryAfterError(err error) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *RWCancel) ReadyRead() bool {
|
func (rw *RWCancel) ReadyRead() bool {
|
||||||
closeFd := int(rw.closingReader.Fd())
|
closeFd := int32(rw.closingReader.Fd())
|
||||||
fdset := fdSet{}
|
|
||||||
fdset.set(rw.fd)
|
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
|
||||||
fdset.set(closeFd)
|
|
||||||
var err error
|
var err error
|
||||||
for {
|
for {
|
||||||
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
_, err = unix.Poll(pollFds, -1)
|
||||||
if err == nil || !RetryAfterError(err) {
|
if err == nil || !RetryAfterError(err) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -64,20 +56,18 @@ func (rw *RWCancel) ReadyRead() bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if fdset.check(closeFd) {
|
if pollFds[1].Revents != 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return fdset.check(rw.fd)
|
return pollFds[0].Revents != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *RWCancel) ReadyWrite() bool {
|
func (rw *RWCancel) ReadyWrite() bool {
|
||||||
closeFd := int(rw.closingReader.Fd())
|
closeFd := int32(rw.closingReader.Fd())
|
||||||
fdset := fdSet{}
|
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
|
||||||
fdset.set(rw.fd)
|
|
||||||
fdset.set(closeFd)
|
|
||||||
var err error
|
var err error
|
||||||
for {
|
for {
|
||||||
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
_, err = unix.Poll(pollFds, -1)
|
||||||
if err == nil || !RetryAfterError(err) {
|
if err == nil || !RetryAfterError(err) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -85,10 +75,11 @@ func (rw *RWCancel) ReadyWrite() bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if fdset.check(closeFd) {
|
|
||||||
|
if pollFds[1].Revents != 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return fdset.check(rw.fd)
|
return pollFds[0].Revents != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
||||||
|
@ -98,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
if !rw.ReadyRead() {
|
if !rw.ReadyRead() {
|
||||||
return 0, errors.New("fd closed")
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
if !rw.ReadyWrite() {
|
if !rw.ReadyWrite() {
|
||||||
return 0, errors.New("fd closed")
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
//go:build windows || wasm
|
||||||
|
|
||||||
// SPDX-License-Identifier: MIT
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
package rwcancel
|
package rwcancel
|
||||||
|
|
||||||
type RWCancel struct {
|
type RWCancel struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (*RWCancel) Cancel() {}
|
func (*RWCancel) Cancel() {}
|
|
@ -1,15 +0,0 @@
|
||||||
// +build !linux,!windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package rwcancel
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
|
|
||||||
_, err := unix.Select(nfd, r, w, e, timeout)
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package rwcancel
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) (err error) {
|
|
||||||
_, err = unix.Select(nfd, r, w, e, timeout)
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tai64n
|
package tai64n
|
||||||
|
@ -11,9 +11,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const TimestampSize = 12
|
const (
|
||||||
const base = uint64(0x400000000000000a)
|
TimestampSize = 12
|
||||||
const whitenerMask = uint32(0x1000000 - 1)
|
base = uint64(0x400000000000000a)
|
||||||
|
whitenerMask = uint32(0x1000000 - 1)
|
||||||
|
)
|
||||||
|
|
||||||
type Timestamp [TimestampSize]byte
|
type Timestamp [TimestampSize]byte
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tai64n
|
package tai64n
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package tun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -18,15 +18,15 @@ func checkAlignment(t *testing.T, name string, offset uintptr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeerAlignment checks that atomically-accessed fields are
|
// TestRateJugglerAlignment checks that atomically-accessed fields are
|
||||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
// aligned to 64-bit boundaries, as required by the atomic package.
|
||||||
//
|
//
|
||||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
// Unfortunately, violating this rule on 32-bit platforms results in a
|
||||||
// hard segfault at runtime.
|
// hard segfault at runtime.
|
||||||
func TestPeerAlignment(t *testing.T) {
|
func TestRateJugglerAlignment(t *testing.T) {
|
||||||
var p Peer
|
var r rateJuggler
|
||||||
|
|
||||||
typ := reflect.TypeOf(&p).Elem()
|
typ := reflect.TypeOf(&r).Elem()
|
||||||
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
|
@ -38,20 +38,21 @@ func TestPeerAlignment(t *testing.T) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
|
checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
|
||||||
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
|
checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
|
||||||
|
checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDeviceAlignment checks that atomically-accessed fields are
|
// TestNativeTunAlignment checks that atomically-accessed fields are
|
||||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
// aligned to 64-bit boundaries, as required by the atomic package.
|
||||||
//
|
//
|
||||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
// Unfortunately, violating this rule on 32-bit platforms results in a
|
||||||
// hard segfault at runtime.
|
// hard segfault at runtime.
|
||||||
func TestDeviceAlignment(t *testing.T) {
|
func TestNativeTunAlignment(t *testing.T) {
|
||||||
var d Device
|
var tun NativeTun
|
||||||
|
|
||||||
typ := reflect.TypeOf(&d).Elem()
|
typ := reflect.TypeOf(&tun).Elem()
|
||||||
t.Logf("Device type size: %d, with fields:", typ.Size())
|
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
||||||
|
@ -61,5 +62,6 @@ func TestDeviceAlignment(t *testing.T) {
|
||||||
field.Type.Align(),
|
field.Type.Align(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
|
|
||||||
|
checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
|
||||||
}
|
}
|
118
tun/checksum.go
Normal file
118
tun/checksum.go
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||||
|
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
|
||||||
|
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||||
|
ac := initial
|
||||||
|
|
||||||
|
for len(b) >= 128 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
|
||||||
|
b = b[128:]
|
||||||
|
}
|
||||||
|
if len(b) >= 64 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||||
|
b = b[64:]
|
||||||
|
}
|
||||||
|
if len(b) >= 32 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
b = b[32:]
|
||||||
|
}
|
||||||
|
if len(b) >= 16 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
if len(b) >= 8 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(b) >= 4 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b))
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
if len(b) >= 2 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint16(b))
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) == 1 {
|
||||||
|
ac += uint64(b[0]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
return ac
|
||||||
|
}
|
||||||
|
|
||||||
|
func checksum(b []byte, initial uint64) uint16 {
|
||||||
|
ac := checksumNoFold(b, initial)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
ac = (ac >> 16) + (ac & 0xffff)
|
||||||
|
return uint16(ac)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
||||||
|
sum := checksumNoFold(srcAddr, 0)
|
||||||
|
sum = checksumNoFold(dstAddr, sum)
|
||||||
|
sum = checksumNoFold([]byte{0, protocol}, sum)
|
||||||
|
tmp := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||||
|
return checksumNoFold(tmp, sum)
|
||||||
|
}
|
35
tun/checksum_test.go
Normal file
35
tun/checksum_test.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkChecksum(b *testing.B) {
|
||||||
|
lengths := []int{
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1500,
|
||||||
|
2048,
|
||||||
|
4096,
|
||||||
|
8192,
|
||||||
|
9000,
|
||||||
|
9001,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, length := range lengths {
|
||||||
|
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
|
||||||
|
buf := make([]byte, length)
|
||||||
|
rng := rand.New(rand.NewSource(1))
|
||||||
|
rng.Read(buf)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
checksum(buf, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
12
tun/errors.go
Normal file
12
tun/errors.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrTooManySegments is returned by Device.Read() when segmentation
|
||||||
|
// overflows the length of supplied buffers. This error should not cause
|
||||||
|
// reads to cease.
|
||||||
|
ErrTooManySegments = errors.New("too many segments")
|
||||||
|
)
|
|
@ -1,8 +1,8 @@
|
||||||
// +build ignore
|
//go:build ignore
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
@ -10,27 +10,27 @@ package main
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
tun, tnet, err := netstack.CreateNetTUN(
|
tun, tnet, err := netstack.CreateNetTUN(
|
||||||
[]net.IP{net.ParseIP("192.168.4.29")},
|
[]netip.Addr{netip.MustParseAddr("192.168.4.28")},
|
||||||
[]net.IP{net.ParseIP("8.8.8.8")},
|
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
|
||||||
1420)
|
1420)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
log.Panic(err)
|
||||||
}
|
}
|
||||||
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
||||||
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
|
err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
|
||||||
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
|
public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
|
||||||
endpoint=163.172.161.0:12912
|
|
||||||
allowed_ip=0.0.0.0/0
|
allowed_ip=0.0.0.0/0
|
||||||
|
endpoint=127.0.0.1:58120
|
||||||
`)
|
`)
|
||||||
err = dev.Up()
|
err = dev.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -42,7 +42,7 @@ allowed_ip=0.0.0.0/0
|
||||||
DialContext: tnet.DialContext,
|
DialContext: tnet.DialContext,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
resp, err := client.Get("https://www.zx2c4.com/ip")
|
resp, err := client.Get("http://192.168.4.29/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
log.Panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
// +build ignore
|
//go:build ignore
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
@ -12,26 +12,27 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
tun, tnet, err := netstack.CreateNetTUN(
|
tun, tnet, err := netstack.CreateNetTUN(
|
||||||
[]net.IP{net.ParseIP("192.168.4.29")},
|
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
|
||||||
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
|
[]netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
|
||||||
1420,
|
1420,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
log.Panic(err)
|
||||||
}
|
}
|
||||||
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
||||||
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
|
dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
|
||||||
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
|
listen_port=58120
|
||||||
endpoint=163.172.161.0:12912
|
public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
|
||||||
allowed_ip=0.0.0.0/0
|
allowed_ip=192.168.4.28/32
|
||||||
persistent_keepalive_interval=25
|
persistent_keepalive_interval=25
|
||||||
`)
|
`)
|
||||||
dev.Up()
|
dev.Up()
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue