mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-09-24 10:25:03 +02:00
Compare commits
No commits in common. "master" and "0.0.20220316" have entirely different histories.
master
...
0.0.202203
127 changed files with 3096 additions and 8557 deletions
41
.github/workflows/build-if-tag.yml
vendored
41
.github/workflows/build-if-tag.yml
vendored
|
@ -1,41 +0,0 @@
|
|||
name: build-if-tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||
|
||||
env:
|
||||
APP: amneziawg-go
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
name: build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Setup metadata
|
||||
uses: docker/metadata-action@v5
|
||||
id: metadata
|
||||
with:
|
||||
images: amneziavpn/${{ env.APP }}
|
||||
tags: type=semver,pattern={{version}}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
push: true
|
||||
tags: ${{ steps.metadata.outputs.tags }}
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1 @@
|
|||
amneziawg-go
|
||||
wireguard-go
|
||||
|
|
18
Dockerfile
18
Dockerfile
|
@ -1,18 +0,0 @@
|
|||
FROM golang:1.24.4 as awg
|
||||
COPY . /awg
|
||||
WORKDIR /awg
|
||||
RUN go mod download && \
|
||||
go mod verify && \
|
||||
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG AWGTOOLS_RELEASE="1.0.20250901"
|
||||
|
||||
RUN apk --no-cache add iproute2 iptables bash && \
|
||||
cd /usr/bin/ && \
|
||||
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
|
||||
unzip -j alpine-3.19-amneziawg-tools.zip && \
|
||||
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
|
||||
ln -s /usr/bin/awg /usr/bin/wg && \
|
||||
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
|
||||
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go
|
12
Makefile
12
Makefile
|
@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
|
|||
|
||||
generate-version-and-build:
|
||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
|
||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
||||
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
||||
echo "$$ver" > version.go && \
|
||||
git update-index --assume-unchanged version.go || true
|
||||
@$(MAKE) amneziawg-go
|
||||
@$(MAKE) wireguard-go
|
||||
|
||||
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
|
||||
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
||||
go build -v -o "$@"
|
||||
|
||||
install: amneziawg-go
|
||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
|
||||
install: wireguard-go
|
||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
clean:
|
||||
rm -f amneziawg-go
|
||||
rm -f wireguard-go
|
||||
|
||||
.PHONY: all clean test install generate-version-and-build
|
||||
|
|
60
README.md
60
README.md
|
@ -1,27 +1,24 @@
|
|||
# Go Implementation of AmneziaWG
|
||||
# Go Implementation of [WireGuard](https://www.wireguard.com/)
|
||||
|
||||
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
|
||||
|
||||
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
|
||||
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
|
||||
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
|
||||
This is an implementation of WireGuard in Go.
|
||||
|
||||
## Usage
|
||||
|
||||
Simply run:
|
||||
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
|
||||
|
||||
```
|
||||
$ amneziawg-go wg0
|
||||
$ wireguard-go wg0
|
||||
```
|
||||
|
||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
|
||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
|
||||
|
||||
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
|
||||
To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
|
||||
|
||||
```
|
||||
$ amneziawg-go -f wg0
|
||||
$ wireguard-go -f wg0
|
||||
```
|
||||
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||
|
||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||
|
||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||
|
||||
|
@ -29,25 +26,52 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
|||
|
||||
### Linux
|
||||
|
||||
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
|
||||
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
|
||||
|
||||
### macOS
|
||||
|
||||
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
||||
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
|
||||
|
||||
### Windows
|
||||
|
||||
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
|
||||
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
|
||||
|
||||
### FreeBSD
|
||||
|
||||
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
|
||||
|
||||
### OpenBSD
|
||||
|
||||
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
||||
|
||||
## Building
|
||||
|
||||
This requires an installation of the latest version of [Go](https://go.dev/).
|
||||
This requires an installation of [go](https://golang.org) ≥ 1.18.
|
||||
|
||||
```
|
||||
$ git clone https://github.com/amnezia-vpn/amneziawg-go
|
||||
$ cd amneziawg-go
|
||||
$ git clone https://git.zx2c4.com/wireguard-go
|
||||
$ cd wireguard-go
|
||||
$ make
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Software, and to permit persons to whom the Software is furnished to do
|
||||
so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
562
conn/bind_linux.go
Normal file
562
conn/bind_linux.go
Normal file
|
@ -0,0 +1,562 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type ipv4Source struct {
|
||||
Src [4]byte
|
||||
Ifindex int32
|
||||
}
|
||||
|
||||
type ipv6Source struct {
|
||||
src [16]byte
|
||||
// ifindex belongs in dst.ZoneId
|
||||
}
|
||||
|
||||
type LinuxSocketEndpoint struct {
|
||||
mu sync.Mutex
|
||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
||||
src [unsafe.Sizeof(ipv6Source{})]byte
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
|
||||
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
||||
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
|
||||
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
|
||||
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
|
||||
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
|
||||
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
|
||||
}
|
||||
|
||||
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
|
||||
type LinuxSocketBind struct {
|
||||
// mu guards sock4 and sock6 and the associated fds.
|
||||
// As long as someone holds mu (read or write), the associated fds are valid.
|
||||
mu sync.RWMutex
|
||||
sock4 int
|
||||
sock6 int
|
||||
}
|
||||
|
||||
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
|
||||
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
|
||||
|
||||
var (
|
||||
_ Endpoint = (*LinuxSocketEndpoint)(nil)
|
||||
_ Bind = (*LinuxSocketBind)(nil)
|
||||
)
|
||||
|
||||
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
var end LinuxSocketEndpoint
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.Addr().Is4() {
|
||||
dst := end.dst4()
|
||||
end.isV6 = false
|
||||
dst.Port = int(e.Port())
|
||||
dst.Addr = e.Addr().As4()
|
||||
end.ClearSrc()
|
||||
return &end, nil
|
||||
}
|
||||
|
||||
if e.Addr().Is6() {
|
||||
zone, err := zoneToUint32(e.Addr().Zone())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dst := end.dst6()
|
||||
end.isV6 = true
|
||||
dst.Port = int(e.Port())
|
||||
dst.ZoneId = zone
|
||||
dst.Addr = e.Addr().As16()
|
||||
end.ClearSrc()
|
||||
return &end, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid IP address")
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var newPort uint16
|
||||
var tries int
|
||||
|
||||
if bind.sock4 != -1 || bind.sock6 != -1 {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
originalPort := port
|
||||
|
||||
again:
|
||||
port = originalPort
|
||||
var sock4, sock6 int
|
||||
// Attempt ipv6 bind, update port if successful.
|
||||
sock6, newPort, err = create6(port)
|
||||
if err != nil {
|
||||
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
port = newPort
|
||||
}
|
||||
|
||||
// Attempt ipv4 bind, update port if successful.
|
||||
sock4, newPort, err = create4(port)
|
||||
if err != nil {
|
||||
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
unix.Close(sock6)
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
unix.Close(sock6)
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
port = newPort
|
||||
}
|
||||
|
||||
var fns []ReceiveFunc
|
||||
if sock4 != -1 {
|
||||
bind.sock4 = sock4
|
||||
fns = append(fns, bind.receiveIPv4)
|
||||
}
|
||||
if sock6 != -1 {
|
||||
bind.sock6 = sock6
|
||||
fns = append(fns, bind.receiveIPv6)
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
return fns, port, nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
|
||||
if bind.sock6 != -1 {
|
||||
err := unix.SetsockoptInt(
|
||||
bind.sock6,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if bind.sock4 != -1 {
|
||||
err := unix.SetsockoptInt(
|
||||
bind.sock4,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Close() error {
|
||||
// Take a readlock to shut down the sockets...
|
||||
bind.mu.RLock()
|
||||
if bind.sock6 != -1 {
|
||||
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
|
||||
}
|
||||
if bind.sock4 != -1 {
|
||||
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
|
||||
}
|
||||
bind.mu.RUnlock()
|
||||
// ...and a write lock to close the fd.
|
||||
// This ensures that no one else is using the fd.
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
var err1, err2 error
|
||||
if bind.sock6 != -1 {
|
||||
err1 = unix.Close(bind.sock6)
|
||||
bind.sock6 = -1
|
||||
}
|
||||
if bind.sock4 != -1 {
|
||||
err2 = unix.Close(bind.sock4)
|
||||
bind.sock4 = -1
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.sock4 == -1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
var end LinuxSocketEndpoint
|
||||
n, err := receive4(bind.sock4, buf, &end)
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.sock6 == -1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
var end LinuxSocketEndpoint
|
||||
n, err := receive6(bind.sock6, buf, &end)
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
||||
nend, ok := end.(*LinuxSocketEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if !nend.isV6 {
|
||||
if bind.sock4 == -1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return send4(bind.sock4, nend, buff)
|
||||
} else {
|
||||
if bind.sock6 == -1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return send6(bind.sock6, nend, buff)
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
|
||||
if !end.isV6 {
|
||||
return netip.AddrFrom4(end.src4().Src)
|
||||
} else {
|
||||
return netip.AddrFrom16(end.src6().src)
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
|
||||
if !end.isV6 {
|
||||
return netip.AddrFrom4(end.dst4().Addr)
|
||||
} else {
|
||||
return netip.AddrFrom16(end.dst6().Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
|
||||
if !end.isV6 {
|
||||
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
|
||||
} else {
|
||||
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) SrcToString() string {
|
||||
return end.SrcIP().String()
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) DstToString() string {
|
||||
var port int
|
||||
if !end.isV6 {
|
||||
port = end.dst4().Port
|
||||
} else {
|
||||
port = end.dst6().Port
|
||||
}
|
||||
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) ClearDst() {
|
||||
for i := range end.dst {
|
||||
end.dst[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) ClearSrc() {
|
||||
for i := range end.src {
|
||||
end.src[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func zoneToUint32(zone string) (uint32, error) {
|
||||
if zone == "" {
|
||||
return 0, nil
|
||||
}
|
||||
if intr, err := net.InterfaceByName(zone); err == nil {
|
||||
return uint32(intr.Index), nil
|
||||
}
|
||||
n, err := strconv.ParseUint(zone, 10, 32)
|
||||
return uint32(n), err
|
||||
}
|
||||
|
||||
func create4(port uint16) (int, uint16, error) {
|
||||
// create socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
addr := unix.SockaddrInet4{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
if err := func() error {
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IP,
|
||||
unix.IP_PKTINFO,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.Bind(fd, &addr)
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
sa, err := unix.Getsockname(fd)
|
||||
if err == nil {
|
||||
addr.Port = sa.(*unix.SockaddrInet4).Port
|
||||
}
|
||||
|
||||
return fd, uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func create6(port uint16) (int, uint16, error) {
|
||||
// create socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET6,
|
||||
unix.SOCK_DGRAM,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
addr := unix.SockaddrInet6{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
if err := func() error {
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IPV6,
|
||||
unix.IPV6_RECVPKTINFO,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IPV6,
|
||||
unix.IPV6_V6ONLY,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.Bind(fd, &addr)
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
sa, err := unix.Getsockname(fd)
|
||||
if err == nil {
|
||||
addr.Port = sa.(*unix.SockaddrInet6).Port
|
||||
}
|
||||
|
||||
return fd, uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
|
||||
// construct message header
|
||||
|
||||
cmsg := struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet4Pktinfo
|
||||
}{
|
||||
unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IP,
|
||||
Type: unix.IP_PKTINFO,
|
||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet4Pktinfo{
|
||||
Spec_dst: end.src4().Src,
|
||||
Ifindex: end.src4().Ifindex,
|
||||
},
|
||||
}
|
||||
|
||||
end.mu.Lock()
|
||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||
end.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear src and retry
|
||||
|
||||
if err == unix.EINVAL {
|
||||
end.ClearSrc()
|
||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||
end.mu.Lock()
|
||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||
end.mu.Unlock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
|
||||
// construct message header
|
||||
|
||||
cmsg := struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet6Pktinfo
|
||||
}{
|
||||
unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IPV6,
|
||||
Type: unix.IPV6_PKTINFO,
|
||||
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet6Pktinfo{
|
||||
Addr: end.src6().src,
|
||||
Ifindex: end.dst6().ZoneId,
|
||||
},
|
||||
}
|
||||
|
||||
if cmsg.pktinfo.Addr == [16]byte{} {
|
||||
cmsg.pktinfo.Ifindex = 0
|
||||
}
|
||||
|
||||
end.mu.Lock()
|
||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||
end.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear src and retry
|
||||
|
||||
if err == unix.EINVAL {
|
||||
end.ClearSrc()
|
||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||
end.mu.Lock()
|
||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||
end.mu.Unlock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
|
||||
// construct message header
|
||||
|
||||
var cmsg struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet4Pktinfo
|
||||
}
|
||||
|
||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
end.isV6 = false
|
||||
|
||||
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
|
||||
*end.dst4() = *newDst4
|
||||
}
|
||||
|
||||
// update source cache
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||
end.src4().Src = cmsg.pktinfo.Spec_dst
|
||||
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
|
||||
// construct message header
|
||||
|
||||
var cmsg struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet6Pktinfo
|
||||
}
|
||||
|
||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
end.isV6 = true
|
||||
|
||||
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
|
||||
*end.dst6() = *newDst6
|
||||
}
|
||||
|
||||
// update source cache
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
||||
end.src6().src = cmsg.pktinfo.Addr
|
||||
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
514
conn/bind_std.go
514
conn/bind_std.go
|
@ -1,126 +1,80 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
)
|
||||
|
||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||
// StdNetBind is meant to be a temporary solution on platforms for which
|
||||
// the sticky socket / source caching behavior has not yet been implemented.
|
||||
// It uses the Go's net package to implement networking.
|
||||
// See LinuxSocketBind for a proper implementation on the Linux platform.
|
||||
type StdNetBind struct {
|
||||
mu sync.Mutex // protects all fields except as specified
|
||||
mu sync.Mutex // protects following fields
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||
ipv4TxOffload bool
|
||||
ipv4RxOffload bool
|
||||
ipv6TxOffload bool
|
||||
ipv6RxOffload bool
|
||||
|
||||
// these two fields are not guarded by mu
|
||||
udpAddrPool sync.Pool
|
||||
msgsPool sync.Pool
|
||||
|
||||
blackhole4 bool
|
||||
blackhole6 bool
|
||||
}
|
||||
|
||||
func NewStdNetBind() Bind {
|
||||
return &StdNetBind{
|
||||
udpAddrPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
},
|
||||
},
|
||||
func NewStdNetBind() Bind { return &StdNetBind{} }
|
||||
|
||||
msgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||
// both aliases for x/net/internal/socket.Message.
|
||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StdNetEndpoint struct {
|
||||
// AddrPort is the endpoint destination.
|
||||
netip.AddrPort
|
||||
// src is the current sticky source address and interface index, if
|
||||
// supported. Typically this is a PKTINFO structure from/for control
|
||||
// messages, see unix.PKTINFO for an example.
|
||||
src []byte
|
||||
}
|
||||
type StdNetEndpoint net.UDPAddr
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
_ Endpoint = &StdNetEndpoint{}
|
||||
_ Endpoint = (*StdNetEndpoint)(nil)
|
||||
)
|
||||
|
||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StdNetEndpoint{
|
||||
AddrPort: e,
|
||||
}, nil
|
||||
return (*StdNetEndpoint)(&net.UDPAddr{
|
||||
IP: e.Addr().AsSlice(),
|
||||
Port: int(e.Port()),
|
||||
Zone: e.Addr().Zone(),
|
||||
}), err
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) ClearSrc() {
|
||||
if e.src != nil {
|
||||
// Truncate src, no need to reallocate.
|
||||
e.src = e.src[:0]
|
||||
}
|
||||
}
|
||||
func (*StdNetEndpoint) ClearSrc() {}
|
||||
|
||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
|
||||
return a
|
||||
}
|
||||
|
||||
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{} // not supported
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
addr := (*net.UDPAddr)(e)
|
||||
out := addr.IP.To4()
|
||||
if out == nil {
|
||||
out = addr.IP
|
||||
}
|
||||
out = append(out, byte(addr.Port&0xff))
|
||||
out = append(out, byte((addr.Port>>8)&0xff))
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
return (*net.UDPAddr)(e).String()
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
@ -134,17 +88,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
|||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||
return conn, uaddr.Port, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var tries int
|
||||
|
||||
if s.ipv4 != nil || s.ipv6 != nil {
|
||||
if bind.ipv4 != nil || bind.ipv6 != nil {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
|
@ -152,207 +106,92 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
|||
// If uport is 0, we can retry on failure.
|
||||
again:
|
||||
port := int(uport)
|
||||
var v4conn, v6conn *net.UDPConn
|
||||
var v4pc *ipv4.PacketConn
|
||||
var v6pc *ipv6.PacketConn
|
||||
var ipv4, ipv6 *net.UDPConn
|
||||
|
||||
v4conn, port, err = listenNet("udp4", port)
|
||||
ipv4, port, err = listenNet("udp4", port)
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Listen on the same port as we're using for ipv4.
|
||||
v6conn, port, err = listenNet("udp6", port)
|
||||
ipv6, port, err = listenNet("udp6", port)
|
||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
v4conn.Close()
|
||||
ipv4.Close()
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
v4conn.Close()
|
||||
ipv4.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
var fns []ReceiveFunc
|
||||
if v4conn != nil {
|
||||
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
v4pc = ipv4.NewPacketConn(v4conn)
|
||||
s.ipv4PC = v4pc
|
||||
if ipv4 != nil {
|
||||
fns = append(fns, bind.makeReceiveIPv4(ipv4))
|
||||
bind.ipv4 = ipv4
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||
s.ipv4 = v4conn
|
||||
}
|
||||
if v6conn != nil {
|
||||
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
v6pc = ipv6.NewPacketConn(v6conn)
|
||||
s.ipv6PC = v6pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||
s.ipv6 = v6conn
|
||||
if ipv6 != nil {
|
||||
fns = append(fns, bind.makeReceiveIPv6(ipv6))
|
||||
bind.ipv6 = ipv6
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
s.msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
var (
|
||||
// If compilation fails here these are no longer the same underlying type.
|
||||
_ ipv6.Message = ipv4.Message{}
|
||||
)
|
||||
|
||||
type batchReader interface {
|
||||
ReadBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
type batchWriter interface {
|
||||
WriteBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) receiveIP(
|
||||
br batchReader,
|
||||
conn *net.UDPConn,
|
||||
rxOffload bool,
|
||||
bufs [][]byte,
|
||||
sizes []int,
|
||||
eps []Endpoint,
|
||||
) (n int, err error) {
|
||||
msgs := s.getMessages()
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer s.putMessages(msgs)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
||||
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (s *StdNetBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
return IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
func (bind *StdNetBind) Close() error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if s.ipv4 != nil {
|
||||
err1 = s.ipv4.Close()
|
||||
s.ipv4 = nil
|
||||
s.ipv4PC = nil
|
||||
if bind.ipv4 != nil {
|
||||
err1 = bind.ipv4.Close()
|
||||
bind.ipv4 = nil
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
err2 = s.ipv6.Close()
|
||||
s.ipv6 = nil
|
||||
s.ipv6PC = nil
|
||||
if bind.ipv6 != nil {
|
||||
err2 = bind.ipv6.Close()
|
||||
bind.ipv6 = nil
|
||||
}
|
||||
s.blackhole4 = false
|
||||
s.blackhole6 = false
|
||||
s.ipv4TxOffload = false
|
||||
s.ipv4RxOffload = false
|
||||
s.ipv6TxOffload = false
|
||||
s.ipv6RxOffload = false
|
||||
bind.blackhole4 = false
|
||||
bind.blackhole6 = false
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
type ErrUDPGSODisabled struct {
|
||||
onLaddr string
|
||||
RetryErr error
|
||||
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
|
||||
return func(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := conn.ReadFromUDP(buff)
|
||||
if endpoint != nil {
|
||||
endpoint.IP = endpoint.IP.To4()
|
||||
}
|
||||
return n, (*StdNetEndpoint)(endpoint), err
|
||||
}
|
||||
}
|
||||
|
||||
func (e ErrUDPGSODisabled) Error() string {
|
||||
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
|
||||
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
|
||||
return func(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := conn.ReadFromUDP(buff)
|
||||
return n, (*StdNetEndpoint)(endpoint), err
|
||||
}
|
||||
}
|
||||
|
||||
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||
return e.RetryErr
|
||||
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||
var err error
|
||||
nend, ok := endpoint.(*StdNetEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
s.mu.Lock()
|
||||
blackhole := s.blackhole4
|
||||
conn := s.ipv4
|
||||
offload := s.ipv4TxOffload
|
||||
br := batchWriter(s.ipv4PC)
|
||||
is6 := false
|
||||
if endpoint.DstIP().Is6() {
|
||||
blackhole = s.blackhole6
|
||||
conn = s.ipv6
|
||||
br = s.ipv6PC
|
||||
is6 = true
|
||||
offload = s.ipv6TxOffload
|
||||
bind.mu.Lock()
|
||||
blackhole := bind.blackhole4
|
||||
conn := bind.ipv4
|
||||
if nend.IP.To4() == nil {
|
||||
blackhole = bind.blackhole6
|
||||
conn = bind.ipv6
|
||||
}
|
||||
s.mu.Unlock()
|
||||
bind.mu.Unlock()
|
||||
|
||||
if blackhole {
|
||||
return nil
|
||||
|
@ -360,185 +199,6 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
|||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
msgs := s.getMessages()
|
||||
defer s.putMessages(msgs)
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
defer s.udpAddrPool.Put(ua)
|
||||
if is6 {
|
||||
as16 := endpoint.DstIP().As16()
|
||||
copy(ua.IP, as16[:])
|
||||
ua.IP = ua.IP[:16]
|
||||
} else {
|
||||
as4 := endpoint.DstIP().As4()
|
||||
copy(ua.IP, as4[:])
|
||||
ua.IP = ua.IP[:4]
|
||||
}
|
||||
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||
var (
|
||||
retried bool
|
||||
err error
|
||||
)
|
||||
retry:
|
||||
if offload {
|
||||
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||
err = s.send(conn, br, (*msgs)[:n])
|
||||
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||
offload = false
|
||||
s.mu.Lock()
|
||||
if is6 {
|
||||
s.ipv6TxOffload = false
|
||||
} else {
|
||||
s.ipv4TxOffload = false
|
||||
}
|
||||
s.mu.Unlock()
|
||||
retried = true
|
||||
goto retry
|
||||
}
|
||||
} else {
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Addr = ua
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||
}
|
||||
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||
}
|
||||
if retried {
|
||||
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||
}
|
||||
_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
for {
|
||||
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||
if err != nil || n == len(msgs[start:]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for _, msg := range msgs {
|
||||
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const (
|
||||
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||
// length field is self excluding.
|
||||
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||
|
||||
// This is a hard limit imposed by the kernel.
|
||||
udpSegmentMaxDatagrams = 64
|
||||
)
|
||||
|
||||
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||
|
||||
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||
var (
|
||||
base = -1 // index of msg we are currently coalescing into
|
||||
gsoSize int // segmentation size of msgs[base]
|
||||
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||
)
|
||||
maxPayloadLen := maxIPv4PayloadLen
|
||||
if ep.DstIP().Is6() {
|
||||
maxPayloadLen = maxIPv6PayloadLen
|
||||
}
|
||||
for i, buf := range bufs {
|
||||
if i > 0 {
|
||||
msgLen := len(buf)
|
||||
baseLenBefore := len(msgs[base].Buffers[0])
|
||||
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||
msgLen <= gsoSize &&
|
||||
msgLen <= freeBaseCap &&
|
||||
dgramCnt < udpSegmentMaxDatagrams &&
|
||||
!endBatch {
|
||||
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||
if i == len(bufs)-1 {
|
||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
dgramCnt++
|
||||
if msgLen < gsoSize {
|
||||
// A smaller than gsoSize packet on the tail is legal, but
|
||||
// it must end the batch.
|
||||
endBatch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dgramCnt > 1 {
|
||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
// Reset prior to incrementing base since we are preparing to start a
|
||||
// new potential batch.
|
||||
endBatch = false
|
||||
base++
|
||||
gsoSize = len(buf)
|
||||
setSrcControl(&msgs[base].OOB, ep)
|
||||
msgs[base].Buffers[0] = buf
|
||||
msgs[base].Addr = addr
|
||||
dgramCnt = 1
|
||||
}
|
||||
return base + 1
|
||||
}
|
||||
|
||||
type getGSOFunc func(control []byte) (int, error)
|
||||
|
||||
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||
for i := firstMsgAt; i < len(msgs); i++ {
|
||||
msg := &msgs[i]
|
||||
if msg.N == 0 {
|
||||
return n, err
|
||||
}
|
||||
var (
|
||||
gsoSize int
|
||||
start int
|
||||
end = msg.N
|
||||
numToSplit = 1
|
||||
)
|
||||
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if gsoSize > 0 {
|
||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||
end = gsoSize
|
||||
}
|
||||
for j := 0; j < numToSplit; j++ {
|
||||
if n > i {
|
||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||
}
|
||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||
msgs[n].N = copied
|
||||
msgs[n].Addr = msg.Addr
|
||||
start = end
|
||||
end += gsoSize
|
||||
if end > msg.N {
|
||||
end = msg.N
|
||||
}
|
||||
n++
|
||||
}
|
||||
if i != n-1 {
|
||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||
// of splitting, so we only zero the source msg len when it is not
|
||||
// the destination of the last split operation above.
|
||||
msg.N = 0
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
|
|
@ -1,250 +0,0 @@
|
|||
package conn
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||
bind := NewStdNetBind().(*StdNetBind)
|
||||
fns, _, err := bind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bind.Close()
|
||||
bufs := make([][]byte, 1)
|
||||
bufs[0] = make([]byte, 1)
|
||||
sizes := make([]int, 1)
|
||||
eps := make([]Endpoint, 1)
|
||||
for _, fn := range fns {
|
||||
// The ReceiveFuncs must not access conn-related fields on StdNetBind
|
||||
// unguarded. Close() nils the conn-related fields resulting in a panic
|
||||
// if they violate the mutex.
|
||||
fn(bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||
}
|
||||
|
||||
func Test_coalesceMessages(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
buffs [][]byte
|
||||
wantLens []int
|
||||
wantGSO []int
|
||||
}{
|
||||
{
|
||||
name: "one message no coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{1},
|
||||
wantGSO: []int{0},
|
||||
},
|
||||
{
|
||||
name: "two messages equal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 2),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{2},
|
||||
wantGSO: []int{1},
|
||||
},
|
||||
{
|
||||
name: "two messages unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{3},
|
||||
wantGSO: []int{2},
|
||||
},
|
||||
{
|
||||
name: "three messages second unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{3, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
{
|
||||
name: "three messages limited cap coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 4),
|
||||
make([]byte, 2, 2),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{4, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1").To4(),
|
||||
Port: 1,
|
||||
}
|
||||
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make([][]byte, 1)
|
||||
msgs[i].OOB = make([]byte, 0, 2)
|
||||
}
|
||||
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
|
||||
if got != len(tt.wantLens) {
|
||||
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||
}
|
||||
for i := 0; i < got; i++ {
|
||||
if msgs[i].Addr != addr {
|
||||
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||
}
|
||||
gotLen := len(msgs[i].Buffers[0])
|
||||
if gotLen != tt.wantLens[i] {
|
||||
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||
}
|
||||
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
|
||||
if err != nil {
|
||||
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||
}
|
||||
if gotGSO != tt.wantGSO[i] {
|
||||
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mockGetGSOSize(control []byte) (int, error) {
|
||||
if len(control) < 2 {
|
||||
return 0, nil
|
||||
}
|
||||
return int(binary.LittleEndian.Uint16(control)), nil
|
||||
}
|
||||
|
||||
func Test_splitCoalescedMessages(t *testing.T) {
|
||||
newMsg := func(n, gso int) ipv6.Message {
|
||||
msg := ipv6.Message{
|
||||
Buffers: [][]byte{make([]byte, 1<<16-1)},
|
||||
N: n,
|
||||
OOB: make([]byte, 2),
|
||||
}
|
||||
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||
if gso > 0 {
|
||||
msg.NN = 2
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
msgs []ipv6.Message
|
||||
firstMsgAt int
|
||||
wantNumEval int
|
||||
wantMsgLens []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "second last split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(3, 1),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 3,
|
||||
wantMsgLens: []int{1, 1, 1, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 1,
|
||||
wantMsgLens: []int{1, 0, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last no split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(1, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 2,
|
||||
wantMsgLens: []int{1, 1, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(3, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(2, 1),
|
||||
newMsg(2, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split overflow",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(4, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
|
||||
if err != nil && !tt.wantErr {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got != tt.wantNumEval {
|
||||
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||
}
|
||||
for i, msg := range tt.msgs {
|
||||
if msg.N != tt.wantMsgLens[i] {
|
||||
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
|
||||
"golang.zx2c4.com/wireguard/conn/winrio"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -74,7 +74,7 @@ type afWinRingBind struct {
|
|||
type WinRingBind struct {
|
||||
v4, v6 afWinRingBind
|
||||
mu sync.RWMutex
|
||||
isOpen atomic.Uint32 // 0, 1, or 2
|
||||
isOpen uint32
|
||||
}
|
||||
|
||||
func NewDefaultBind() Bind { return NewWinRingBind() }
|
||||
|
@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
|
|||
func (e *WinRingEndpoint) DstToString() string {
|
||||
switch e.family {
|
||||
case windows.AF_INET:
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||
case windows.AF_INET6:
|
||||
var zone string
|
||||
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
||||
|
@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
|
|||
}
|
||||
|
||||
func (bind *WinRingBind) closeAndZero() {
|
||||
bind.isOpen.Store(0)
|
||||
atomic.StoreUint32(&bind.isOpen, 0)
|
||||
bind.v4.CloseAndZero()
|
||||
bind.v6.CloseAndZero()
|
||||
}
|
||||
|
@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
|
|||
bind.closeAndZero()
|
||||
}
|
||||
}()
|
||||
if bind.isOpen.Load() != 0 {
|
||||
if atomic.LoadUint32(&bind.isOpen) != 0 {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
var sa windows.Sockaddr
|
||||
|
@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
|
|||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
bind.isOpen.Store(1)
|
||||
atomic.StoreUint32(&bind.isOpen, 1)
|
||||
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Close() error {
|
||||
bind.mu.RLock()
|
||||
if bind.isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
||||
bind.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
bind.isOpen.Store(2)
|
||||
atomic.StoreUint32(&bind.isOpen, 2)
|
||||
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
||||
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
||||
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
||||
|
@ -321,13 +321,6 @@ func (bind *WinRingBind) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (bind *WinRingBind) BatchSize() int {
|
||||
// TODO: implement batching in and out of the ring
|
||||
return 1
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -352,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
|
|||
//go:linkname procyield runtime.procyield
|
||||
func procyield(cycles uint32)
|
||||
|
||||
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
|
||||
if isOpen.Load() != 1 {
|
||||
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
|
||||
if atomic.LoadUint32(isOpen) != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
bind.rx.mu.Lock()
|
||||
|
@ -366,7 +359,7 @@ retry:
|
|||
count = 0
|
||||
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
||||
if tries > 0 {
|
||||
if isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(isOpen) != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
procyield(1)
|
||||
|
@ -385,7 +378,7 @@ retry:
|
|||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(isOpen) != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||
|
@ -402,7 +395,7 @@ retry:
|
|||
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
||||
// attacker bandwidth, just like the rest of the receive path.
|
||||
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
||||
if isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(isOpen) != 1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
goto retry
|
||||
|
@ -416,26 +409,20 @@ retry:
|
|||
return n, &ep, nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
return bind.v4.Receive(buf, &bind.isOpen)
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
return bind.v6.Receive(buf, &bind.isOpen)
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||
if isOpen.Load() != 1 {
|
||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
|
||||
if atomic.LoadUint32(isOpen) != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if len(buf) > bytesPerPacket {
|
||||
|
@ -457,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(isOpen) != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||
|
@ -486,38 +473,32 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
|
|||
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
|
||||
nend, ok := endpoint.(*WinRingEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
for _, buf := range bufs {
|
||||
switch nend.family {
|
||||
case windows.AF_INET:
|
||||
if bind.v4.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
return bind.v4.Send(buf, nend, &bind.isOpen)
|
||||
case windows.AF_INET6:
|
||||
if bind.v6.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return bind.v6.Send(buf, nend, &bind.isOpen)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -530,14 +511,14 @@ func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole boo
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.blackhole4 = blackhole
|
||||
bind.blackhole4 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -550,14 +531,14 @@ func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole boo
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.blackhole6 = blackhole
|
||||
bind.blackhole6 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
||||
|
@ -571,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
|||
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.isOpen.Load() != 1 {
|
||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bindtest
|
||||
|
@ -12,7 +12,7 @@ import (
|
|||
"net/netip"
|
||||
"os"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type ChannelBind struct {
|
||||
|
@ -89,26 +89,20 @@ func (c *ChannelBind) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *ChannelBind) BatchSize() int { return 1 }
|
||||
|
||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
return func(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return 0, net.ErrClosed
|
||||
return 0, nil, net.ErrClosed
|
||||
case rx := <-ch:
|
||||
copied := copy(bufs[0], rx)
|
||||
sizes[0] = copied
|
||||
eps[0] = c.target6
|
||||
return 1, nil
|
||||
return copy(b, rx), c.target6, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
for _, b := range bufs {
|
||||
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return net.ErrClosed
|
||||
|
@ -123,7 +117,6 @@ func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|||
return os.ErrInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
@ -19,8 +19,8 @@ func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
|
26
conn/conn.go
26
conn/conn.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package conn implements WireGuard's network connections.
|
||||
|
@ -15,17 +15,10 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||
// into packets. On a successful read it returns the number of elements of
|
||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||
// and eps slice with a length greater than or equal to the length of packets.
|
||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||
// A ReceiveFunc receives a single inbound packet from the network.
|
||||
// It writes the data into b. n is the length of the packet.
|
||||
// ep is the remote endpoint.
|
||||
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
|
@ -45,16 +38,11 @@ type Bind interface {
|
|||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// Send writes one or more packets in bufs to address ep. The length of
|
||||
// bufs must not exceed BatchSize().
|
||||
Send(bufs [][]byte, ep Endpoint) error
|
||||
// Send writes a packet b to address ep.
|
||||
Send(b []byte, ep Endpoint) error
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
ParseEndpoint(s string) (Endpoint, error)
|
||||
|
||||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrettyName(t *testing.T) {
|
||||
var (
|
||||
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
|
||||
)
|
||||
|
||||
const want = "TestPrettyName"
|
||||
|
||||
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
|
||||
if got := recvFunc.PrettyName(); got != want {
|
||||
t.Errorf("PrettyName() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||
// the max supported by a default configuration of macOS. Some platforms will
|
||||
// silently clamp the value to other maximums, such as linux clamping to
|
||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||
// around this limitation)
|
||||
const socketBufferSize = 7 << 20
|
||||
|
||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||
// It is used to apply platform specific configuration to the socket prior to
|
||||
// bind.
|
||||
type controlFn func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// controlFns is a list of functions that are called from the listen config
|
||||
// that can apply socket options.
|
||||
var controlFns = []controlFn{}
|
||||
|
||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||
// information OOB configuration for sticky sockets.
|
||||
func listenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
for _, fn := range controlFns {
|
||||
if err := fn(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
|
@ -1,109 +0,0 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
|
||||
func kernelVersion() (major, minor int) {
|
||||
var uname unix.Utsname
|
||||
if err := unix.Uname(&uname); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
values [2]int
|
||||
value, vi int
|
||||
)
|
||||
for _, c := range uname.Release {
|
||||
if '0' <= c && c <= '9' {
|
||||
value = (value * 10) + int(c-'0')
|
||||
} else {
|
||||
// Note that we're assuming N.N.N here.
|
||||
// If we see anything else, we are likely to mis-parse it.
|
||||
values[vi] = value
|
||||
vi++
|
||||
if vi >= len(values) {
|
||||
break
|
||||
}
|
||||
value = 0
|
||||
}
|
||||
}
|
||||
|
||||
return values[0], values[1]
|
||||
}
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||
// fail silently - the result of failure is lower performance on very fast
|
||||
// links or high latency links.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
// Set up to *mem_max
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
switch network {
|
||||
case "udp4":
|
||||
if runtime.GOOS != "android" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||
})
|
||||
}
|
||||
case "udp6":
|
||||
c.Control(func(fd uintptr) {
|
||||
if runtime.GOOS != "android" {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||
}
|
||||
return err
|
||||
},
|
||||
|
||||
// Attempt to enable UDP_GRO
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
// Kernels below 5.12 are missing 98184612aca0 ("net:
|
||||
// udp: Add support for getsockopt(..., ..., UDP_GRO,
|
||||
// ..., ...);"), which means we can't read this back
|
||||
// later. We could pipe the return value through to
|
||||
// the rest of the code, but UDP_GRO is kind of buggy
|
||||
// anyway, so just gate this here.
|
||||
major, minor := kernelVersion()
|
||||
if major < 5 || (major == 5 && minor < 12) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
})
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
//go:build !windows && !linux && !wasm
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
if network == "udp6" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
|
||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
//go:build !windows
|
||||
//go:build !linux && !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func errShouldDisableUDPGSO(_ error) bool {
|
||||
return false
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
var serr *os.SyscallError
|
||||
if errors.As(err, &serr) {
|
||||
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||
// See:
|
||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
|
||||
// It occurs when the peer mtu + wg headers is greater than path mtu.
|
||||
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
return
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
rc, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||
txOffload = errSyscall == nil
|
||||
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
|
||||
// use setsockopt workaround
|
||||
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
rxOffload = errSyscall == nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
return txOffload, rxOffload
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
}
|
||||
|
||||
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||
// offloading control data.
|
||||
const gsoControlSize = 0
|
|
@ -1,65 +0,0 @@
|
|||
//go:build linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
sizeOfGSOData = 2
|
||||
)
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||
}
|
||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||
var gso uint16
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||
return int(gso), nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||
// data in control untouched.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
existingLen := len(*control)
|
||||
avail := cap(*control) - existingLen
|
||||
space := unix.CmsgSpace(sizeOfGSOData)
|
||||
if avail < space {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:cap(*control)]
|
||||
gsoControl := (*control)[existingLen:]
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||
*control = (*control)[:existingLen+space]
|
||||
}
|
||||
|
||||
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||
// offloading control data.
|
||||
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
func (bind *StdNetBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -26,13 +26,13 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
func (bind *StdNetBind) SetMark(mark uint32) error {
|
||||
var operr error
|
||||
if fwmarkIoctl == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.ipv4 != nil {
|
||||
fd, err := s.ipv4.SyscallConn()
|
||||
if bind.ipv4 != nil {
|
||||
fd, err := bind.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -46,8 +46,8 @@ func (s *StdNetBind) SetMark(mark uint32) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
if s.ipv6 != nil {
|
||||
fd, err := s.ipv6.SyscallConn()
|
||||
if bind.ipv6 != nil {
|
||||
fd, err := bind.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
//go:build !linux || android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||
// ports and require testing.
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||
// offloading control data.
|
||||
const stickyControlSize = 0
|
||||
|
||||
const StdNetSupportsStickySockets = false
|
|
@ -1,112 +0,0 @@
|
|||
//go:build linux && !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
switch len(e.src) {
|
||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return netip.AddrFrom4(info.Spec_dst)
|
||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
// TODO: set zone. in order to do so we need to check if the address is
|
||||
// link local, and if it is perform a syscall to turn the ifindex into a
|
||||
// zone string because netip uses string zones.
|
||||
return netip.AddrFrom16(info.Addr)
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
switch len(e.src) {
|
||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return info.Ifindex
|
||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return int32(info.Ifindex)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return e.SrcIP().String()
|
||||
}
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
ep.ClearSrc()
|
||||
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem []byte = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IP &&
|
||||
hdr.Type == unix.IP_PKTINFO {
|
||||
|
||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
}
|
||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||
|
||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||
copy(ep.src, hdrBuf)
|
||||
copy(ep.src[unix.CmsgLen(0):], data)
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||
hdr.Type == unix.IPV6_PKTINFO {
|
||||
|
||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||
}
|
||||
|
||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||
|
||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||
copy(ep.src, hdrBuf)
|
||||
copy(ep.src[unix.CmsgLen(0):], data)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||
// that ep is a default value.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
if cap(*control) < len(ep.src) {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:0]
|
||||
*control = append(*control, ep.src...)
|
||||
}
|
||||
|
||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||
// offloading control data.
|
||||
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||
|
||||
const StdNetSupportsStickySockets = true
|
|
@ -1,266 +0,0 @@
|
|||
//go:build linux && !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
|
||||
var buf []byte
|
||||
if addr.Is4() {
|
||||
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
hdr := unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IP,
|
||||
Type: unix.IP_PKTINFO,
|
||||
}
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||
|
||||
info := unix.Inet4Pktinfo{
|
||||
Ifindex: ifidx,
|
||||
Spec_dst: addr.As4(),
|
||||
}
|
||||
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
|
||||
} else {
|
||||
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||
hdr := unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IPV6,
|
||||
Type: unix.IPV6_PKTINFO,
|
||||
}
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
||||
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||
|
||||
info := unix.Inet6Pktinfo{
|
||||
Ifindex: uint32(ifidx),
|
||||
Addr: addr.As16(),
|
||||
}
|
||||
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
|
||||
}
|
||||
|
||||
ep.src = buf
|
||||
}
|
||||
|
||||
func Test_setSrcControl(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
ep := &StdNetEndpoint{
|
||||
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
||||
}
|
||||
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
||||
|
||||
control := make([]byte, stickyControlSize)
|
||||
|
||||
setSrcControl(&control, ep)
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
if hdr.Level != unix.IPPROTO_IP {
|
||||
t.Errorf("unexpected level: %d", hdr.Level)
|
||||
}
|
||||
if hdr.Type != unix.IP_PKTINFO {
|
||||
t.Errorf("unexpected type: %d", hdr.Type)
|
||||
}
|
||||
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
||||
t.Errorf("unexpected length: %d", hdr.Len)
|
||||
}
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
||||
t.Errorf("unexpected address: %v", info.Spec_dst)
|
||||
}
|
||||
if info.Ifindex != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
ep := &StdNetEndpoint{
|
||||
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
||||
}
|
||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||
|
||||
control := make([]byte, stickyControlSize)
|
||||
|
||||
setSrcControl(&control, ep)
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
if hdr.Level != unix.IPPROTO_IPV6 {
|
||||
t.Errorf("unexpected level: %d", hdr.Level)
|
||||
}
|
||||
if hdr.Type != unix.IPV6_PKTINFO {
|
||||
t.Errorf("unexpected type: %d", hdr.Type)
|
||||
}
|
||||
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
||||
t.Errorf("unexpected length: %d", hdr.Len)
|
||||
}
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
if info.Addr != ep.SrcIP().As16() {
|
||||
t.Errorf("unexpected address: %v", info.Addr)
|
||||
}
|
||||
if info.Ifindex != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||
control := make([]byte, stickyControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = 1
|
||||
hdr.Type = 2
|
||||
hdr.Len = 3
|
||||
|
||||
setSrcControl(&control, &StdNetEndpoint{})
|
||||
|
||||
if len(control) != 0 {
|
||||
t.Errorf("unexpected control: %v", control)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getSrcFromControl(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
control := make([]byte, stickyControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(control, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
control := make([]byte, stickyControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IPV6
|
||||
hdr.Type = unix.IPV6_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(control, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
t.Run("ClearOnEmpty", func(t *testing.T) {
|
||||
var control []byte
|
||||
ep := &StdNetEndpoint{}
|
||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||
|
||||
getSrcFromControl(control, ep)
|
||||
if ep.SrcIP().IsValid() {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 0 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
t.Run("Multiple", func(t *testing.T) {
|
||||
zeroControl := make([]byte, unix.CmsgSpace(0))
|
||||
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
|
||||
zeroHdr.SetLen(unix.CmsgLen(0))
|
||||
|
||||
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
combined := make([]byte, 0)
|
||||
combined = append(combined, zeroControl...)
|
||||
combined = append(combined, control...)
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(combined, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_listenConfig(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var i int
|
||||
sc.Control(func(fd uintptr) {
|
||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error("IP_PKTINFO not set!")
|
||||
}
|
||||
} else {
|
||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var i int
|
||||
sc.Control(func(fd uintptr) {
|
||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error("IPV6_PKTINFO not set!")
|
||||
}
|
||||
} else {
|
||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winrio
|
||||
|
|
65
device/alignment_test.go
Normal file
65
device/alignment_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func checkAlignment(t *testing.T, name string, offset uintptr) {
|
||||
t.Helper()
|
||||
if offset%8 != 0 {
|
||||
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerAlignment checks that atomically-accessed fields are
|
||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
||||
//
|
||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
||||
// hard segfault at runtime.
|
||||
func TestPeerAlignment(t *testing.T) {
|
||||
var p Peer
|
||||
|
||||
typ := reflect.TypeOf(&p).Elem()
|
||||
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
||||
field.Name,
|
||||
field.Offset,
|
||||
field.Type.Size(),
|
||||
field.Type.Align(),
|
||||
)
|
||||
}
|
||||
|
||||
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
|
||||
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
|
||||
}
|
||||
|
||||
// TestDeviceAlignment checks that atomically-accessed fields are
|
||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
||||
//
|
||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
||||
// hard segfault at runtime.
|
||||
func TestDeviceAlignment(t *testing.T) {
|
||||
var d Device
|
||||
|
||||
typ := reflect.TypeOf(&d).Elem()
|
||||
t.Logf("Device type size: %d, with fields:", typ.Size())
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
||||
field.Name,
|
||||
field.Offset,
|
||||
field.Type.Size(),
|
||||
field.Type.Align(),
|
||||
)
|
||||
}
|
||||
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -223,11 +223,19 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
|
|||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) remove() {
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
return
|
||||
continue
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
|
@ -240,12 +248,12 @@ func (node *trieEntry) remove() {
|
|||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
return
|
||||
continue
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
return
|
||||
continue
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
|
@ -255,37 +263,6 @@ func (node *trieEntry) remove() {
|
|||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
var node *trieEntry
|
||||
var exact bool
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else {
|
||||
panic(errors.New("removing unknown address type"))
|
||||
}
|
||||
if !exact || node == nil || peer != node.peer {
|
||||
return
|
||||
}
|
||||
node.remove()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
elem.Value.(*trieEntry).remove()
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
var peers []*Peer
|
||||
var allowedIPs AllowedIPs
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rand.Seed(1)
|
||||
|
||||
for n := 0; n < NumberOfPeers; n++ {
|
||||
peers = append(peers, &Peer{})
|
||||
|
@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) {
|
|||
|
||||
for n := 0; n < NumberOfAddresses; n++ {
|
||||
var addr4 [4]byte
|
||||
rng.Read(addr4[:])
|
||||
rand.Read(addr4[:])
|
||||
cidr := uint8(rand.Intn(32) + 1)
|
||||
index := rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||
|
||||
var addr6 [16]byte
|
||||
rng.Read(addr6[:])
|
||||
rand.Read(addr6[:])
|
||||
cidr = uint8(rand.Intn(128) + 1)
|
||||
index = rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||
|
@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
for p = 0; ; p++ {
|
||||
for n := 0; n < NumberOfTests; n++ {
|
||||
var addr4 [4]byte
|
||||
rng.Read(addr4[:])
|
||||
rand.Read(addr4[:])
|
||||
peer1 := slow4.Lookup(addr4[:])
|
||||
peer2 := allowedIPs.Lookup(addr4[:])
|
||||
if peer1 != peer2 {
|
||||
|
@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
}
|
||||
|
||||
var addr6 [16]byte
|
||||
rng.Read(addr6[:])
|
||||
rand.Read(addr6[:])
|
||||
peer1 = slow6.Lookup(addr6[:])
|
||||
peer2 = allowedIPs.Lookup(addr6[:])
|
||||
if peer1 != peer2 {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
|
||||
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||
var trie *trieEntry
|
||||
var peers []*Peer
|
||||
root := parentIndirection{&trie, 2}
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rand.Seed(1)
|
||||
|
||||
const AddressLength = 4
|
||||
|
||||
|
@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
|
|||
|
||||
for n := 0; n < addressNumber; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rng.Read(addr[:])
|
||||
cidr := uint8(rng.Uint32() % (AddressLength * 8))
|
||||
index := rng.Int() % peerNumber
|
||||
rand.Read(addr[:])
|
||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||
index := rand.Int() % peerNumber
|
||||
root.insert(addr[:], cidr, peers[index])
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rng.Read(addr[:])
|
||||
rand.Read(addr[:])
|
||||
trie.lookup(addr[:])
|
||||
}
|
||||
}
|
||||
|
@ -101,10 +101,6 @@ func TestTrieIPv4(t *testing.T) {
|
|||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||
if p != peer {
|
||||
|
@ -180,21 +176,6 @@ func TestTrieIPv4(t *testing.T) {
|
|||
allowedIPs.RemoveByPeer(a)
|
||||
|
||||
assertNEQ(a, 192, 168, 0, 1)
|
||||
|
||||
insert(a, 1, 0, 0, 0, 32)
|
||||
insert(a, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 1, 0, 0, 0)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 32)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(nil, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(b, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 24)
|
||||
assertNEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 1, 0, 0, 0, 32)
|
||||
assertNEQ(a, 1, 0, 0, 0)
|
||||
}
|
||||
|
||||
/* Test ported from kernel implementation:
|
||||
|
@ -230,15 +211,6 @@ func TestTrieIPv6(t *testing.T) {
|
|||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
|
@ -251,18 +223,6 @@ func TestTrieIPv6(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
p := allowedIPs.Lookup(addr)
|
||||
if p == peer {
|
||||
t.Error("Assert NEQ failed")
|
||||
}
|
||||
}
|
||||
|
||||
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
||||
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
||||
insert(e, 0, 0, 0, 0, 0)
|
||||
|
@ -284,21 +244,4 @@ func TestTrieIPv6(t *testing.T) {
|
|||
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
||||
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
||||
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
||||
|
||||
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
}
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
type Cfg struct {
|
||||
IsSet bool
|
||||
JunkPacketCount int
|
||||
JunkPacketMinSize int
|
||||
JunkPacketMaxSize int
|
||||
InitHeaderJunkSize int
|
||||
ResponseHeaderJunkSize int
|
||||
CookieReplyHeaderJunkSize int
|
||||
TransportHeaderJunkSize int
|
||||
|
||||
MagicHeaders MagicHeaders
|
||||
}
|
||||
|
||||
type Protocol struct {
|
||||
IsOn abool.AtomicBool
|
||||
// TODO: revision the need of the mutex
|
||||
Mux sync.RWMutex
|
||||
Cfg Cfg
|
||||
JunkCreator JunkCreator
|
||||
|
||||
HandshakeHandler SpecialHandshakeHandler
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
|
||||
if junkSize == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, junkSize+extraSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
||||
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("append junk: %w", err)
|
||||
}
|
||||
|
||||
return writer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
|
||||
for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
|
||||
if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
|
||||
return magicHeader.Min, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no header for value: %d", msgType)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
|
||||
return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package internal
|
||||
|
||||
type mockGenerator struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func NewMockGenerator(size int) mockGenerator {
|
||||
return mockGenerator{size: size}
|
||||
}
|
||||
|
||||
func (m mockGenerator) Generate() []byte {
|
||||
return make([]byte, m.size)
|
||||
}
|
||||
|
||||
func (m mockGenerator) Size() int {
|
||||
return m.size
|
||||
}
|
||||
|
||||
func (m mockGenerator) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
type mockByteGenerator struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewMockByteGenerator(data []byte) mockByteGenerator {
|
||||
return mockByteGenerator{data: data}
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Generate() []byte {
|
||||
return bg.data
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Size() int {
|
||||
return len(bg.data)
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type JunkCreator struct {
|
||||
cfg Cfg
|
||||
randomGenerator PRNG[int]
|
||||
}
|
||||
|
||||
// TODO: refactor param to only pass the junk related params
|
||||
func NewJunkCreator(cfg Cfg) JunkCreator {
|
||||
return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
|
||||
if jc.cfg.JunkPacketCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for range jc.cfg.JunkPacketCount {
|
||||
packetSize := jc.randomPacketSize()
|
||||
junk := jc.randomJunkWithSize(packetSize)
|
||||
*junks = append(*junks, junk)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) randomPacketSize() int {
|
||||
return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk := jc.randomJunkWithSize(size)
|
||||
_, err := writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write header junk: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
|
||||
return jc.randomGenerator.ReadSize(size)
|
||||
}
|
|
@ -1,97 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setUpJunkCreator() JunkCreator {
|
||||
mh, _ := NewMagicHeaders(
|
||||
[]MagicHeader{
|
||||
NewMagicHeaderSameValue(123456),
|
||||
NewMagicHeaderSameValue(67543),
|
||||
NewMagicHeaderSameValue(32345),
|
||||
NewMagicHeaderSameValue(123123),
|
||||
},
|
||||
)
|
||||
|
||||
jc := NewJunkCreator(Cfg{
|
||||
IsSet: true,
|
||||
JunkPacketCount: 5,
|
||||
JunkPacketMinSize: 500,
|
||||
JunkPacketMaxSize: 1000,
|
||||
InitHeaderJunkSize: 30,
|
||||
ResponseHeaderJunkSize: 40,
|
||||
MagicHeaders: mh,
|
||||
})
|
||||
|
||||
return jc
|
||||
}
|
||||
|
||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
got := make([][]byte, 0, jc.cfg.JunkPacketCount)
|
||||
jc.CreateJunkPackets(&got)
|
||||
seen := make(map[string]bool)
|
||||
for _, junk := range got {
|
||||
key := string(junk)
|
||||
if seen[key] {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
||||
got,
|
||||
junk,
|
||||
)
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
r1 := jc.randomJunkWithSize(10)
|
||||
r2 := jc.randomJunkWithSize(10)
|
||||
fmt.Printf("%v\n%v\n", r1, r2)
|
||||
if bytes.Equal(r1, r2) {
|
||||
t.Errorf("same junks")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
for range [30]struct{}{} {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
|
||||
got > jc.cfg.JunkPacketMaxSize {
|
||||
t.Errorf(
|
||||
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||
got,
|
||||
jc.cfg.JunkPacketMinSize,
|
||||
jc.cfg.JunkPacketMaxSize,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
err := jc.AppendJunk(buffer, 30)
|
||||
if err != nil &&
|
||||
buffer.Len() != len(s)+30 {
|
||||
t.Error("appendWithJunk() size don't match")
|
||||
}
|
||||
read := make([]byte, 50)
|
||||
buffer.Read(read)
|
||||
fmt.Println(string(read))
|
||||
})
|
||||
}
|
|
@ -1,97 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MagicHeader struct {
|
||||
Min uint32
|
||||
Max uint32
|
||||
}
|
||||
|
||||
func NewMagicHeaderSameValue(value uint32) MagicHeader {
|
||||
return MagicHeader{Min: value, Max: value}
|
||||
}
|
||||
|
||||
func NewMagicHeader(min, max uint32) (MagicHeader, error) {
|
||||
if min > max {
|
||||
return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
|
||||
}
|
||||
|
||||
return MagicHeader{Min: min, Max: max}, nil
|
||||
}
|
||||
|
||||
func ParseMagicHeader(key, value string) (MagicHeader, error) {
|
||||
hyphenIdx := strings.Index(value, "-")
|
||||
if hyphenIdx == -1 {
|
||||
// if there is no hyphen, we treat it as single magic header value
|
||||
magicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
|
||||
}
|
||||
|
||||
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
|
||||
}
|
||||
|
||||
minStr := value[:hyphenIdx]
|
||||
maxStr := value[hyphenIdx+1:]
|
||||
if len(minStr) == 0 || len(maxStr) == 0 {
|
||||
return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value)
|
||||
}
|
||||
|
||||
min, err := strconv.ParseUint(minStr, 10, 32)
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, minStr, err)
|
||||
}
|
||||
|
||||
max, err := strconv.ParseUint(maxStr, 10, 32)
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, maxStr, err)
|
||||
}
|
||||
|
||||
magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, minStr, maxStr, err)
|
||||
}
|
||||
|
||||
return magicHeader, nil
|
||||
}
|
||||
|
||||
type MagicHeaders struct {
|
||||
Values []MagicHeader
|
||||
randomGenerator RandomNumberGenerator[uint32]
|
||||
}
|
||||
|
||||
func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
|
||||
if len(headerValues) != 4 {
|
||||
return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues)
|
||||
}
|
||||
|
||||
sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int {
|
||||
return cmp.Compare(lhs.Min, rhs.Min)
|
||||
})
|
||||
|
||||
for i := range 3 {
|
||||
if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min {
|
||||
return MagicHeaders{}, fmt.Errorf(
|
||||
"magic headers shouldn't overlap; %v > %v",
|
||||
sortedMagicHeaders[i].Max,
|
||||
sortedMagicHeaders[i+1].Min,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
|
||||
}
|
||||
|
||||
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
|
||||
if defaultMsgType == 0 || defaultMsgType > 4 {
|
||||
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
|
||||
}
|
||||
|
||||
return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
|
||||
}
|
|
@ -1,488 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMagicHeaderSameValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value uint32
|
||||
expected MagicHeader
|
||||
}{
|
||||
{
|
||||
name: "zero value",
|
||||
value: 0,
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "small value",
|
||||
value: 1,
|
||||
expected: MagicHeader{Min: 1, Max: 1},
|
||||
},
|
||||
{
|
||||
name: "large value",
|
||||
value: 4294967295, // max uint32
|
||||
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "medium value",
|
||||
value: 1000,
|
||||
expected: MagicHeader{Min: 1000, Max: 1000},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := NewMagicHeaderSameValue(tt.value)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMagicHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
min uint32
|
||||
max uint32
|
||||
expected MagicHeader
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid range",
|
||||
min: 1,
|
||||
max: 10,
|
||||
expected: MagicHeader{Min: 1, Max: 10},
|
||||
},
|
||||
{
|
||||
name: "equal values",
|
||||
min: 5,
|
||||
max: 5,
|
||||
expected: MagicHeader{Min: 5, Max: 5},
|
||||
},
|
||||
{
|
||||
name: "zero range",
|
||||
min: 0,
|
||||
max: 0,
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "max uint32 range",
|
||||
min: 4294967294,
|
||||
max: 4294967295,
|
||||
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "min greater than max",
|
||||
min: 10,
|
||||
max: 5,
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "min (10) cannot be greater than max (5)",
|
||||
},
|
||||
{
|
||||
name: "large min greater than max",
|
||||
min: 4294967295,
|
||||
max: 1,
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "min (4294967295) cannot be greater than max (1)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := NewMagicHeader(tt.min, tt.max)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, MagicHeader{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMagicHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
expected MagicHeader
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "single value",
|
||||
key: "header1",
|
||||
value: "100",
|
||||
expected: MagicHeader{Min: 100, Max: 100},
|
||||
},
|
||||
{
|
||||
name: "valid range",
|
||||
key: "header2",
|
||||
value: "10-20",
|
||||
expected: MagicHeader{Min: 10, Max: 20},
|
||||
},
|
||||
{
|
||||
name: "zero single value",
|
||||
key: "header3",
|
||||
value: "0",
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "zero range",
|
||||
key: "header4",
|
||||
value: "0-0",
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "max uint32 single",
|
||||
key: "header5",
|
||||
value: "4294967295",
|
||||
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "max uint32 range",
|
||||
key: "header6",
|
||||
value: "4294967294-4294967295",
|
||||
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "invalid single value - not number",
|
||||
key: "header7",
|
||||
value: "abc",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header7; value: abc;",
|
||||
},
|
||||
{
|
||||
name: "invalid single value - negative",
|
||||
key: "header8",
|
||||
value: "-5",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header8; value: -5;",
|
||||
},
|
||||
{
|
||||
name: "invalid single value - too large",
|
||||
key: "header9",
|
||||
value: "4294967296",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header9; value: 4294967296;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - min not number",
|
||||
key: "header10",
|
||||
value: "abc-10",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse min key: header10; value: abc;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - max not number",
|
||||
key: "header11",
|
||||
value: "10-abc",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse max key: header11; value: abc;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - min greater than max",
|
||||
key: "header12",
|
||||
value: "20-10",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "new magicHeader key: header12; value: 20-10;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - too many parts",
|
||||
key: "header13",
|
||||
value: "10-20-30",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header13; value: 10-20-30;",
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
key: "header14",
|
||||
value: "",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header14; value: ;",
|
||||
},
|
||||
{
|
||||
name: "hyphen only",
|
||||
key: "header15",
|
||||
value: "-",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header15; value: -;",
|
||||
},
|
||||
{
|
||||
name: "empty min",
|
||||
key: "header16",
|
||||
value: "-10",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header16; value: -10;",
|
||||
},
|
||||
{
|
||||
name: "empty max",
|
||||
key: "header17",
|
||||
value: "10-",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header17; value: 10-;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := ParseMagicHeader(tt.key, tt.value)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, MagicHeader{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMagicHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
magicHeaders []MagicHeader
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid non-overlapping headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
{Min: 31, Max: 40},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid adjacent headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 1},
|
||||
{Min: 2, Max: 2},
|
||||
{Min: 3, Max: 3},
|
||||
{Min: 4, Max: 4},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid zero-based headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 0, Max: 0},
|
||||
{Min: 1, Max: 1},
|
||||
{Min: 2, Max: 2},
|
||||
{Min: 3, Max: 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid large value headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 4294967290, Max: 4294967291},
|
||||
{Min: 4294967292, Max: 4294967293},
|
||||
{Min: 4294967294, Max: 4294967294},
|
||||
{Min: 4294967295, Max: 4294967295},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too few headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
},
|
||||
errorMsg: "all header types should be included:",
|
||||
},
|
||||
{
|
||||
name: "too many headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
{Min: 31, Max: 40},
|
||||
{Min: 41, Max: 50},
|
||||
},
|
||||
errorMsg: "all header types should be included:",
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
magicHeaders: []MagicHeader{},
|
||||
errorMsg: "all header types should be included:",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 15},
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers at limit-first",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers at limit-second",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 15, Max: 25},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers at limit-third",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 15, Max: 25},
|
||||
{Min: 30, Max: 35},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "identical ranges",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := NewMagicHeaders(tt.magicHeaders)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, MagicHeaders{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.magicHeaders, result.Values)
|
||||
require.NotNil(t, result.randomGenerator)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock PRNG for testing
|
||||
type mockPRNG struct {
|
||||
returnValue uint32
|
||||
}
|
||||
|
||||
func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 {
|
||||
return m.returnValue
|
||||
}
|
||||
|
||||
func (m *mockPRNG) Get() uint64 {
|
||||
return 0
|
||||
}
|
||||
func (m *mockPRNG) ReadSize(size int) []byte {
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func TestMagicHeaders_Get(t *testing.T) {
|
||||
// Create test headers
|
||||
headers := []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
{Min: 31, Max: 40},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
defaultMsgType uint32
|
||||
mockValue uint32
|
||||
expectedValue uint32
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid type 1",
|
||||
defaultMsgType: 1,
|
||||
mockValue: 5,
|
||||
expectedValue: 5,
|
||||
},
|
||||
{
|
||||
name: "valid type 2",
|
||||
defaultMsgType: 2,
|
||||
mockValue: 15,
|
||||
expectedValue: 15,
|
||||
},
|
||||
{
|
||||
name: "valid type 3",
|
||||
defaultMsgType: 3,
|
||||
mockValue: 25,
|
||||
expectedValue: 25,
|
||||
},
|
||||
{
|
||||
name: "valid type 4",
|
||||
defaultMsgType: 4,
|
||||
mockValue: 35,
|
||||
expectedValue: 35,
|
||||
},
|
||||
{
|
||||
name: "invalid type 0",
|
||||
defaultMsgType: 0,
|
||||
mockValue: 0,
|
||||
expectedValue: 0,
|
||||
errorMsg: "invalid msg type: 0",
|
||||
},
|
||||
{
|
||||
name: "invalid type 5",
|
||||
defaultMsgType: 5,
|
||||
mockValue: 0,
|
||||
expectedValue: 0,
|
||||
errorMsg: "invalid msg type: 5",
|
||||
},
|
||||
{
|
||||
name: "invalid type max uint32",
|
||||
defaultMsgType: 4294967295,
|
||||
mockValue: 0,
|
||||
expectedValue: 0,
|
||||
errorMsg: "invalid msg type: 4294967295",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a new instance with mock PRNG for each test
|
||||
testMagicHeaders := MagicHeaders{
|
||||
Values: headers,
|
||||
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
|
||||
}
|
||||
|
||||
result, err := testMagicHeaders.Get(tt.defaultMsgType)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, uint32(0), result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
v2 "math/rand/v2"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type RandomNumberGenerator[T constraints.Integer] interface {
|
||||
RandomSizeInRange(min, max T) T
|
||||
Get() uint64
|
||||
ReadSize(size int) []byte
|
||||
}
|
||||
|
||||
type PRNG[T constraints.Integer] struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
}
|
||||
|
||||
func NewPRNG[T constraints.Integer]() PRNG[T] {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = crand.Read(buf)
|
||||
|
||||
return PRNG[T]{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
}
|
||||
}
|
||||
|
||||
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
|
||||
if min > max {
|
||||
panic("min must be less than max")
|
||||
}
|
||||
|
||||
if min == max {
|
||||
return min
|
||||
}
|
||||
|
||||
return T(p.Get()%uint64(max-min)) + min
|
||||
}
|
||||
|
||||
func (p PRNG[T]) Get() uint64 {
|
||||
return p.cha8Rand.Uint64()
|
||||
}
|
||||
|
||||
func (p PRNG[T]) ReadSize(size int) []byte {
|
||||
// TODO: use a memory pool to allocate
|
||||
buf := make([]byte, size)
|
||||
_, _ = p.cha8Rand.Read(buf)
|
||||
return buf
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"github.com/tevino/abool"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// TODO: atomic?/ and better way to use this
|
||||
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||
|
||||
// TODO
|
||||
var WaitResponse = struct {
|
||||
Channel chan struct{}
|
||||
ShouldWait *abool.AtomicBool
|
||||
}{
|
||||
make(chan struct{}, 1),
|
||||
abool.New(),
|
||||
}
|
||||
|
||||
type SpecialHandshakeHandler struct {
|
||||
SpecialJunk TagJunkPacketGenerators
|
||||
|
||||
IsSet bool
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) Validate() error {
|
||||
return handler.SpecialJunk.Validate()
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||
if !handler.SpecialJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return handler.SpecialJunk.GeneratePackets()
|
||||
}
|
|
@ -1,229 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v2 "math/rand/v2"
|
||||
// "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type Generator interface {
|
||||
Generate() []byte
|
||||
Size() int
|
||||
}
|
||||
|
||||
type newGenerator func(string) (Generator, error)
|
||||
|
||||
type BytesGenerator struct {
|
||||
value []byte
|
||||
size int
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Generate() []byte {
|
||||
return bg.value
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Size() int {
|
||||
return bg.size
|
||||
}
|
||||
|
||||
func newBytesGenerator(param string) (Generator, error) {
|
||||
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
|
||||
if !hasPrefix {
|
||||
return nil, fmt.Errorf("not correct hex: %s", param)
|
||||
}
|
||||
|
||||
hex, err := hexToBytes(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hexToBytes: %w", err)
|
||||
}
|
||||
|
||||
return &BytesGenerator{value: hex, size: len(hex)}, nil
|
||||
}
|
||||
|
||||
func hexToBytes(hexStr string) ([]byte, error) {
|
||||
hexStr = strings.TrimPrefix(hexStr, "0x")
|
||||
hexStr = strings.TrimPrefix(hexStr, "0X")
|
||||
|
||||
// Ensure even length (pad with leading zero if needed)
|
||||
if len(hexStr)%2 != 0 {
|
||||
hexStr = "0" + hexStr
|
||||
}
|
||||
|
||||
return hex.DecodeString(hexStr)
|
||||
}
|
||||
|
||||
type randomGeneratorBase struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
size int
|
||||
}
|
||||
|
||||
func newRandomGeneratorBase(param string) (*randomGeneratorBase, error) {
|
||||
size, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse int: %w", err)
|
||||
}
|
||||
|
||||
if size > 1000 {
|
||||
return nil, fmt.Errorf("size must be less than 1000")
|
||||
}
|
||||
|
||||
buf := make([]byte, 32)
|
||||
_, err = crand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crand read: %w", err)
|
||||
}
|
||||
|
||||
return &randomGeneratorBase{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rpg *randomGeneratorBase) generate() []byte {
|
||||
junk := make([]byte, rpg.size)
|
||||
rpg.cha8Rand.Read(junk)
|
||||
return junk
|
||||
}
|
||||
|
||||
func (rpg *randomGeneratorBase) Size() int {
|
||||
return rpg.size
|
||||
}
|
||||
|
||||
type RandomBytesGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomBytesGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random bytes generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomBytesGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomBytesGenerator) Generate() []byte {
|
||||
return rpg.generate()
|
||||
}
|
||||
|
||||
const alphanumericChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
type RandomASCIIGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomASCIIGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random ascii generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomASCIIGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomASCIIGenerator) Generate() []byte {
|
||||
junk := rpg.generate()
|
||||
|
||||
result := make([]byte, rpg.size)
|
||||
for i, b := range junk {
|
||||
result[i] = alphanumericChars[b%byte(len(alphanumericChars))]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type RandomDigitGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomDigitGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random digit generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomDigitGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomDigitGenerator) Generate() []byte {
|
||||
junk := rpg.generate()
|
||||
|
||||
result := make([]byte, rpg.size)
|
||||
for i, b := range junk {
|
||||
result[i] = '0' + (b % 10) // Convert to digit character
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newTimestampGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &TimestampGenerator{}, nil
|
||||
}
|
||||
|
||||
type PacketCounterGenerator struct {
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
// TODO: better way to handle counter tag
|
||||
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newPacketCounterGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &PacketCounterGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitResponseGenerator struct {
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Generate() []byte {
|
||||
WaitResponse.ShouldWait.Set()
|
||||
<-WaitResponse.Channel
|
||||
WaitResponse.ShouldWait.UnSet()
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitResponseGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &WaitResponseGenerator{}, nil
|
||||
}
|
|
@ -1,321 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBytesGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "wrong start",
|
||||
args: args{
|
||||
param: "123456",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with X",
|
||||
args: args{
|
||||
param: "0X12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with x",
|
||||
args: args{
|
||||
param: "0x12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "valid hex",
|
||||
args: args{
|
||||
param: "0xf6ab3267fa",
|
||||
},
|
||||
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
{
|
||||
name: "valid hex with odd length",
|
||||
args: args{
|
||||
param: "0xfab3267fa",
|
||||
},
|
||||
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newBytesGenerator(tt.args.param)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
gotValues := got.Generate()
|
||||
require.Equal(t, tt.want, gotValues)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomBytesGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomBytesGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomASCIIGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomASCIIGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomDigitGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomDigitGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketCounterGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
param string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid empty param",
|
||||
param: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid non-empty param",
|
||||
param: "anything",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gen, err := newPacketCounterGenerator(tc.param)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8, gen.Size())
|
||||
|
||||
// Reset counter to known value for test
|
||||
initialCount := uint64(42)
|
||||
PacketCounter.Store(initialCount)
|
||||
|
||||
output := gen.Generate()
|
||||
require.Equal(t, 8, len(output))
|
||||
|
||||
// Verify counter value in output
|
||||
counterValue := binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount, counterValue)
|
||||
|
||||
// Increment counter and verify change
|
||||
PacketCounter.Add(1)
|
||||
output = gen.Generate()
|
||||
counterValue = binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount+1, counterValue)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type TagJunkPacketGenerator struct {
|
||||
name string
|
||||
tagValue string
|
||||
|
||||
packetSize int
|
||||
generators []Generator
|
||||
}
|
||||
|
||||
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
|
||||
return TagJunkPacketGenerator{
|
||||
name: name,
|
||||
tagValue: tagValue,
|
||||
generators: make([]Generator, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
||||
tg.generators = append(tg.generators, generator)
|
||||
tg.packetSize += generator.Size()
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
|
||||
packet := make([]byte, 0, tg.packetSize)
|
||||
for _, generator := range tg.generators {
|
||||
packet = append(packet, generator.Generate()...)
|
||||
}
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) Name() string {
|
||||
return tg.name
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
|
||||
if len(tg.name) != 2 {
|
||||
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(tg.name[1:2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("name 2 char should be an int %w", err)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||
return IpcFields{
|
||||
Key: tg.name,
|
||||
Value: tg.tagValue,
|
||||
}
|
||||
}
|
|
@ -1,210 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTagJunkGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
genName string
|
||||
size int
|
||||
expected TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "Create new generator with empty name",
|
||||
genName: "",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with valid name",
|
||||
genName: "T1",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T1",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with non-zero size",
|
||||
genName: "T2",
|
||||
size: 5,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 5),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
|
||||
require.Equal(t, tc.expected.name, result.name)
|
||||
require.Equal(t, tc.expected.packetSize, result.packetSize)
|
||||
require.Equal(t, cap(result.generators), len(tc.expected.generators))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorAppend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialState TagJunkPacketGenerator
|
||||
mockSize int
|
||||
expectedLength int
|
||||
expectedSize int
|
||||
}{
|
||||
{
|
||||
name: "Append to empty generator",
|
||||
initialState: newTagJunkPacketGenerator("T1", "", 0),
|
||||
mockSize: 5,
|
||||
expectedLength: 1,
|
||||
expectedSize: 5,
|
||||
},
|
||||
{
|
||||
name: "Append to non-empty generator",
|
||||
initialState: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 10,
|
||||
generators: make([]Generator, 2),
|
||||
},
|
||||
mockSize: 7,
|
||||
expectedLength: 3, // 2 existing + 1 new
|
||||
expectedSize: 17, // 10 + 7
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.initialState
|
||||
mockGen := internal.NewMockGenerator(tc.mockSize)
|
||||
|
||||
tg.append(mockGen)
|
||||
|
||||
require.Equal(t, tc.expectedLength, len(tg.generators))
|
||||
require.Equal(t, tc.expectedSize, tg.packetSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorGenerate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock generators for testing
|
||||
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
|
||||
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupGenerator func() TagJunkPacketGenerator
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "Generate with empty generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
return newTagJunkPacketGenerator("T1", "", 0)
|
||||
},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Generate with single generator",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T2", "", 0)
|
||||
tg.append(mockGen1)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02},
|
||||
},
|
||||
{
|
||||
name: "Generate with multiple generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T3", "", 0)
|
||||
tg.append(mockGen1)
|
||||
tg.append(mockGen2)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.setupGenerator()
|
||||
result := tg.generatePacket()
|
||||
|
||||
require.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorNameIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
generatorName string
|
||||
expectedIndex int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid name with digit",
|
||||
generatorName: "T5",
|
||||
expectedIndex: 5,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too short",
|
||||
generatorName: "T",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too long",
|
||||
generatorName: "T55",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - non-digit second character",
|
||||
generatorName: "TX",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := TagJunkPacketGenerator{name: tc.generatorName}
|
||||
index, err := tg.nameIndex()
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectedIndex, index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package awg
|
||||
|
||||
import "fmt"
|
||||
|
||||
type TagJunkPacketGenerators struct {
|
||||
tagGenerators []TagJunkPacketGenerator
|
||||
length int
|
||||
DefaultJunkCount int // Jc
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) AppendGenerator(
|
||||
generator TagJunkPacketGenerator,
|
||||
) {
|
||||
generators.tagGenerators = append(generators.tagGenerators, generator)
|
||||
generators.length++
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) IsDefined() bool {
|
||||
return len(generators.tagGenerators) > 0
|
||||
}
|
||||
|
||||
// validate that packets were defined consecutively
|
||||
func (generators *TagJunkPacketGenerators) Validate() error {
|
||||
seen := make([]bool, len(generators.tagGenerators))
|
||||
for _, generator := range generators.tagGenerators {
|
||||
index, err := generator.nameIndex()
|
||||
if index > len(generators.tagGenerators) {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("name index: %w", err)
|
||||
} else {
|
||||
seen[index-1] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, found := range seen {
|
||||
if !found {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
|
||||
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
||||
|
||||
for i, tagGenerator := range generators.tagGenerators {
|
||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||
copy(rv[i], tagGenerator.generatePacket())
|
||||
PacketCounter.Inc()
|
||||
}
|
||||
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
|
||||
rv := make([]IpcFields, 0, len(tg.tagGenerators))
|
||||
for _, generator := range tg.tagGenerators {
|
||||
rv = append(rv, generator.IpcGetFields())
|
||||
}
|
||||
|
||||
return rv
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generator TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "append single generator",
|
||||
generator: newTagJunkPacketGenerator("t1", "", 10),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
|
||||
// Initial length should be 0
|
||||
require.Equal(t, 0, generators.length)
|
||||
require.Empty(t, generators.tagGenerators)
|
||||
|
||||
// After append, length should be 1 and generator should be added
|
||||
generators.AppendGenerator(tt.generator)
|
||||
require.Equal(t, 1, generators.length)
|
||||
require.Len(t, generators.tagGenerators, 1)
|
||||
require.Equal(t, tt.generator, generators.tagGenerators[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generators []TagJunkPacketGenerator
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "bad start",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "non-consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t2", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
newTagJunkPacketGenerator("t5", "", 10),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nameIndex error",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("error", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "name must be 2 character long",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
for _, gen := range tt.generators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
err := generators.Validate()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errMsg)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
|
||||
mockByte1 := []byte{0x01, 0x02}
|
||||
mockByte2 := []byte{0x03, 0x04, 0x05}
|
||||
mockGen1 := internal.NewMockByteGenerator(mockByte1)
|
||||
mockGen2 := internal.NewMockByteGenerator(mockByte2)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupGenerator func() []TagJunkPacketGenerator
|
||||
expected [][]byte
|
||||
}{
|
||||
{
|
||||
name: "generate with no default junk",
|
||||
setupGenerator: func() []TagJunkPacketGenerator {
|
||||
tg1 := newTagJunkPacketGenerator("t1", "", 0)
|
||||
tg1.append(mockGen1)
|
||||
tg1.append(mockGen2)
|
||||
tg2 := newTagJunkPacketGenerator("t2", "", 0)
|
||||
tg2.append(mockGen2)
|
||||
tg2.append(mockGen1)
|
||||
|
||||
return []TagJunkPacketGenerator{tg1, tg2}
|
||||
},
|
||||
expected: [][]byte{
|
||||
append(mockByte1, mockByte2...),
|
||||
append(mockByte2, mockByte1...),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
tagGenerators := tt.setupGenerator()
|
||||
for _, gen := range tagGenerators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
result := generators.GeneratePackets()
|
||||
require.Equal(t, result, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,112 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IpcFields struct{ Key, Value string }
|
||||
|
||||
type EnumTag string
|
||||
|
||||
const (
|
||||
BytesEnumTag EnumTag = "b"
|
||||
CounterEnumTag EnumTag = "c"
|
||||
TimestampEnumTag EnumTag = "t"
|
||||
RandomBytesEnumTag EnumTag = "r"
|
||||
RandomASCIIEnumTag EnumTag = "rc"
|
||||
RandomDigitEnumTag EnumTag = "rd"
|
||||
)
|
||||
|
||||
var generatorCreator = map[EnumTag]newGenerator{
|
||||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomBytesGenerator,
|
||||
RandomASCIIEnumTag: newRandomASCIIGenerator,
|
||||
RandomDigitEnumTag: newRandomDigitGenerator,
|
||||
}
|
||||
|
||||
// helper map to determine enumTags are unique
|
||||
var uniqueTags = map[EnumTag]bool{
|
||||
CounterEnumTag: false,
|
||||
TimestampEnumTag: false,
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name EnumTag
|
||||
Param string
|
||||
}
|
||||
|
||||
func parseTag(input string) (Tag, error) {
|
||||
// Regular expression to match <tagname optional_param>
|
||||
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
|
||||
|
||||
match := re.FindStringSubmatch(input)
|
||||
tag := Tag{
|
||||
Name: EnumTag(match[1]),
|
||||
}
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
tag.Param = strings.TrimSpace(match[2])
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
func ParseTagJunkGenerator(name, input string) (TagJunkPacketGenerator, error) {
|
||||
inputSlice := strings.Split(input, "<")
|
||||
if len(inputSlice) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
|
||||
}
|
||||
|
||||
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
|
||||
maps.Copy(uniqueTagCheck, uniqueTags)
|
||||
|
||||
// skip byproduct of split
|
||||
inputSlice = inputSlice[1:]
|
||||
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
|
||||
for _, inputParam := range inputSlice {
|
||||
if len(inputParam) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"empty tag in input: %s",
|
||||
inputSlice,
|
||||
)
|
||||
} else if strings.Count(inputParam, ">") != 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
||||
}
|
||||
|
||||
tag, _ := parseTag(inputParam)
|
||||
creator, ok := generatorCreator[tag.Name]
|
||||
if !ok {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
|
||||
}
|
||||
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
||||
if present {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"tag %s needs to be unique",
|
||||
tag.Name,
|
||||
)
|
||||
}
|
||||
uniqueTagCheck[tag.Name] = true
|
||||
}
|
||||
generator, err := creator(tag.Param)
|
||||
if err != nil {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
|
||||
}
|
||||
|
||||
// TODO: handle counter tag
|
||||
// if tag.Name == CounterEnumTag {
|
||||
// packetCounter, ok := generator.(*PacketCounterGenerator)
|
||||
// if !ok {
|
||||
// log.Fatalf("packet counter generator expected, got %T", generator)
|
||||
// }
|
||||
// PacketCounter = packetCounter.counter
|
||||
// }
|
||||
|
||||
rv.append(generator)
|
||||
}
|
||||
|
||||
return rv, nil
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
input string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid name",
|
||||
args: args{name: "apple", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{name: "i1", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra >",
|
||||
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra <",
|
||||
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "empty <>",
|
||||
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "invalid tag",
|
||||
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
|
||||
wantErr: fmt.Errorf("invalid tag"),
|
||||
},
|
||||
{
|
||||
name: "counter uniqueness violation",
|
||||
args: args{name: "i1", input: "<c><c>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "timestamp uniqueness violation",
|
||||
args: args{name: "i1", input: "<t><t>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ParseTagJunkGenerator(tt.args.name, tt.args.input)
|
||||
|
||||
// TODO: ErrorAs doesn't work as you think
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -8,7 +8,7 @@ package device
|
|||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type DummyDatagram struct {
|
||||
|
@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
|
||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
|
||||
datagram, ok := <-b.in6
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
copy(buf, datagram.msg)
|
||||
copy(buff, datagram.msg)
|
||||
return len(datagram.msg), datagram.endpoint, nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
|
||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
|
||||
datagram, ok := <-b.in4
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
copy(buf, datagram.msg)
|
||||
copy(buff, datagram.msg)
|
||||
return len(datagram.msg), datagram.endpoint, nil
|
||||
}
|
||||
|
||||
|
@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
|
||||
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -19,13 +19,13 @@ import (
|
|||
// call wg.Done to remove the initial reference.
|
||||
// When the refcount hits 0, the queue's channel is closed.
|
||||
type outboundQueue struct {
|
||||
c chan *QueueOutboundElementsContainer
|
||||
c chan *QueueOutboundElement
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newOutboundQueue() *outboundQueue {
|
||||
q := &outboundQueue{
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
|
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
|
|||
|
||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||
type inboundQueue struct {
|
||||
c chan *QueueInboundElementsContainer
|
||||
c chan *QueueInboundElement
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newInboundQueue() *inboundQueue {
|
||||
q := &inboundQueue{
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
|
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
|
|||
}
|
||||
|
||||
type autodrainingInboundQueue struct {
|
||||
c chan *QueueInboundElementsContainer
|
||||
c chan *QueueInboundElement
|
||||
}
|
||||
|
||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
|
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
|
|||
// some other means, such as sending a sentinel nil values.
|
||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
q := &autodrainingInboundQueue{
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||
return q
|
||||
|
@ -90,13 +90,10 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
|||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
case elem := <-q.c:
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
@ -104,7 +101,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
|||
}
|
||||
|
||||
type autodrainingOutboundQueue struct {
|
||||
c chan *QueueOutboundElementsContainer
|
||||
c chan *QueueOutboundElement
|
||||
}
|
||||
|
||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
|
@ -114,7 +111,7 @@ type autodrainingOutboundQueue struct {
|
|||
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
q := &autodrainingOutboundQueue{
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||
return q
|
||||
|
@ -123,13 +120,10 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
|||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
case elem := <-q.c:
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -118,7 +118,6 @@ func (st *CookieChecker) CreateReply(
|
|||
msg []byte,
|
||||
recv uint32,
|
||||
src []byte,
|
||||
msgType uint32,
|
||||
) (*MessageCookieReply, error) {
|
||||
st.RLock()
|
||||
|
||||
|
@ -154,7 +153,7 @@ func (st *CookieChecker) CreateReply(
|
|||
smac1 := smac2 - blake2s.Size128
|
||||
|
||||
reply := new(MessageCookieReply)
|
||||
reply.Type = msgType
|
||||
reply.Type = MessageCookieReplyType
|
||||
reply.Receiver = recv
|
||||
|
||||
_, err := rand.Read(reply.Nonce[:])
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) {
|
|||
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
|
||||
}
|
||||
generator.AddMacs(msg)
|
||||
reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType)
|
||||
reply, err := checker.CreateReply(msg, 1377, src)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create cookie reply:", err)
|
||||
}
|
||||
|
|
476
device/device.go
476
device/device.go
|
@ -1,62 +1,22 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type Version uint8
|
||||
|
||||
const (
|
||||
VersionDefault Version = iota
|
||||
VersionAwg
|
||||
VersionAwgSpecialHandshake
|
||||
)
|
||||
|
||||
// TODO:
|
||||
type AtomicVersion struct {
|
||||
value atomic.Uint32
|
||||
}
|
||||
|
||||
func NewAtomicVersion(v Version) *AtomicVersion {
|
||||
av := &AtomicVersion{}
|
||||
av.Store(v)
|
||||
return av
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Load() Version {
|
||||
return Version(av.value.Load())
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Store(v Version) {
|
||||
av.value.Store(uint32(v))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
|
||||
return av.value.CompareAndSwap(uint32(old), uint32(new))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Swap(new Version) Version {
|
||||
return Version(av.value.Swap(uint32(new)))
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
state struct {
|
||||
// state holds the device's state. It is accessed atomically.
|
||||
|
@ -70,7 +30,7 @@ type Device struct {
|
|||
// will become the actual state; Up can fail.
|
||||
// The device can also change state multiple times between time of check and time of use.
|
||||
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
||||
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
||||
state uint32 // actually a deviceState, but typed uint32 for convenience
|
||||
// stopping blocks until all inputs to Device have been closed.
|
||||
stopping sync.WaitGroup
|
||||
// mu protects state changes.
|
||||
|
@ -98,8 +58,9 @@ type Device struct {
|
|||
keyMap map[NoisePublicKey]*Peer
|
||||
}
|
||||
|
||||
// Keep this 8-byte aligned
|
||||
rate struct {
|
||||
underLoadUntil atomic.Int64
|
||||
underLoadUntil int64
|
||||
limiter ratelimiter.Ratelimiter
|
||||
}
|
||||
|
||||
|
@ -108,8 +69,6 @@ type Device struct {
|
|||
cookieChecker CookieChecker
|
||||
|
||||
pool struct {
|
||||
inboundElementsContainer *WaitPool
|
||||
outboundElementsContainer *WaitPool
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
|
@ -123,15 +82,12 @@ type Device struct {
|
|||
|
||||
tun struct {
|
||||
device tun.Device
|
||||
mtu atomic.Int32
|
||||
mtu int32
|
||||
}
|
||||
|
||||
ipcMutex sync.RWMutex
|
||||
closed chan struct{}
|
||||
log *Logger
|
||||
|
||||
version Version
|
||||
awg awg.Protocol
|
||||
}
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
|
@ -141,6 +97,7 @@ type Device struct {
|
|||
// down -----+
|
||||
// ↑↓ ↓
|
||||
// up -> closed
|
||||
//
|
||||
type deviceState uint32
|
||||
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
||||
|
@ -153,7 +110,7 @@ const (
|
|||
// deviceState returns device.state.state as a deviceState
|
||||
// See those docs for how to interpret this value.
|
||||
func (device *Device) deviceState() deviceState {
|
||||
return deviceState(device.state.state.Load())
|
||||
return deviceState(atomic.LoadUint32(&device.state.state))
|
||||
}
|
||||
|
||||
// isClosed reports whether the device is closed (or is closing).
|
||||
|
@ -192,21 +149,20 @@ func (device *Device) changeState(want deviceState) (err error) {
|
|||
case old:
|
||||
return nil
|
||||
case deviceStateUp:
|
||||
device.state.state.Store(uint32(deviceStateUp))
|
||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
|
||||
err = device.upLocked()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
fallthrough // up failed; bring the device all the way back down
|
||||
case deviceStateDown:
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
|
||||
errDown := device.downLocked()
|
||||
if err == nil {
|
||||
err = errDown
|
||||
}
|
||||
}
|
||||
device.log.Verbosef(
|
||||
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
||||
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -226,7 +182,7 @@ func (device *Device) upLocked() error {
|
|||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Start()
|
||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
|
@ -263,11 +219,11 @@ func (device *Device) IsUnderLoad() bool {
|
|||
now := time.Now()
|
||||
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
||||
if underLoad {
|
||||
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
||||
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
|
||||
return true
|
||||
}
|
||||
// check if recently under load
|
||||
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
||||
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
|
||||
}
|
||||
|
||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||
|
@ -311,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||
for _, peer := range device.peers.keyMap {
|
||||
handshake := &peer.handshake
|
||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
expiredPeers = append(expiredPeers, peer)
|
||||
}
|
||||
|
||||
|
@ -327,7 +283,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||
|
||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device := new(Device)
|
||||
device.state.state.Store(uint32(deviceStateDown))
|
||||
device.state.state = uint32(deviceStateDown)
|
||||
device.closed = make(chan struct{})
|
||||
device.log = logger
|
||||
device.net.bind = bind
|
||||
|
@ -337,11 +293,10 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
|||
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
device.tun.mtu.Store(int32(mtu))
|
||||
device.tun.mtu = int32(mtu)
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
device.rate.limiter.Init()
|
||||
device.indexTable.Init()
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
// create queues
|
||||
|
@ -369,19 +324,6 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
|||
return device
|
||||
}
|
||||
|
||||
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||
// the bind batch size and the tun batch size. The batch size reported by device
|
||||
// is the size used to construct memory pools, and is the allowed batch size for
|
||||
// the lifetime of the device.
|
||||
func (device *Device) BatchSize() int {
|
||||
size := device.net.bind.BatchSize()
|
||||
dSize := device.tun.device.BatchSize()
|
||||
if size < dSize {
|
||||
size = dSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
|
@ -414,12 +356,10 @@ func (device *Device) RemoveAllPeers() {
|
|||
func (device *Device) Close() {
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
if device.isClosed() {
|
||||
return
|
||||
}
|
||||
device.state.state.Store(uint32(deviceStateClosed))
|
||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
|
||||
device.log.Verbosef("Device closing")
|
||||
|
||||
device.tun.device.Close()
|
||||
|
@ -439,8 +379,6 @@ func (device *Device) Close() {
|
|||
|
||||
device.rate.limiter.Close()
|
||||
|
||||
device.resetProtocol()
|
||||
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
}
|
||||
|
@ -507,7 +445,11 @@ func (device *Device) BindSetMark(mark uint32) error {
|
|||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
|
@ -532,13 +474,11 @@ func (device *Device) BindUpdate() error {
|
|||
var err error
|
||||
var recvFns []conn.ReceiveFunc
|
||||
netc := &device.net
|
||||
|
||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
|
@ -557,7 +497,11 @@ func (device *Device) BindUpdate() error {
|
|||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
|
@ -565,9 +509,8 @@ func (device *Device) BindUpdate() error {
|
|||
device.net.stopping.Add(len(recvFns))
|
||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
batchSize := netc.bind.BatchSize()
|
||||
for _, fn := range recvFns {
|
||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||
go device.RoutineReceiveIncoming(fn)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
|
@ -580,358 +523,3 @@ func (device *Device) BindClose() error {
|
|||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) isAWG() bool {
|
||||
return device.version >= VersionAwg
|
||||
}
|
||||
|
||||
func (device *Device) resetProtocol() {
|
||||
// restore default message type values
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
||||
if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
isAwgOn := false
|
||||
device.awg.Mux.Lock()
|
||||
if tempAwg.Cfg.JunkPacketCount < 0 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketCount should be non negative",
|
||||
),
|
||||
)
|
||||
}
|
||||
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
|
||||
if tempAwg.Cfg.JunkPacketCount != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
|
||||
if tempAwg.Cfg.JunkPacketMinSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
if device.awg.Cfg.JunkPacketCount > 0 &&
|
||||
tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
|
||||
|
||||
tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
|
||||
device.awg.Cfg.JunkPacketMinSize = 0
|
||||
device.awg.Cfg.JunkPacketMaxSize = 1
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||
tempAwg.Cfg.JunkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
))
|
||||
} else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d",
|
||||
tempAwg.Cfg.JunkPacketMaxSize,
|
||||
tempAwg.Cfg.JunkPacketMinSize,
|
||||
))
|
||||
} else {
|
||||
device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
magicHeaders := make([]awg.MagicHeader, 4)
|
||||
|
||||
if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"magic headers should have 4 values; got: %d",
|
||||
len(tempAwg.Cfg.MagicHeaders.Values),
|
||||
)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
|
||||
isAwgOn = true
|
||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||
magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
|
||||
|
||||
MessageInitiationType = magicHeaders[0].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||
magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
|
||||
MessageResponseType = magicHeaders[1].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||
magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
|
||||
MessageCookieReplyType = magicHeaders[2].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||
magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
|
||||
MessageTransportType = magicHeaders[3].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
|
||||
}
|
||||
|
||||
var err error
|
||||
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
|
||||
if err != nil {
|
||||
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
|
||||
}
|
||||
|
||||
isSameHeaderMap := map[uint32]struct{}{
|
||||
MessageInitiationType: {},
|
||||
MessageResponseType: {},
|
||||
MessageCookieReplyType: {},
|
||||
MessageTransportType: {},
|
||||
}
|
||||
|
||||
// size will be different if same values
|
||||
if len(isSameHeaderMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
||||
MessageInitiationType,
|
||||
MessageResponseType,
|
||||
MessageCookieReplyType,
|
||||
MessageTransportType,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
|
||||
|
||||
if newInitSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.InitHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
|
||||
|
||||
if newResponseSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.ResponseHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
|
||||
|
||||
if newCookieSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.CookieReplyHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
|
||||
|
||||
if newTransportSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.TransportHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
isSameSizeMap := map[int]struct{}{
|
||||
newInitSize: {},
|
||||
newResponseSize: {},
|
||||
newCookieSize: {},
|
||||
newTransportSize: {},
|
||||
}
|
||||
|
||||
if len(isSameSizeMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
newCookieSize,
|
||||
newTransportSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
|
||||
MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
|
||||
MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
|
||||
MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
|
||||
}
|
||||
|
||||
packetSizeToMsgType = map[int]uint32{
|
||||
newInitSize: MessageInitiationType,
|
||||
newResponseSize: MessageResponseType,
|
||||
newCookieSize: MessageCookieReplyType,
|
||||
newTransportSize: MessageTransportType,
|
||||
}
|
||||
}
|
||||
|
||||
device.awg.IsOn.SetTo(isAwgOn)
|
||||
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
|
||||
|
||||
if tempAwg.HandshakeHandler.IsSet {
|
||||
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||
} else {
|
||||
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
|
||||
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
|
||||
device.version = VersionAwgSpecialHandshake
|
||||
}
|
||||
} else {
|
||||
device.version = VersionAwg
|
||||
}
|
||||
|
||||
device.awg.Mux.Unlock()
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
|
||||
// TODO:
|
||||
// if awg.WaitResponse.ShouldWait.IsSet() {
|
||||
// awg.WaitResponse.Channel <- struct{}{}
|
||||
// }
|
||||
|
||||
expectedMsgType, isKnownSize := packetSizeToMsgType[size]
|
||||
if !isKnownSize {
|
||||
msgType, err := device.handleTransport(size, packet, buffer)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("handle transport: %w", err)
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
junkSize := msgTypeToJunkSize[expectedMsgType]
|
||||
|
||||
// transport size can align with other header types;
|
||||
// making sure we have the right actualMsgType
|
||||
actualMsgType, err := device.getMsgType(packet, junkSize)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get msg type: %w", err)
|
||||
}
|
||||
|
||||
if actualMsgType == expectedMsgType {
|
||||
*packet = (*packet)[junkSize:]
|
||||
return actualMsgType, nil
|
||||
}
|
||||
|
||||
device.log.Verbosef("awg: transport packet lined up with another msg type")
|
||||
|
||||
msgType, err := device.handleTransport(size, packet, buffer)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("handle transport: %w", err)
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
|
||||
msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
|
||||
msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get magic header min: %w", err)
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
|
||||
junkSize := device.awg.Cfg.TransportHeaderJunkSize
|
||||
|
||||
msgType, err := device.getMsgType(packet, junkSize)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get msg type: %w", err)
|
||||
}
|
||||
|
||||
if msgType != MessageTransportType {
|
||||
// probably a junk packet
|
||||
return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
|
||||
}
|
||||
|
||||
if junkSize > 0 {
|
||||
// remove junk from buffer by shifting the packet
|
||||
// this buffer is also used for decryption, so it needs to be corrected
|
||||
copy((*buffer)[:size], (*packet)[junkSize:])
|
||||
size -= junkSize
|
||||
// need to reinitialize packet as well
|
||||
(*packet) = (*packet)[:size]
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
|
|
@ -1,32 +1,27 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/conn/bindtest"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
)
|
||||
|
||||
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
||||
|
@ -53,7 +48,7 @@ func uapiCfg(cfg ...string) string {
|
|||
|
||||
// genConfigs generates a pair of configs that connect to each other.
|
||||
// The configs use distinct, probably-usable ports.
|
||||
func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
||||
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||
var key1, key2 NoisePrivateKey
|
||||
_, err := rand.Read(key1[:])
|
||||
if err != nil {
|
||||
|
@ -65,8 +60,7 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
|||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
args0 := append([]string(nil), cfg...)
|
||||
args0 = append(args0, []string{
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
|
@ -74,16 +68,12 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
|||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.2/32",
|
||||
}...)
|
||||
cfgs[0] = uapiCfg(args0...)
|
||||
|
||||
)
|
||||
endpointCfgs[0] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
|
||||
args1 := append([]string(nil), cfg...)
|
||||
args1 = append(args1, []string{
|
||||
cfgs[1] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
|
@ -91,9 +81,7 @@ func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
|||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.1/32",
|
||||
}...)
|
||||
|
||||
cfgs[1] = uapiCfg(args1...)
|
||||
)
|
||||
endpointCfgs[1] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
|
@ -125,21 +113,16 @@ func (d SendDirection) String() string {
|
|||
return "pong"
|
||||
}
|
||||
|
||||
func (pair *testPair) Send(
|
||||
tb testing.TB,
|
||||
ping SendDirection,
|
||||
done chan struct{},
|
||||
) {
|
||||
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
|
||||
tb.Helper()
|
||||
p0, p1 := pair[0], pair[1]
|
||||
if !ping {
|
||||
// pong is the new ping
|
||||
p0, p1 = p1, p0
|
||||
}
|
||||
|
||||
msg := tuntest.Ping(p0.ip, p1.ip)
|
||||
p1.tun.Outbound <- msg
|
||||
timer := time.NewTimer(6 * time.Second)
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
defer timer.Stop()
|
||||
var err error
|
||||
select {
|
||||
|
@ -164,14 +147,8 @@ func (pair *testPair) Send(
|
|||
}
|
||||
|
||||
// genTestPair creates a testPair.
|
||||
func genTestPair(
|
||||
tb testing.TB,
|
||||
realSocket bool,
|
||||
extraCfg ...string,
|
||||
) (pair testPair) {
|
||||
var cfg, endpointCfg [2]string
|
||||
cfg, endpointCfg = genConfigs(tb, extraCfg...)
|
||||
|
||||
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||
cfg, endpointCfg := genConfigs(tb)
|
||||
var binds [2]conn.Bind
|
||||
if realSocket {
|
||||
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||
|
@ -224,74 +201,6 @@ func TestTwoDevicePing(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
|
||||
func TestAWGDevicePing(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
|
||||
pair := genTestPair(t, true,
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
"s1", "15",
|
||||
"s2", "18",
|
||||
"s3", "20",
|
||||
"s4", "25",
|
||||
"h1", "123456-123500",
|
||||
"h2", "67543-67550",
|
||||
"h3", "123123-123200",
|
||||
"h4", "32345-32350",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
pair.Send(t, Ping, nil)
|
||||
})
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
pair.Send(t, Pong, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// Needs to be stopped with Ctrl-C
|
||||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||
t.Skip("This test is intended to be run manually, not as part of the test suite.")
|
||||
|
||||
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
isRunning := atomic.NewBool(true)
|
||||
go func() {
|
||||
<-signalContext.Done()
|
||||
fmt.Println("Waiting to finish")
|
||||
isRunning.Store(false)
|
||||
}()
|
||||
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true,
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10>",
|
||||
"i2", "<b 0xf6ab3267fa><c><b 0xf6ab><t><rc 10>",
|
||||
"i3", "<b 0xf6ab3267fa><c><b 0xf6ab><t><rd 10>",
|
||||
"i4", "<b 0xf6ab3267fa><r 100>",
|
||||
// "jc", "1",
|
||||
// "jmin", "500",
|
||||
// "jmax", "1000",
|
||||
// "s1", "30",
|
||||
// "s2", "40",
|
||||
// "h1", "123456",
|
||||
// "h2", "67543",
|
||||
// "h4", "32345",
|
||||
// "h3", "123123",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Ping, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Pong, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpDown(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
const itrials = 50
|
||||
|
@ -398,17 +307,6 @@ func TestConcurrencySafety(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
// Perform bind updates and keepalive sends concurrently with tunnel use.
|
||||
t.Run("bindUpdate and keepalive", func(t *testing.T) {
|
||||
const iters = 10
|
||||
for i := 0; i < iters; i++ {
|
||||
for _, peer := range pair {
|
||||
peer.dev.BindUpdate()
|
||||
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
close(done)
|
||||
}
|
||||
|
||||
|
@ -435,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||
|
||||
// Measure how long it takes to receive b.N packets,
|
||||
// starting when we receive the first packet.
|
||||
var recv atomic.Uint64
|
||||
var recv uint64
|
||||
var elapsed time.Duration
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
@ -444,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||
var start time.Time
|
||||
for {
|
||||
<-pair[0].tun.Inbound
|
||||
new := recv.Add(1)
|
||||
new := atomic.AddUint64(&recv, 1)
|
||||
if new == 1 {
|
||||
start = time.Now()
|
||||
}
|
||||
|
@ -460,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
||||
pingc := pair[1].tun.Outbound
|
||||
var sent uint64
|
||||
for recv.Load() != uint64(b.N) {
|
||||
for atomic.LoadUint64(&recv) != uint64(b.N) {
|
||||
sent++
|
||||
pingc <- ping
|
||||
}
|
||||
|
@ -507,73 +405,3 @@ func goroutineLeakCheck(t *testing.T) {
|
|||
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
||||
})
|
||||
}
|
||||
|
||||
type fakeBindSized struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func (b *fakeBindSized) Open(
|
||||
port uint16,
|
||||
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (b *fakeBindSized) Close() error { return nil }
|
||||
|
||||
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
||||
|
||||
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
||||
|
||||
func (b *fakeBindSized) BatchSize() int { return b.size }
|
||||
|
||||
type fakeTUNDeviceSized struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
||||
|
||||
func TestBatchSize(t *testing.T) {
|
||||
d := Device{}
|
||||
|
||||
d.net.bind = &fakeBindSized{1}
|
||||
d.tun.device = &fakeTUNDeviceSized{1}
|
||||
if want, got := 1, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{1}
|
||||
d.tun.device = &fakeTUNDeviceSized{128}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{128}
|
||||
d.tun.device = &fakeTUNDeviceSized{1}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{128}
|
||||
d.tun.device = &fakeTUNDeviceSized{128}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -10,8 +10,9 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/replay"
|
||||
"golang.zx2c4.com/wireguard/replay"
|
||||
)
|
||||
|
||||
/* Due to limitations in Go and /x/crypto there is currently
|
||||
|
@ -22,7 +23,7 @@ import (
|
|||
*/
|
||||
|
||||
type Keypair struct {
|
||||
sendNonce atomic.Uint64
|
||||
sendNonce uint64 // accessed atomically
|
||||
send cipher.AEAD
|
||||
receive cipher.AEAD
|
||||
replayFilter replay.Filter
|
||||
|
@ -36,7 +37,15 @@ type Keypairs struct {
|
|||
sync.RWMutex
|
||||
current *Keypair
|
||||
previous *Keypair
|
||||
next atomic.Pointer[Keypair]
|
||||
next *Keypair
|
||||
}
|
||||
|
||||
func (kp *Keypairs) storeNext(next *Keypair) {
|
||||
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
||||
}
|
||||
|
||||
func (kp *Keypairs) loadNext() *Keypair {
|
||||
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
||||
}
|
||||
|
||||
func (kp *Keypairs) Current() *Keypair {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
41
device/misc.go
Normal file
41
device/misc.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
/* Atomic Boolean */
|
||||
|
||||
const (
|
||||
AtomicFalse = int32(iota)
|
||||
AtomicTrue
|
||||
)
|
||||
|
||||
type AtomicBool struct {
|
||||
int32
|
||||
}
|
||||
|
||||
func (a *AtomicBool) Get() bool {
|
||||
return atomic.LoadInt32(&a.int32) == AtomicTrue
|
||||
}
|
||||
|
||||
func (a *AtomicBool) Swap(val bool) bool {
|
||||
flag := AtomicFalse
|
||||
if val {
|
||||
flag = AtomicTrue
|
||||
}
|
||||
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
|
||||
}
|
||||
|
||||
func (a *AtomicBool) Set(val bool) {
|
||||
flag := AtomicFalse
|
||||
if val {
|
||||
flag = AtomicTrue
|
||||
}
|
||||
atomic.StoreInt32(&a.int32, flag)
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
|||
device.net.brokenRoaming = true
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||
peer.endpoint.Unlock()
|
||||
peer.Lock()
|
||||
peer.disableRoaming = peer.endpoint != nil
|
||||
peer.Unlock()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -9,7 +9,6 @@ import (
|
|||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
|
@ -95,14 +94,9 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
|||
return
|
||||
}
|
||||
|
||||
var errInvalidPublicKey = errors.New("invalid public key")
|
||||
|
||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||
curve25519.ScalarMult(&ss, ask, apk)
|
||||
if isZero(ss[:]) {
|
||||
return ss, errInvalidPublicKey
|
||||
}
|
||||
return ss, nil
|
||||
return ss
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -15,7 +15,7 @@ import (
|
|||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/tai64n"
|
||||
"golang.zx2c4.com/wireguard/tai64n"
|
||||
)
|
||||
|
||||
type handshakeState int
|
||||
|
@ -53,17 +53,10 @@ const (
|
|||
)
|
||||
|
||||
const (
|
||||
DefaultMessageInitiationType uint32 = 1
|
||||
DefaultMessageResponseType uint32 = 2
|
||||
DefaultMessageCookieReplyType uint32 = 3
|
||||
DefaultMessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
var (
|
||||
MessageInitiationType uint32 = DefaultMessageInitiationType
|
||||
MessageResponseType uint32 = DefaultMessageResponseType
|
||||
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
|
||||
MessageTransportType uint32 = DefaultMessageTransportType
|
||||
MessageInitiationType = 1
|
||||
MessageResponseType = 2
|
||||
MessageCookieReplyType = 3
|
||||
MessageTransportType = 4
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -82,11 +75,6 @@ const (
|
|||
MessageTransportOffsetContent = 16
|
||||
)
|
||||
|
||||
var (
|
||||
packetSizeToMsgType map[int]uint32
|
||||
msgTypeToJunkSize map[uint32]int
|
||||
)
|
||||
|
||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||
* by marshalling the messages in little-endian byteorder
|
||||
* we can treat these as a 32-bit unsigned int (for now)
|
||||
|
@ -187,6 +175,8 @@ func init() {
|
|||
}
|
||||
|
||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||
errZeroECDHResult := errors.New("ECDH returned all zeros")
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
||||
|
@ -205,20 +195,8 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
msgType := DefaultMessageInitiationType
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
msgType, err = device.awg.GetMsgType(DefaultMessageInitiationType)
|
||||
if err != nil {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil, fmt.Errorf("get message type: %w", err)
|
||||
}
|
||||
|
||||
device.awg.Mux.RUnlock()
|
||||
}
|
||||
|
||||
msg := MessageInitiation{
|
||||
Type: msgType,
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
|
||||
|
@ -226,9 +204,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
||||
// encrypt static key
|
||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if isZero(ss[:]) {
|
||||
return nil, errZeroECDHResult
|
||||
}
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
KDF2(
|
||||
|
@ -243,7 +221,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
|
||||
// encrypt timestamp
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
return nil, errInvalidPublicKey
|
||||
return nil, errZeroECDHResult
|
||||
}
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
|
@ -274,13 +252,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
if msg.Type != MessageInitiationType {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.Mux.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
@ -290,10 +264,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||
|
||||
// decrypt static key
|
||||
var err error
|
||||
var peerPK NoisePublicKey
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if isZero(ss[:]) {
|
||||
return nil
|
||||
}
|
||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||
|
@ -307,7 +282,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
// lookup peer
|
||||
|
||||
peer := device.LookupPeer(peerPK)
|
||||
if peer == nil || !peer.isRunning.Load() {
|
||||
if peer == nil || !peer.isRunning.Get() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -395,19 +370,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
var msg MessageResponse
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
|
||||
if err != nil {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil, fmt.Errorf("get message type: %w", err)
|
||||
}
|
||||
|
||||
device.awg.Mux.RUnlock()
|
||||
} else {
|
||||
msg.Type = DefaultMessageResponseType
|
||||
}
|
||||
|
||||
msg.Type = MessageResponseType
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
|
@ -421,16 +384,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
handshake.mixHash(msg.Ephemeral[:])
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
|
||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func() {
|
||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||
handshake.mixKey(ss[:])
|
||||
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
handshake.mixKey(ss[:])
|
||||
}()
|
||||
|
||||
// add preshared key
|
||||
|
||||
|
@ -447,9 +406,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
|
||||
handshake.mixHash(tau[:])
|
||||
|
||||
func() {
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||
handshake.mixHash(msg.Empty[:])
|
||||
}()
|
||||
|
||||
handshake.state = handshakeResponseCreated
|
||||
|
||||
|
@ -457,13 +418,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
if msg.Type != MessageResponseType {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.Mux.RUnlock()
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
|
@ -498,19 +455,17 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||
|
||||
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
func() {
|
||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||
mixKey(&chainKey, &chainKey, ss[:])
|
||||
setZero(ss[:])
|
||||
}()
|
||||
|
||||
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
func() {
|
||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
mixKey(&chainKey, &chainKey, ss[:])
|
||||
setZero(ss[:])
|
||||
}()
|
||||
|
||||
// add preshared key (psk)
|
||||
|
||||
|
@ -528,7 +483,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||
// authenticate transcript
|
||||
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -626,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||
defer keypairs.Unlock()
|
||||
|
||||
previous := keypairs.previous
|
||||
next := keypairs.next.Load()
|
||||
next := keypairs.loadNext()
|
||||
current := keypairs.current
|
||||
|
||||
if isInitiator {
|
||||
if next != nil {
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.storeNext(nil)
|
||||
keypairs.previous = next
|
||||
device.DeleteKeypair(current)
|
||||
} else {
|
||||
|
@ -640,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||
device.DeleteKeypair(previous)
|
||||
keypairs.current = keypair
|
||||
} else {
|
||||
keypairs.next.Store(keypair)
|
||||
keypairs.storeNext(keypair)
|
||||
device.DeleteKeypair(next)
|
||||
keypairs.previous = nil
|
||||
device.DeleteKeypair(previous)
|
||||
|
@ -652,18 +607,18 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||
keypairs := &peer.keypairs
|
||||
|
||||
if keypairs.next.Load() != receivedKeypair {
|
||||
if keypairs.loadNext() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
if keypairs.next.Load() != receivedKeypair {
|
||||
if keypairs.loadNext() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
old := keypairs.previous
|
||||
keypairs.previous = keypairs.current
|
||||
peer.device.DeleteKeypair(old)
|
||||
keypairs.current = keypairs.next.Load()
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.current = keypairs.loadNext()
|
||||
keypairs.storeNext(nil)
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -10,8 +10,8 @@ import (
|
|||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
)
|
||||
|
||||
func TestCurveWrappers(t *testing.T) {
|
||||
|
@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
|
|||
pk1 := sk1.publicKey()
|
||||
pk2 := sk2.publicKey()
|
||||
|
||||
ss1, err1 := sk1.sharedSecret(pk2)
|
||||
ss2, err2 := sk2.sharedSecret(pk1)
|
||||
ss1 := sk1.sharedSecret(pk2)
|
||||
ss2 := sk2.sharedSecret(pk1)
|
||||
|
||||
if ss1 != ss2 || err1 != nil || err2 != nil {
|
||||
if ss1 != ss2 {
|
||||
t.Fatal("Failed to compute shared secet")
|
||||
}
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
|
|||
t.Fatal("failed to derive keypair for peer 2", err)
|
||||
}
|
||||
|
||||
key1 := peer1.keypairs.next.Load()
|
||||
key1 := peer1.keypairs.loadNext()
|
||||
key2 := peer2.keypairs.current
|
||||
|
||||
// encrypting / decryption test
|
||||
|
|
124
device/peer.go
124
device/peer.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -12,36 +12,40 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
isRunning atomic.Bool
|
||||
isRunning AtomicBool
|
||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
||||
keypairs Keypairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
endpoint conn.Endpoint
|
||||
stopping sync.WaitGroup // routines pending stop
|
||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||
rxBytes atomic.Uint64 // bytes received from peer
|
||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||
|
||||
endpoint struct {
|
||||
sync.Mutex
|
||||
val conn.Endpoint
|
||||
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||
disableRoaming bool
|
||||
// These fields are accessed with atomic operations, which must be
|
||||
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
||||
// allocated struct will be 64-bit aligned. So we place
|
||||
// atomically-accessed fields up front, so that they can share in
|
||||
// this alignment before smaller fields throw it off.
|
||||
stats struct {
|
||||
txBytes uint64 // bytes send to peer (endpoint)
|
||||
rxBytes uint64 // bytes received from peer
|
||||
lastHandshakeNano int64 // nano seconds since epoch
|
||||
}
|
||||
|
||||
disableRoaming bool
|
||||
|
||||
timers struct {
|
||||
retransmitHandshake *Timer
|
||||
sendKeepalive *Timer
|
||||
newHandshake *Timer
|
||||
zeroKeyMaterial *Timer
|
||||
persistentKeepalive *Timer
|
||||
handshakeAttempts atomic.Uint32
|
||||
needAnotherKeepalive atomic.Bool
|
||||
sentLastMinuteHandshake atomic.Bool
|
||||
handshakeAttempts uint32
|
||||
needAnotherKeepalive AtomicBool
|
||||
sentLastMinuteHandshake AtomicBool
|
||||
}
|
||||
|
||||
state struct {
|
||||
|
@ -49,14 +53,14 @@ type Peer struct {
|
|||
}
|
||||
|
||||
queue struct {
|
||||
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||
staged chan *QueueOutboundElement // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
}
|
||||
|
||||
cookieGenerator CookieGenerator
|
||||
trieEntries list.List
|
||||
persistentKeepaliveInterval atomic.Uint32
|
||||
persistentKeepaliveInterval uint32 // accessed atomically
|
||||
}
|
||||
|
||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
|
@ -78,12 +82,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
|
||||
// create peer
|
||||
peer := new(Peer)
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
|
||||
peer.cookieGenerator.Init(pk)
|
||||
peer.device = device
|
||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
|
||||
|
||||
// map public key
|
||||
_, ok := device.peers.keyMap[pk]
|
||||
|
@ -94,16 +100,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
// pre-compute DH
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||
handshake.remoteStatic = pk
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.val = nil
|
||||
peer.endpoint.disableRoaming = false
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.Unlock()
|
||||
peer.endpoint = nil
|
||||
|
||||
// init timers
|
||||
peer.timersInit()
|
||||
|
@ -114,17 +116,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
|
||||
err := peer.SendBuffers(buffers)
|
||||
if err == nil {
|
||||
awg.PacketCounter.Add(uint64(len(buffers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
|
@ -132,25 +124,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
peer.endpoint.Lock()
|
||||
endpoint := peer.endpoint.val
|
||||
if endpoint == nil {
|
||||
peer.endpoint.Unlock()
|
||||
peer.RLock()
|
||||
defer peer.RUnlock()
|
||||
|
||||
if peer.endpoint == nil {
|
||||
return errors.New("no known endpoint for peer")
|
||||
}
|
||||
if peer.endpoint.clearSrcOnTx {
|
||||
endpoint.ClearSrc()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||
if err == nil {
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
totalLen += uint64(len(b))
|
||||
}
|
||||
peer.txBytes.Add(totalLen)
|
||||
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -191,7 +174,7 @@ func (peer *Peer) Start() {
|
|||
peer.state.Lock()
|
||||
defer peer.state.Unlock()
|
||||
|
||||
if peer.isRunning.Load() {
|
||||
if peer.isRunning.Get() {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -212,14 +195,10 @@ func (peer *Peer) Start() {
|
|||
|
||||
device.flushInboundQueue(peer.queue.inbound)
|
||||
device.flushOutboundQueue(peer.queue.outbound)
|
||||
go peer.RoutineSequentialSender()
|
||||
go peer.RoutineSequentialReceiver()
|
||||
|
||||
// Use the device batch size, not the bind batch size, as the device size is
|
||||
// the size of the batch pools.
|
||||
batchSize := peer.device.BatchSize()
|
||||
go peer.RoutineSequentialSender(batchSize)
|
||||
go peer.RoutineSequentialReceiver(batchSize)
|
||||
|
||||
peer.isRunning.Store(true)
|
||||
peer.isRunning.Set(true)
|
||||
}
|
||||
|
||||
func (peer *Peer) ZeroAndFlushAll() {
|
||||
|
@ -231,10 +210,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||
keypairs.Lock()
|
||||
device.DeleteKeypair(keypairs.previous)
|
||||
device.DeleteKeypair(keypairs.current)
|
||||
device.DeleteKeypair(keypairs.next.Load())
|
||||
device.DeleteKeypair(keypairs.loadNext())
|
||||
keypairs.previous = nil
|
||||
keypairs.current = nil
|
||||
keypairs.next.Store(nil)
|
||||
keypairs.storeNext(nil)
|
||||
keypairs.Unlock()
|
||||
|
||||
// clear handshake state
|
||||
|
@ -259,10 +238,11 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
|||
keypairs := &peer.keypairs
|
||||
keypairs.Lock()
|
||||
if keypairs.current != nil {
|
||||
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
||||
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
|
||||
}
|
||||
if next := keypairs.next.Load(); next != nil {
|
||||
next.sendNonce.Store(RejectAfterMessages)
|
||||
if keypairs.next != nil {
|
||||
next := keypairs.loadNext()
|
||||
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
|
||||
}
|
||||
keypairs.Unlock()
|
||||
}
|
||||
|
@ -288,20 +268,10 @@ func (peer *Peer) Stop() {
|
|||
}
|
||||
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.disableRoaming {
|
||||
if peer.disableRoaming {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.val = endpoint
|
||||
}
|
||||
|
||||
func (peer *Peer) markEndpointSrcForClearing() {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.val == nil {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = true
|
||||
peer.Lock()
|
||||
peer.endpoint = endpoint
|
||||
peer.Unlock()
|
||||
}
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type WaitPool struct {
|
||||
pool sync.Pool
|
||||
cond sync.Cond
|
||||
lock sync.Mutex
|
||||
count uint32 // Get calls not yet Put back
|
||||
count uint32
|
||||
max uint32
|
||||
}
|
||||
|
||||
|
@ -26,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
|
|||
func (p *WaitPool) Get() any {
|
||||
if p.max != 0 {
|
||||
p.lock.Lock()
|
||||
for p.count >= p.max {
|
||||
for atomic.LoadUint32(&p.count) >= p.max {
|
||||
p.cond.Wait()
|
||||
}
|
||||
p.count++
|
||||
atomic.AddUint32(&p.count, 1)
|
||||
p.lock.Unlock()
|
||||
}
|
||||
return p.pool.Get()
|
||||
|
@ -40,21 +41,11 @@ func (p *WaitPool) Put(x any) {
|
|||
if p.max == 0 {
|
||||
return
|
||||
}
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.count--
|
||||
atomic.AddUint32(&p.count, ^uint32(0))
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
func (device *Device) PopulatePools() {
|
||||
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||
return &QueueInboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||
return &QueueOutboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new([MaxMessageSize]byte)
|
||||
})
|
||||
|
@ -66,34 +57,6 @@ func (device *Device) PopulatePools() {
|
|||
})
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.inboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.outboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -17,33 +17,29 @@ import (
|
|||
func TestWaitPool(t *testing.T) {
|
||||
t.Skip("Currently disabled")
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
startTrials := int32(100000)
|
||||
trials := int32(100000)
|
||||
if raceEnabled {
|
||||
// This test can be very slow with -race.
|
||||
startTrials /= 10
|
||||
trials /= 10
|
||||
}
|
||||
trials.Store(startTrials)
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
t.Skip("Not enough cores")
|
||||
}
|
||||
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
||||
wg.Add(workers)
|
||||
var max atomic.Uint32
|
||||
max := uint32(0)
|
||||
updateMax := func() {
|
||||
p.lock.Lock()
|
||||
count := p.count
|
||||
p.lock.Unlock()
|
||||
count := atomic.LoadUint32(&p.count)
|
||||
if count > p.max {
|
||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||
}
|
||||
for {
|
||||
old := max.Load()
|
||||
old := atomic.LoadUint32(&max)
|
||||
if count <= old {
|
||||
break
|
||||
}
|
||||
if max.CompareAndSwap(old, count) {
|
||||
if atomic.CompareAndSwapUint32(&max, old, count) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -51,7 +47,7 @@ func TestWaitPool(t *testing.T) {
|
|||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
for atomic.AddInt32(&trials, -1) > 0 {
|
||||
updateMax()
|
||||
x := p.Get()
|
||||
updateMax()
|
||||
|
@ -63,15 +59,14 @@ func TestWaitPool(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if max.Load() != p.max {
|
||||
if max != p.max {
|
||||
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWaitPool(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
trials := int32(b.N)
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
|
@ -82,55 +77,7 @@ func BenchmarkWaitPool(b *testing.B) {
|
|||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkWaitPoolEmpty(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
}
|
||||
p := NewWaitPool(0, func() any { return make([]byte, 16) })
|
||||
wg.Add(workers)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkSyncPool(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
}
|
||||
p := sync.Pool{New: func() any { return make([]byte, 16) }}
|
||||
wg.Add(workers)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
for atomic.AddInt32(&trials, -1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
|
||||
/* Reduce memory consumption for Android */
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueStagedSize = 128
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||
MaxSegmentSize = 2200
|
||||
PreallocatedBuffersPerPool = 4096
|
||||
)
|
||||
|
|
|
@ -2,15 +2,13 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueStagedSize = 128
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -11,12 +11,13 @@ import (
|
|||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type QueueHandshakeElement struct {
|
||||
|
@ -27,6 +28,7 @@ type QueueHandshakeElement struct {
|
|||
}
|
||||
|
||||
type QueueInboundElement struct {
|
||||
sync.Mutex
|
||||
buffer *[MaxMessageSize]byte
|
||||
packet []byte
|
||||
counter uint64
|
||||
|
@ -34,11 +36,6 @@ type QueueInboundElement struct {
|
|||
endpoint conn.Endpoint
|
||||
}
|
||||
|
||||
type QueueInboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueInboundElement
|
||||
}
|
||||
|
||||
// clearPointers clears elem fields that contain pointers.
|
||||
// This makes the garbage collector's life easier and
|
||||
// avoids accidentally keeping other objects around unnecessarily.
|
||||
|
@ -55,12 +52,12 @@ func (elem *QueueInboundElement) clearPointers() {
|
|||
* NOTE: Not thread safe, but called by sequential receiver!
|
||||
*/
|
||||
func (peer *Peer) keepKeyFreshReceiving() {
|
||||
if peer.timers.sentLastMinuteHandshake.Load() {
|
||||
if peer.timers.sentLastMinuteHandshake.Get() {
|
||||
return
|
||||
}
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||
peer.timers.sentLastMinuteHandshake.Store(true)
|
||||
peer.timers.sentLastMinuteHandshake.Set(true)
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
}
|
||||
|
@ -70,10 +67,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
|||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(
|
||||
maxBatchSize int,
|
||||
recv conn.ReceiveFunc,
|
||||
) {
|
||||
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
recvName := recv.PrettyName()
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||
|
@ -86,33 +80,20 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
buffer := device.GetMessageBuffer()
|
||||
|
||||
var (
|
||||
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||
bufs = make([][]byte, maxBatchSize)
|
||||
err error
|
||||
sizes = make([]int, maxBatchSize)
|
||||
count int
|
||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||
size int
|
||||
endpoint conn.Endpoint
|
||||
deathSpiral int
|
||||
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||
)
|
||||
|
||||
for i := range bufsArrs {
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := 0; i < maxBatchSize; i++ {
|
||||
if bufsArrs[i] != nil {
|
||||
device.PutMessageBuffer(bufsArrs[i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
count, err = recv(bufs, sizes, endpoints)
|
||||
size, endpoint, err = recv(buffer[:])
|
||||
|
||||
if err != nil {
|
||||
device.PutMessageBuffer(buffer)
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
|
@ -123,32 +104,23 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
if deathSpiral < 10 {
|
||||
deathSpiral++
|
||||
time.Sleep(time.Second / 3)
|
||||
buffer = device.GetMessageBuffer()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
deathSpiral = 0
|
||||
|
||||
device.awg.Mux.RLock()
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
packet := bufsArrs[i][:size]
|
||||
var msgType uint32
|
||||
if device.isAWG() {
|
||||
msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
|
||||
|
||||
if err != nil {
|
||||
device.log.Verbosef("awg: process packet: %v", err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
}
|
||||
packet := buffer[:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
var okay bool
|
||||
|
||||
switch msgType {
|
||||
|
||||
|
@ -183,70 +155,50 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = bufsArrs[i]
|
||||
elem.buffer = buffer
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoints[i]
|
||||
elem.endpoint = endpoint
|
||||
elem.counter = 0
|
||||
elem.Mutex = sync.Mutex{}
|
||||
elem.Lock()
|
||||
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetInboundElementsContainer()
|
||||
elemsForPeer.Lock()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
// add to decryption queues
|
||||
if peer.isRunning.Get() {
|
||||
peer.queue.inbound.c <- elem
|
||||
device.queue.decryption.c <- elem
|
||||
buffer = device.GetMessageBuffer()
|
||||
} else {
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
if len(packet) != MessageInitiationSize {
|
||||
continue
|
||||
}
|
||||
okay = len(packet) == MessageInitiationSize
|
||||
|
||||
case MessageResponseType:
|
||||
if len(packet) != MessageResponseSize {
|
||||
continue
|
||||
}
|
||||
okay = len(packet) == MessageResponseSize
|
||||
|
||||
case MessageCookieReplyType:
|
||||
if len(packet) != MessageCookieReplySize {
|
||||
continue
|
||||
}
|
||||
okay = len(packet) == MessageCookieReplySize
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
continue
|
||||
}
|
||||
|
||||
if okay {
|
||||
select {
|
||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: bufsArrs[i],
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
endpoint: endpoints[i],
|
||||
endpoint: endpoint,
|
||||
}:
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
buffer = device.GetMessageBuffer()
|
||||
default:
|
||||
}
|
||||
}
|
||||
device.awg.Mux.RUnlock()
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
device.queue.decryption.c <- elemsContainer
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -256,8 +208,7 @@ func (device *Device) RoutineDecryption(id int) {
|
|||
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||
|
||||
for elemsContainer := range device.queue.decryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
for elem := range device.queue.decryption.c {
|
||||
// split message into fields
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
|
@ -276,8 +227,7 @@ func (device *Device) RoutineDecryption(id int) {
|
|||
if err != nil {
|
||||
elem.packet = nil
|
||||
}
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
elem.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -292,8 +242,6 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
switch elem.msgType {
|
||||
|
@ -320,15 +268,10 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
// consume reply
|
||||
|
||||
if peer := entry.peer; peer.isRunning.Load() {
|
||||
device.log.Verbosef(
|
||||
"Receiving cookie response from %s",
|
||||
elem.endpoint.DstToString(),
|
||||
)
|
||||
if peer := entry.peer; peer.isRunning.Get() {
|
||||
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
||||
device.log.Verbosef(
|
||||
"Could not decrypt invalid cookie response",
|
||||
)
|
||||
device.log.Verbosef("Could not decrypt invalid cookie response")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -370,7 +313,9 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
switch elem.msgType {
|
||||
case MessageInitiationType:
|
||||
|
||||
// unmarshal
|
||||
|
||||
var msg MessageInitiation
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
|
@ -379,10 +324,8 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
goto skip
|
||||
}
|
||||
|
||||
// have to reassign msgType for ranged msgType to work
|
||||
msg.Type = elem.msgType
|
||||
|
||||
// consume initiation
|
||||
|
||||
peer := device.ConsumeMessageInitiation(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||
|
@ -398,7 +341,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
|
||||
peer.SendHandshakeResponse()
|
||||
|
||||
|
@ -414,9 +357,6 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
goto skip
|
||||
}
|
||||
|
||||
// have to reassign msgType for ranged msgType to work
|
||||
msg.Type = elem.msgType
|
||||
|
||||
// consume response
|
||||
|
||||
peer := device.ConsumeMessageResponse(&msg)
|
||||
|
@ -429,7 +369,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
device.log.Verbosef("%v - Received handshake response", peer)
|
||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
|
||||
// update timers
|
||||
|
||||
|
@ -450,12 +390,11 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.awg.Mux.RUnlock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||
func (peer *Peer) RoutineSequentialReceiver() {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||
|
@ -463,109 +402,89 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
|||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.inbound.c {
|
||||
if elemsContainer == nil {
|
||||
for elem := range peer.queue.inbound.c {
|
||||
if elem == nil {
|
||||
return
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
validTailPacket := -1
|
||||
dataPacketReceived := false
|
||||
rxBytesLen := uint64(0)
|
||||
for i, elem := range elemsContainer.elems {
|
||||
var err error
|
||||
elem.Lock()
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
validTailPacket = i
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
||||
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
dataPacketReceived = true
|
||||
peer.timersDataReceived()
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
case ipv4.Version:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
case 6:
|
||||
case ipv6.Version:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef(
|
||||
"Packet with invalid IP version from %v",
|
||||
peer,
|
||||
)
|
||||
continue
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
goto skip
|
||||
}
|
||||
|
||||
bufs = append(
|
||||
bufs,
|
||||
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
|
||||
)
|
||||
}
|
||||
|
||||
peer.rxBytes.Add(rxBytesLen)
|
||||
if validTailPacket >= 0 {
|
||||
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
}
|
||||
if dataPacketReceived {
|
||||
peer.timersDataReceived()
|
||||
}
|
||||
if len(bufs) > 0 {
|
||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||
device.log.Errorf("Failed to write packet to TUN device: %v", err)
|
||||
}
|
||||
if len(peer.queue.inbound.c) == 0 {
|
||||
err = device.tun.device.Flush()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("Unable to flush packets: %v", err)
|
||||
}
|
||||
}
|
||||
for _, elem := range elemsContainer.elems {
|
||||
skip:
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
bufs = bufs[:0]
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
|
|
364
device/send.go
364
device/send.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -12,10 +12,9 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
@ -46,6 +45,7 @@ import (
|
|||
*/
|
||||
|
||||
type QueueOutboundElement struct {
|
||||
sync.Mutex
|
||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||
packet []byte // slice of "buffer" (always!)
|
||||
nonce uint64 // nonce for encryption
|
||||
|
@ -53,14 +53,10 @@ type QueueOutboundElement struct {
|
|||
peer *Peer // related peer
|
||||
}
|
||||
|
||||
type QueueOutboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueOutboundElement
|
||||
}
|
||||
|
||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||
elem := device.GetOutboundElement()
|
||||
elem.buffer = device.GetMessageBuffer()
|
||||
elem.Mutex = sync.Mutex{}
|
||||
elem.nonce = 0
|
||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||
return elem
|
||||
|
@ -80,17 +76,14 @@ func (elem *QueueOutboundElement) clearPointers() {
|
|||
/* Queues a keepalive if no packets are queued for peer
|
||||
*/
|
||||
func (peer *Peer) SendKeepalive() {
|
||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||
if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
|
||||
elem := peer.device.NewOutboundElement()
|
||||
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||
select {
|
||||
case peer.queue.staged <- elemsContainer:
|
||||
case peer.queue.staged <- elem:
|
||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||
default:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
|
@ -98,7 +91,7 @@ func (peer *Peer) SendKeepalive() {
|
|||
|
||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
if !isRetry {
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
||||
}
|
||||
|
||||
peer.handshake.mutex.RLock()
|
||||
|
@ -123,56 +116,17 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
var sendBuffer [][]byte
|
||||
|
||||
// so only packet processed for cookie generation
|
||||
var junkedHeader []byte
|
||||
if peer.device.version >= VersionAwg {
|
||||
var junks [][]byte
|
||||
if peer.device.version == VersionAwgSpecialHandshake {
|
||||
peer.device.awg.Mux.RLock()
|
||||
// set junks depending on packet type
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||
if junks != nil {
|
||||
peer.device.log.Verbosef("%v - Special junks sent", peer)
|
||||
}
|
||||
peer.device.awg.Mux.RUnlock()
|
||||
} else {
|
||||
junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
|
||||
}
|
||||
peer.device.awg.Mux.RLock()
|
||||
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
|
||||
peer.device.awg.Mux.RUnlock()
|
||||
|
||||
if len(junks) > 0 {
|
||||
err = peer.SendBuffers(junks)
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
var buff [MessageInitiationSize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
sendBuffer = append(sendBuffer, junkedHeader)
|
||||
|
||||
err = peer.SendAndCountBuffers(sendBuffer)
|
||||
err = peer.SendBuffer(packet)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
|
@ -194,19 +148,11 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
||||
var buff [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
if err != nil {
|
||||
|
@ -218,57 +164,27 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
|
||||
err = peer.SendBuffer(packet)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) SendHandshakeCookie(
|
||||
initiatingElem *QueueHandshakeElement,
|
||||
) error {
|
||||
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
||||
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
||||
|
||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
||||
msgType := DefaultMessageCookieReplyType
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
var err error
|
||||
msgType, err = device.awg.GetMsgType(DefaultMessageCookieReplyType)
|
||||
device.awg.Mux.RUnlock()
|
||||
if err != nil {
|
||||
device.log.Errorf("Get message type for cookie reply: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
reply, err := device.cookieChecker.CreateReply(
|
||||
initiatingElem.packet,
|
||||
sender,
|
||||
initiatingElem.endpoint.DstToBytes(),
|
||||
msgType,
|
||||
)
|
||||
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to create cookie reply: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
var buff [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
|
||||
junkedHeader = append(junkedHeader, writer.Bytes()...)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
|
||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -277,12 +193,17 @@ func (peer *Peer) keepKeyFreshSending() {
|
|||
if keypair == nil {
|
||||
return
|
||||
}
|
||||
nonce := keypair.sendNonce.Load()
|
||||
nonce := atomic.LoadUint64(&keypair.sendNonce)
|
||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
}
|
||||
|
||||
/* Reads packets from the TUN and inserts
|
||||
* into staged queue for peer
|
||||
*
|
||||
* Obs. Single instance per TUN device
|
||||
*/
|
||||
func (device *Device) RoutineReadFromTUN() {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||
|
@ -292,53 +213,49 @@ func (device *Device) RoutineReadFromTUN() {
|
|||
|
||||
device.log.Verbosef("Routine: TUN reader - started")
|
||||
|
||||
var (
|
||||
batchSize = device.BatchSize()
|
||||
readErr error
|
||||
elems = make([]*QueueOutboundElement, batchSize)
|
||||
bufs = make([][]byte, batchSize)
|
||||
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||
count = 0
|
||||
sizes = make([]int, batchSize)
|
||||
offset = MessageTransportHeaderSize
|
||||
)
|
||||
var elem *QueueOutboundElement
|
||||
|
||||
for i := range elems {
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, elem := range elems {
|
||||
for {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
}
|
||||
}()
|
||||
elem = device.NewOutboundElement()
|
||||
|
||||
for {
|
||||
// read packets
|
||||
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
||||
for i := 0; i < count; i++ {
|
||||
if sizes[i] < 1 {
|
||||
// read packet
|
||||
|
||||
offset := MessageTransportHeaderSize
|
||||
size, err := device.tun.device.Read(elem.buffer[:], offset)
|
||||
if err != nil {
|
||||
if !device.isClosed() {
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", err)
|
||||
}
|
||||
go device.Close()
|
||||
}
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
return
|
||||
}
|
||||
|
||||
if size == 0 || size > MaxContentSize {
|
||||
continue
|
||||
}
|
||||
|
||||
elem := elems[i]
|
||||
elem.packet = bufs[i][offset : offset+sizes[i]]
|
||||
elem.packet = elem.buffer[offset : offset+size]
|
||||
|
||||
// lookup peer
|
||||
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
case ipv4.Version:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case 6:
|
||||
case ipv6.Version:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
@ -352,63 +269,25 @@ func (device *Device) RoutineReadFromTUN() {
|
|||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetOutboundElementsContainer()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
for peer, elemsForPeer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePackets(elemsForPeer)
|
||||
if peer.isRunning.Get() {
|
||||
peer.StagePacket(elem)
|
||||
elem = nil
|
||||
peer.SendStagedPackets()
|
||||
} else {
|
||||
for _, elem := range elemsForPeer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsForPeer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||
// TODO: record stat for this
|
||||
// This will happen if MSS is surprisingly small (< 576)
|
||||
// coincident with reasonably high throughput.
|
||||
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||
continue
|
||||
}
|
||||
if !device.isClosed() {
|
||||
if !errors.Is(readErr, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||
}
|
||||
go device.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||
func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
|
||||
for {
|
||||
select {
|
||||
case peer.queue.staged <- elems:
|
||||
case peer.queue.staged <- elem:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case tooOld := <-peer.queue.staged:
|
||||
for _, elem := range tooOld.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(tooOld)
|
||||
peer.device.PutMessageBuffer(tooOld.buffer)
|
||||
peer.device.PutOutboundElement(tooOld)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
@ -421,60 +300,33 @@ top:
|
|||
}
|
||||
|
||||
keypair := peer.keypairs.Current()
|
||||
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||
select {
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
i := 0
|
||||
for _, elem := range elemsContainer.elems {
|
||||
case elem := <-peer.queue.staged:
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
if elemsContainerOOO == nil {
|
||||
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||
}
|
||||
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||
continue
|
||||
} else {
|
||||
elemsContainer.elems[i] = elem
|
||||
i++
|
||||
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
|
||||
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
|
||||
goto top
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
elemsContainer.elems = elemsContainer.elems[:i]
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||
}
|
||||
|
||||
if len(elemsContainer.elems) == 0 {
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
goto top
|
||||
}
|
||||
elem.Lock()
|
||||
|
||||
// add to parallel and sequential queue
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.outbound.c <- elemsContainer
|
||||
peer.device.queue.encryption.c <- elemsContainer
|
||||
if peer.isRunning.Get() {
|
||||
peer.queue.outbound.c <- elem
|
||||
peer.device.queue.encryption.c <- elem
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
goto top
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
@ -484,12 +336,9 @@ top:
|
|||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
for _, elem := range elemsContainer.elems {
|
||||
case elem := <-peer.queue.staged:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
@ -523,8 +372,7 @@ func (device *Device) RoutineEncryption(id int) {
|
|||
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||
|
||||
for elemsContainer := range device.queue.encryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
for elem := range device.queue.encryption.c {
|
||||
// populate header fields
|
||||
header := elem.buffer[:MessageTransportHeaderSize]
|
||||
|
||||
|
@ -532,25 +380,12 @@ func (device *Device) RoutineEncryption(id int) {
|
|||
fieldReceiver := header[4:8]
|
||||
fieldNonce := header[8:16]
|
||||
|
||||
msgType := DefaultMessageTransportType
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
var err error
|
||||
msgType, err = device.awg.GetMsgType(DefaultMessageTransportType)
|
||||
device.awg.Mux.RUnlock()
|
||||
if err != nil {
|
||||
device.log.Errorf("get message type for transport: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint32(fieldType, msgType)
|
||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||
|
||||
// pad content to multiple of 16
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
|
||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||
|
||||
// encrypt content and release to consumer
|
||||
|
@ -562,12 +397,16 @@ func (device *Device) RoutineEncryption(id int) {
|
|||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
elem.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
/* Sequentially reads packets from queue and sends to endpoint
|
||||
*
|
||||
* Obs. Single instance per peer.
|
||||
* The routine terminates then the outbound queue is closed.
|
||||
*/
|
||||
func (peer *Peer) RoutineSequentialSender() {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||
|
@ -575,67 +414,36 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.outbound.c {
|
||||
bufs = bufs[:0]
|
||||
if elemsContainer == nil {
|
||||
for elem := range peer.queue.outbound.c {
|
||||
if elem == nil {
|
||||
return
|
||||
}
|
||||
if !peer.isRunning.Load() {
|
||||
elem.Lock()
|
||||
if !peer.isRunning.Get() {
|
||||
// peer has been stopped; return re-usable elems to the shared pool.
|
||||
// This is an optimization only. It is possible for the peer to be stopped
|
||||
// immediately after this check, in which case, elem will get processed.
|
||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||
// The timers and SendBuffer code are resilient to a few stragglers.
|
||||
// TODO: rework peer shutdown order to ensure
|
||||
// that we never accidentally keep timers alive longer than necessary.
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
|
||||
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = append(junkedHeader, elem.packet...)
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendAndCountBuffers(bufs)
|
||||
if dataSent {
|
||||
// send message and return buffer to pool
|
||||
|
||||
err := peer.SendBuffer(elem.packet)
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
if err != nil {
|
||||
var errGSO conn.ErrUDPGSODisabled
|
||||
if errors.As(err, &errGSO) {
|
||||
device.log.Verbosef(err.Error())
|
||||
err = errGSO.RetryErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||
device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
|
@ -9,7 +9,7 @@
|
|||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code remains platform dependent.
|
||||
* So this code is remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -20,15 +20,12 @@ import (
|
|||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
if !conn.StdNetSupportsStickySockets {
|
||||
return nil, nil
|
||||
}
|
||||
if _, ok := bind.(*conn.StdNetBind); !ok {
|
||||
if _, ok := bind.(*conn.LinuxSocketBind); !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -47,7 +44,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
|
|||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
|
@ -110,17 +107,17 @@ func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlink
|
|||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.Lock()
|
||||
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
pePtr.peer.Lock()
|
||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
|
||||
pePtr.peer.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
|
@ -134,18 +131,18 @@ func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlink
|
|||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val == nil {
|
||||
peer.endpoint.Unlock()
|
||||
peer.RLock()
|
||||
if peer.endpoint == nil {
|
||||
peer.RUnlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.endpoint.Unlock()
|
||||
peer.RUnlock()
|
||||
continue
|
||||
}
|
||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||
peer.endpoint.Unlock()
|
||||
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
|
||||
peer.RUnlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
|
@ -172,12 +169,12 @@ func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlink
|
|||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
nativeEP.DstIP().As4(),
|
||||
nativeEP.Dst4().Addr,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
nativeEP.SrcIP().As4(),
|
||||
nativeEP.Src4().Src,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
|
@ -188,10 +185,10 @@ func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlink
|
|||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint.val,
|
||||
endpoint: &peer.endpoint,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.endpoint.Unlock()
|
||||
peer.RUnlock()
|
||||
i++
|
||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
|
@ -207,7 +204,7 @@ func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlink
|
|||
}
|
||||
|
||||
func createNetlinkRouteSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
|
@ -9,6 +9,7 @@ package device
|
|||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
@ -73,11 +74,11 @@ func (timer *Timer) IsPending() bool {
|
|||
}
|
||||
|
||||
func (peer *Peer) timersActive() bool {
|
||||
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
|
||||
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
|
||||
}
|
||||
|
||||
func expiredRetransmitHandshake(peer *Peer) {
|
||||
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||
|
||||
if peer.timersActive() {
|
||||
|
@ -96,11 +97,15 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||
}
|
||||
} else {
|
||||
peer.timers.handshakeAttempts.Add(1)
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||
atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
|
||||
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.Unlock()
|
||||
|
||||
peer.SendHandshakeInitiation(true)
|
||||
}
|
||||
|
@ -108,8 +113,8 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||
|
||||
func expiredSendKeepalive(peer *Peer) {
|
||||
peer.SendKeepalive()
|
||||
if peer.timers.needAnotherKeepalive.Load() {
|
||||
peer.timers.needAnotherKeepalive.Store(false)
|
||||
if peer.timers.needAnotherKeepalive.Get() {
|
||||
peer.timers.needAnotherKeepalive.Set(false)
|
||||
if peer.timersActive() {
|
||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||
}
|
||||
|
@ -119,7 +124,11 @@ func expiredSendKeepalive(peer *Peer) {
|
|||
func expiredNewHandshake(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.Unlock()
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
|
||||
|
@ -129,7 +138,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
|
|||
}
|
||||
|
||||
func expiredPersistentKeepalive(peer *Peer) {
|
||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
|
@ -147,7 +156,7 @@ func (peer *Peer) timersDataReceived() {
|
|||
if !peer.timers.sendKeepalive.IsPending() {
|
||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||
} else {
|
||||
peer.timers.needAnotherKeepalive.Store(true)
|
||||
peer.timers.needAnotherKeepalive.Set(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -178,9 +187,9 @@ func (peer *Peer) timersHandshakeComplete() {
|
|||
if peer.timersActive() {
|
||||
peer.timers.retransmitHandshake.Del()
|
||||
}
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||
peer.lastHandshakeNano.Store(time.Now().UnixNano())
|
||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
||||
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
||||
|
@ -192,7 +201,7 @@ func (peer *Peer) timersSessionDerived() {
|
|||
|
||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||
keepalive := peer.persistentKeepaliveInterval.Load()
|
||||
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
|
||||
if keepalive > 0 && peer.timersActive() {
|
||||
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||
}
|
||||
|
@ -207,9 +216,9 @@ func (peer *Peer) timersInit() {
|
|||
}
|
||||
|
||||
func (peer *Peer) timersStart() {
|
||||
peer.timers.handshakeAttempts.Store(0)
|
||||
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||
peer.timers.needAnotherKeepalive.Store(false)
|
||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
||||
peer.timers.needAnotherKeepalive.Set(false)
|
||||
}
|
||||
|
||||
func (peer *Peer) timersStop() {
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1420
|
||||
|
@ -32,7 +33,7 @@ func (device *Device) RoutineTUNEventReader() {
|
|||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
||||
mtu = MaxContentSize
|
||||
}
|
||||
old := device.tun.mtu.Swap(int32(mtu))
|
||||
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
|
||||
if int(old) != mtu {
|
||||
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
||||
}
|
||||
|
|
258
device/uapi.go
258
device/uapi.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -16,10 +16,10 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
type IPCError struct {
|
||||
|
@ -98,72 +98,35 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
if device.isAWG() {
|
||||
if device.awg.Cfg.JunkPacketCount != 0 {
|
||||
sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
|
||||
}
|
||||
if device.awg.Cfg.JunkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
|
||||
}
|
||||
if device.awg.Cfg.JunkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
|
||||
}
|
||||
if device.awg.Cfg.InitHeaderJunkSize != 0 {
|
||||
sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
|
||||
}
|
||||
if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
|
||||
sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
|
||||
}
|
||||
if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
|
||||
sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
|
||||
}
|
||||
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
|
||||
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
|
||||
}
|
||||
for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
|
||||
if magicHeader.Min > 4 {
|
||||
if magicHeader.Min == magicHeader.Max {
|
||||
sendf("h%d=%d", i+1, magicHeader.Min)
|
||||
continue
|
||||
}
|
||||
|
||||
sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
|
||||
}
|
||||
}
|
||||
|
||||
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||
for _, field := range specialJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range device.peers.keyMap {
|
||||
// Serialize peer state.
|
||||
peer.handshake.mutex.RLock()
|
||||
// Do the work in an anonymous function so that we can use defer.
|
||||
func() {
|
||||
peer.RLock()
|
||||
defer peer.RUnlock()
|
||||
|
||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||
peer.handshake.mutex.RUnlock()
|
||||
sendf("protocol_version=1")
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||
if peer.endpoint != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.DstToString())
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
nano := peer.lastHandshakeNano.Load()
|
||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
||||
secs := nano / time.Second.Nanoseconds()
|
||||
nano %= time.Second.Nanoseconds()
|
||||
|
||||
sendf("last_handshake_time_sec=%d", secs)
|
||||
sendf("last_handshake_time_nsec=%d", nano)
|
||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
|
||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
||||
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
})
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -190,28 +153,17 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
tempAwg := awg.Protocol{}
|
||||
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
err := device.handlePostConfig(&tempAwg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
return nil
|
||||
}
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorProtocol,
|
||||
"failed to parse line %q",
|
||||
line,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
|
||||
}
|
||||
|
||||
if key == "public_key" {
|
||||
|
@ -229,7 +181,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value, &tempAwg)
|
||||
err = device.handleDeviceLine(key, value)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
|
@ -237,10 +189,6 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
err = device.handlePostConfig(&tempAwg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
|
@ -249,7 +197,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
|
||||
func (device *Device) handleDeviceLine(key, value string) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
|
@ -290,122 +238,11 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
|||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set replace_peers, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
||||
case "jc":
|
||||
junkPacketCount, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||
tempAwg.Cfg.JunkPacketCount = junkPacketCount
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "jmin":
|
||||
junkPacketMinSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||
tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "jmax":
|
||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||
tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "s1":
|
||||
initPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||
tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "s2":
|
||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||
tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "s3":
|
||||
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
|
||||
tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
|
||||
case "s4":
|
||||
transportPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
|
||||
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h1":
|
||||
initMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h2":
|
||||
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h3":
|
||||
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h4":
|
||||
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
}
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "i1", "i2", "i3", "i4", "i5":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
generators, err := awg.ParseTagJunkGenerator(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating %s", key)
|
||||
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
|
@ -426,7 +263,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
|||
return
|
||||
}
|
||||
if peer.created {
|
||||
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
|
||||
}
|
||||
if peer.device.isUp() {
|
||||
peer.Start()
|
||||
|
@ -437,10 +274,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
func (device *Device) handlePublicKeyLine(
|
||||
peer *ipcSetPeer,
|
||||
value string,
|
||||
) error {
|
||||
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
||||
// Load/create the peer we are configuring.
|
||||
var publicKey NoisePublicKey
|
||||
err := publicKey.FromHex(value)
|
||||
|
@ -470,19 +304,12 @@ func (device *Device) handlePublicKeyLine(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handlePeerLine(
|
||||
peer *ipcSetPeer,
|
||||
key, value string,
|
||||
) error {
|
||||
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
||||
switch key {
|
||||
case "update_only":
|
||||
// allow disabling of creation
|
||||
if value != "true" {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set update only, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
}
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
|
@ -519,23 +346,19 @@ func (device *Device) handlePeerLine(
|
|||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||
}
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
peer.endpoint.val = endpoint
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
peer.endpoint = endpoint
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set persistent keepalive interval: %w",
|
||||
err,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||
}
|
||||
|
||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
||||
|
||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||
peer.pkaOn = old == 0 && secs != 0
|
||||
|
@ -543,11 +366,7 @@ func (device *Device) handlePeerLine(
|
|||
case "replace_allowed_ips":
|
||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||
if value != "true" {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to replace allowedips, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
|
@ -555,14 +374,7 @@ func (device *Device) handlePeerLine(
|
|||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
add := true
|
||||
verb := "Adding"
|
||||
if len(value) > 0 && value[0] == '-' {
|
||||
add = false
|
||||
verb = "Removing"
|
||||
value = value[1:]
|
||||
}
|
||||
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
|
@ -570,11 +382,7 @@ func (device *Device) handlePeerLine(
|
|||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
if add {
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
} else {
|
||||
device.allowedips.Remove(prefix, peer.Peer)
|
||||
}
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
|
@ -626,11 +434,7 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
|||
return
|
||||
}
|
||||
if nextByte != '\n' {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"trailing character in UAPI get: %q",
|
||||
nextByte,
|
||||
)
|
||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||
break
|
||||
}
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
package main
|
||||
|
||||
|
|
25
go.mod
25
go.mod
|
@ -1,23 +1,10 @@
|
|||
module github.com/amnezia-vpn/amneziawg-go
|
||||
module golang.zx2c4.com/wireguard
|
||||
|
||||
go 1.24.4
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tevino/abool v1.2.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/net v0.41.0
|
||||
golang.org/x/sys v0.33.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/time v0.9.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f
|
||||
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
|
||||
)
|
||||
|
|
38
go.sum
38
go.sum
|
@ -1,30 +1,8 @@
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
|
||||
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 h1:A9i04dxx7Cribqbs8jf3FQLogkL/CV2YN7hj9KWJCkc=
|
||||
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package namedpipe
|
||||
|
||||
|
@ -53,7 +54,7 @@ type file struct {
|
|||
handle windows.Handle
|
||||
wg sync.WaitGroup
|
||||
wgLock sync.RWMutex
|
||||
closing atomic.Bool
|
||||
closing uint32 // used as atomic boolean
|
||||
socket bool
|
||||
readDeadline deadlineHandler
|
||||
writeDeadline deadlineHandler
|
||||
|
@ -64,7 +65,7 @@ type deadlineHandler struct {
|
|||
channel timeoutChan
|
||||
channelLock sync.RWMutex
|
||||
timer *time.Timer
|
||||
timedout atomic.Bool
|
||||
timedout uint32 // used as atomic boolean
|
||||
}
|
||||
|
||||
// makeFile makes a new file from an existing file handle
|
||||
|
@ -88,7 +89,7 @@ func makeFile(h windows.Handle) (*file, error) {
|
|||
func (f *file) closeHandle() {
|
||||
f.wgLock.Lock()
|
||||
// Atomically set that we are closing, releasing the resources only once.
|
||||
if f.closing.Swap(true) == false {
|
||||
if atomic.SwapUint32(&f.closing, 1) == 0 {
|
||||
f.wgLock.Unlock()
|
||||
// cancel all IO and wait for it to complete
|
||||
windows.CancelIoEx(f.handle, nil)
|
||||
|
@ -111,7 +112,7 @@ func (f *file) Close() error {
|
|||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||
func (f *file) prepareIo() (*ioOperation, error) {
|
||||
f.wgLock.RLock()
|
||||
if f.closing.Load() {
|
||||
if atomic.LoadUint32(&f.closing) == 1 {
|
||||
f.wgLock.RUnlock()
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
|
@ -143,7 +144,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
|
|||
return int(bytes), err
|
||||
}
|
||||
|
||||
if f.closing.Load() {
|
||||
if atomic.LoadUint32(&f.closing) == 1 {
|
||||
windows.CancelIoEx(f.handle, &c.o)
|
||||
}
|
||||
|
||||
|
@ -159,7 +160,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
|
|||
case r = <-c.ch:
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED {
|
||||
if f.closing.Load() {
|
||||
if atomic.LoadUint32(&f.closing) == 1 {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
} else if err != nil && f.socket {
|
||||
|
@ -191,7 +192,7 @@ func (f *file) Read(b []byte) (int, error) {
|
|||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.readDeadline.timedout.Load() {
|
||||
if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
|
@ -218,7 +219,7 @@ func (f *file) Write(b []byte) (int, error) {
|
|||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.writeDeadline.timedout.Load() {
|
||||
if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
|
@ -255,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
|
|||
}
|
||||
d.timer = nil
|
||||
}
|
||||
d.timedout.Store(false)
|
||||
atomic.StoreUint32(&d.timedout, 0)
|
||||
|
||||
select {
|
||||
case <-d.channel:
|
||||
|
@ -270,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
|
|||
}
|
||||
|
||||
timeoutIO := func() {
|
||||
d.timedout.Store(true)
|
||||
atomic.StoreUint32(&d.timedout, 1)
|
||||
close(d.channel)
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
||||
package namedpipe
|
||||
|
@ -28,7 +29,7 @@ type pipe struct {
|
|||
|
||||
type messageBytePipe struct {
|
||||
pipe
|
||||
writeClosed atomic.Bool
|
||||
writeClosed int32
|
||||
readEOF bool
|
||||
}
|
||||
|
||||
|
@ -50,17 +51,17 @@ func (f *pipe) SetDeadline(t time.Time) error {
|
|||
|
||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||
func (f *messageBytePipe) CloseWrite() error {
|
||||
if !f.writeClosed.CompareAndSwap(false, true) {
|
||||
if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
err := f.file.Flush()
|
||||
if err != nil {
|
||||
f.writeClosed.Store(false)
|
||||
atomic.StoreInt32(&f.writeClosed, 0)
|
||||
return err
|
||||
}
|
||||
_, err = f.file.Write(nil)
|
||||
if err != nil {
|
||||
f.writeClosed.Store(false)
|
||||
atomic.StoreInt32(&f.writeClosed, 0)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -69,7 +70,7 @@ func (f *messageBytePipe) CloseWrite() error {
|
|||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||
// they are used to implement CloseWrite.
|
||||
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
||||
if f.writeClosed.Load() {
|
||||
if atomic.LoadInt32(&f.writeClosed) != 0 {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if len(b) == 0 {
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package namedpipe_test
|
||||
|
||||
|
@ -20,8 +21,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/ipc/namedpipe"
|
||||
)
|
||||
|
||||
func randomPipePath() string {
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
// Made up sentinel error codes for {js,wasip1}/wasm.
|
||||
// Made up sentinel error codes for the js/wasm platform.
|
||||
const (
|
||||
IpcErrorIO = 1
|
||||
IpcErrorInvalid = 2
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -9,8 +9,8 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
|
@ -96,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||
}
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
var buf [0]byte
|
||||
var buff [0]byte
|
||||
for {
|
||||
defer uapi.inotifyRWCancel.Close()
|
||||
// start with lstat to avoid race condition
|
||||
|
@ -104,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||
l.connErr <- err
|
||||
return
|
||||
}
|
||||
_, err := uapi.inotifyRWCancel.Read(buf[:])
|
||||
_, err := uapi.inotifyRWCancel.Read(buff[:])
|
||||
if err != nil {
|
||||
l.connErr <- err
|
||||
return
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -26,7 +26,7 @@ const (
|
|||
|
||||
// socketDirectory is variable because it is modified by a linker
|
||||
// flag in wireguard-android.
|
||||
var socketDirectory = "/var/run/amneziawg"
|
||||
var socketDirectory = "/var/run/wireguard"
|
||||
|
||||
func sockPath(iface string) string {
|
||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -8,8 +8,8 @@ package ipc
|
|||
import (
|
||||
"net"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/ipc/namedpipe"
|
||||
)
|
||||
|
||||
// TODO: replace these with actual standard windows error numbers from the win package
|
||||
|
@ -62,7 +62,7 @@ func init() {
|
|||
func UAPIListen(name string) (net.Listener, error) {
|
||||
listener, err := (&namedpipe.ListenConfig{
|
||||
SecurityDescriptor: UAPISecurityDescriptor,
|
||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
|
||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
38
main.go
38
main.go
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -13,12 +13,12 @@ import (
|
|||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -46,20 +46,20 @@ func warning() {
|
|||
return
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
|
||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
|
||||
fmt.Fprintln(os.Stderr, "│ │")
|
||||
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
|
||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
|
||||
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
|
||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
|
||||
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
||||
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
|
||||
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
|
||||
fmt.Fprintln(os.Stderr, "│ │")
|
||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
|
||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
||||
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,7 @@ func main() {
|
|||
|
||||
// open TUN device (or use supplied fd)
|
||||
|
||||
tdev, err := func() (tun.Device, error) {
|
||||
tun, err := func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
if tunFdStr == "" {
|
||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
||||
|
@ -124,7 +124,7 @@ func main() {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
err = syscall.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func main() {
|
|||
}()
|
||||
|
||||
if err == nil {
|
||||
realInterfaceName, err2 := tdev.Name()
|
||||
realInterfaceName, err2 := tun.Name()
|
||||
if err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func main() {
|
|||
fmt.Sprintf("(%s) ", interfaceName),
|
||||
)
|
||||
|
||||
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create TUN device: %v", err)
|
||||
|
@ -196,7 +196,7 @@ func main() {
|
|||
files[0], // stdin
|
||||
files[1], // stdout
|
||||
files[2], // stderr
|
||||
tdev.File(),
|
||||
tun.File(),
|
||||
fileUAPI,
|
||||
},
|
||||
Dir: ".",
|
||||
|
@ -222,7 +222,7 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
||||
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
|
||||
|
||||
logger.Verbosef("Device started")
|
||||
|
||||
|
@ -250,7 +250,7 @@ func main() {
|
|||
|
||||
// wait for program to terminate
|
||||
|
||||
signal.Notify(term, unix.SIGTERM)
|
||||
signal.Notify(term, syscall.SIGTERM)
|
||||
signal.Notify(term, os.Interrupt)
|
||||
|
||||
select {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -9,14 +9,13 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,13 +29,13 @@ func main() {
|
|||
}
|
||||
interfaceName := os.Args[1]
|
||||
|
||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
|
||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
|
||||
|
||||
logger := device.NewLogger(
|
||||
device.LogLevelVerbose,
|
||||
fmt.Sprintf("(%s) ", interfaceName),
|
||||
)
|
||||
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
||||
|
||||
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||
if err == nil {
|
||||
|
@ -82,7 +81,7 @@ func main() {
|
|||
|
||||
signal.Notify(term, os.Interrupt)
|
||||
signal.Notify(term, os.Kill)
|
||||
signal.Notify(term, windows.SIGTERM)
|
||||
signal.Notify(term, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-term:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue