mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-08 13:03:45 +02:00
Compare commits
63 commits
Author | SHA1 | Date | |
---|---|---|---|
|
1abd24b5b9 | ||
|
3f19f1c657 | ||
|
c207898480 | ||
|
fe75b639fa | ||
|
169ed49a46 | ||
|
eeb8aae13e | ||
|
99f2e6d66f | ||
|
d5359f52f0 | ||
|
6768090667 | ||
|
2cad62c40b | ||
|
8051f17147 | ||
|
ace3e11ef2 | ||
|
8a2b2bf4f4 | ||
|
75d6c67a67 | ||
|
ac8a885a03 | ||
|
6a7c878409 | ||
|
704d57c27a | ||
|
c0b6e6a200 | ||
|
c803ce1e5b | ||
|
deedce495a | ||
|
27e661d68e | ||
|
71be0eb3a6 | ||
|
e3f1273f8a | ||
|
c97b5b7615 | ||
|
668ddfd455 | ||
|
b8da08c106 | ||
|
2e3f7d122c | ||
|
2e7780471a | ||
|
87d8c00f86 | ||
|
c00bda9200 | ||
|
d2b0fc9789 | ||
|
77d39ff3b9 | ||
|
e433d13df6 | ||
|
3ddf952973 | ||
|
3f0a3bcfa0 | ||
|
4dddf62e57 | ||
|
827ec6e14b | ||
|
92e28a0d14 | ||
|
52fed4d362 | ||
|
9c6b3ff332 | ||
|
7de7a9a754 | ||
|
0c347529b8 | ||
|
6705978fc8 | ||
|
032e33f577 | ||
|
59101fd202 | ||
|
8bcfbac230 | ||
|
f0dfb5eacc | ||
|
9195025d8f | ||
|
cbd414dfec | ||
|
7155d20913 | ||
|
bfeb3954f6 | ||
|
e3c9ec8012 | ||
|
ce9d3866a3 | ||
|
e5f355e843 | ||
|
c05b2ee2a3 | ||
|
180c9284f3 | ||
|
015e11875d | ||
|
12269c2761 | ||
|
542e565baa | ||
|
7c20311b3d | ||
|
4ffa9c2032 | ||
|
d0bc03c707 | ||
|
1cf89f5339 |
119 changed files with 4060 additions and 1503 deletions
41
.github/workflows/build-if-tag.yml
vendored
Normal file
41
.github/workflows/build-if-tag.yml
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
name: build-if-tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||
|
||||
env:
|
||||
APP: amneziawg-go
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
name: build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Setup metadata
|
||||
uses: docker/metadata-action@v5
|
||||
id: metadata
|
||||
with:
|
||||
images: amneziavpn/${{ env.APP }}
|
||||
tags: type=semver,pattern={{version}}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
push: true
|
||||
tags: ${{ steps.metadata.outputs.tags }}
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1 @@
|
|||
wireguard-go
|
||||
amneziawg-go
|
18
Dockerfile
Normal file
18
Dockerfile
Normal file
|
@ -0,0 +1,18 @@
|
|||
FROM golang:1.24.4 as awg
|
||||
COPY . /awg
|
||||
WORKDIR /awg
|
||||
RUN go mod download && \
|
||||
go mod verify && \
|
||||
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG AWGTOOLS_RELEASE="1.0.20241018"
|
||||
|
||||
RUN apk --no-cache add iproute2 iptables bash && \
|
||||
cd /usr/bin/ && \
|
||||
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
|
||||
unzip -j alpine-3.19-amneziawg-tools.zip && \
|
||||
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
|
||||
ln -s /usr/bin/awg /usr/bin/wg && \
|
||||
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
|
||||
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go
|
12
Makefile
12
Makefile
|
@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
|
|||
|
||||
generate-version-and-build:
|
||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
|
||||
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
||||
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
||||
echo "$$ver" > version.go && \
|
||||
git update-index --assume-unchanged version.go || true
|
||||
@$(MAKE) wireguard-go
|
||||
@$(MAKE) amneziawg-go
|
||||
|
||||
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
||||
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
|
||||
go build -v -o "$@"
|
||||
|
||||
install: wireguard-go
|
||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
||||
install: amneziawg-go
|
||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
clean:
|
||||
rm -f wireguard-go
|
||||
rm -f amneziawg-go
|
||||
|
||||
.PHONY: all clean test install generate-version-and-build
|
||||
|
|
19
README.md
19
README.md
|
@ -11,17 +11,17 @@ As a result, AmneziaWG maintains high performance while adding an extra layer of
|
|||
Simply run:
|
||||
|
||||
```
|
||||
$ amnezia-wg wg0
|
||||
$ amneziawg-go wg0
|
||||
```
|
||||
|
||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
|
||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
|
||||
|
||||
To run amnezia-wg without forking to the background, pass `-f` or `--foreground`:
|
||||
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
|
||||
|
||||
```
|
||||
$ amnezia-wg -f wg0
|
||||
$ amneziawg-go -f wg0
|
||||
```
|
||||
When an interface is running, you may use [`amnezia-wg-tools `](https://github.com/amnezia-vpn/amnezia-wg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||
|
||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||
|
||||
|
@ -34,11 +34,11 @@ This will run on Linux; you should run amnezia-wg instead of using default linux
|
|||
### macOS
|
||||
|
||||
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
||||
This runs on MacOS, you should use it from [awg-apple](https://github.com/amnezia-vpn/awg-apple)
|
||||
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
|
||||
|
||||
### Windows
|
||||
|
||||
This runs on Windows, you should use it from [awg-windows](https://github.com/amnezia-vpn/awg-windows), which uses this as a module.
|
||||
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
|
||||
|
||||
|
||||
## Building
|
||||
|
@ -46,7 +46,8 @@ This runs on Windows, you should use it from [awg-windows](https://github.com/am
|
|||
This requires an installation of the latest version of [Go](https://go.dev/).
|
||||
|
||||
```
|
||||
$ git clone https://github.com/amnezia-vpn/amnezia-wg
|
||||
$ cd amnezia-wg
|
||||
$ git clone https://github.com/amnezia-vpn/amneziawg-go
|
||||
$ cd amneziawg-go
|
||||
$ make
|
||||
```
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -331,7 +331,7 @@ type ErrUDPGSODisabled struct {
|
|||
}
|
||||
|
||||
func (e ErrUDPGSODisabled) Error() string {
|
||||
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
|
||||
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
|
||||
}
|
||||
|
||||
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn/winrio"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bindtest
|
||||
|
@ -12,7 +12,7 @@ import (
|
|||
"net/netip"
|
||||
"os"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
)
|
||||
|
||||
type ChannelBind struct {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package conn implements WireGuard's network connections.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -13,6 +13,35 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
|
||||
func kernelVersion() (major, minor int) {
|
||||
var uname unix.Utsname
|
||||
if err := unix.Uname(&uname); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
values [2]int
|
||||
value, vi int
|
||||
)
|
||||
for _, c := range uname.Release {
|
||||
if '0' <= c && c <= '9' {
|
||||
value = (value * 10) + int(c-'0')
|
||||
} else {
|
||||
// Note that we're assuming N.N.N here.
|
||||
// If we see anything else, we are likely to mis-parse it.
|
||||
values[vi] = value
|
||||
vi++
|
||||
if vi >= len(values) {
|
||||
break
|
||||
}
|
||||
value = 0
|
||||
}
|
||||
}
|
||||
|
||||
return values[0], values[1]
|
||||
}
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
|
@ -60,6 +89,17 @@ func init() {
|
|||
|
||||
// Attempt to enable UDP_GRO
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
// Kernels below 5.12 are missing 98184612aca0 ("net:
|
||||
// udp: Add support for getsockopt(..., ..., UDP_GRO,
|
||||
// ..., ...);"), which means we can't read this back
|
||||
// later. We could pipe the return value through to
|
||||
// the rest of the code, but UDP_GRO is kind of buggy
|
||||
// anyway, so just gate this here.
|
||||
major, minor := kernelVersion()
|
||||
if major < 5 || (major == 5 && minor < 12) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
})
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
func errShouldDisableUDPGSO(_ error) bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -20,7 +20,9 @@ func errShouldDisableUDPGSO(err error) bool {
|
|||
// See:
|
||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||
return serr.Err == unix.EIO
|
||||
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
|
||||
// It occurs when the peer mtu + wg headers is greater than path mtu.
|
||||
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
@ -19,8 +19,10 @@ func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
|||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||
txOffload = errSyscall == nil
|
||||
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
||||
rxOffload = errSyscall == nil && opt == 1
|
||||
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
|
||||
// use setsockopt workaround
|
||||
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
rxOffload = errSyscall == nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winrio
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -223,19 +223,11 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
|
|||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
func (node *trieEntry) remove() {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
continue
|
||||
return
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
|
@ -248,12 +240,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
|||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
return
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
return
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
|
@ -262,6 +254,37 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
|||
*parent.parent.parentBit = child
|
||||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
var node *trieEntry
|
||||
var exact bool
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else {
|
||||
panic(errors.New("removing unknown address type"))
|
||||
}
|
||||
if !exact || node == nil || peer != node.peer {
|
||||
return
|
||||
}
|
||||
node.remove()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
elem.Value.(*trieEntry).remove()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
var peers []*Peer
|
||||
var allowedIPs AllowedIPs
|
||||
|
||||
rand.Seed(1)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
|
||||
for n := 0; n < NumberOfPeers; n++ {
|
||||
peers = append(peers, &Peer{})
|
||||
|
@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) {
|
|||
|
||||
for n := 0; n < NumberOfAddresses; n++ {
|
||||
var addr4 [4]byte
|
||||
rand.Read(addr4[:])
|
||||
rng.Read(addr4[:])
|
||||
cidr := uint8(rand.Intn(32) + 1)
|
||||
index := rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||
|
||||
var addr6 [16]byte
|
||||
rand.Read(addr6[:])
|
||||
rng.Read(addr6[:])
|
||||
cidr = uint8(rand.Intn(128) + 1)
|
||||
index = rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||
|
@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
for p = 0; ; p++ {
|
||||
for n := 0; n < NumberOfTests; n++ {
|
||||
var addr4 [4]byte
|
||||
rand.Read(addr4[:])
|
||||
rng.Read(addr4[:])
|
||||
peer1 := slow4.Lookup(addr4[:])
|
||||
peer2 := allowedIPs.Lookup(addr4[:])
|
||||
if peer1 != peer2 {
|
||||
|
@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) {
|
|||
}
|
||||
|
||||
var addr6 [16]byte
|
||||
rand.Read(addr6[:])
|
||||
rng.Read(addr6[:])
|
||||
peer1 = slow6.Lookup(addr6[:])
|
||||
peer2 = allowedIPs.Lookup(addr6[:])
|
||||
if peer1 != peer2 {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||
func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
|
||||
var trie *trieEntry
|
||||
var peers []*Peer
|
||||
root := parentIndirection{&trie, 2}
|
||||
|
||||
rand.Seed(1)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
|
||||
const AddressLength = 4
|
||||
|
||||
|
@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
|||
|
||||
for n := 0; n < addressNumber; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rand.Read(addr[:])
|
||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||
index := rand.Int() % peerNumber
|
||||
rng.Read(addr[:])
|
||||
cidr := uint8(rng.Uint32() % (AddressLength * 8))
|
||||
index := rng.Int() % peerNumber
|
||||
root.insert(addr[:], cidr, peers[index])
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rand.Read(addr[:])
|
||||
rng.Read(addr[:])
|
||||
trie.lookup(addr[:])
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
|
|||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||
if p != peer {
|
||||
|
@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
|
|||
allowedIPs.RemoveByPeer(a)
|
||||
|
||||
assertNEQ(a, 192, 168, 0, 1)
|
||||
|
||||
insert(a, 1, 0, 0, 0, 32)
|
||||
insert(a, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 1, 0, 0, 0)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 32)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(nil, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(b, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 24)
|
||||
assertNEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 1, 0, 0, 0, 32)
|
||||
assertNEQ(a, 1, 0, 0, 0)
|
||||
}
|
||||
|
||||
/* Test ported from kernel implementation:
|
||||
|
@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
|
|||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
|
@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
p := allowedIPs.Lookup(addr)
|
||||
if p == peer {
|
||||
t.Error("Assert NEQ failed")
|
||||
}
|
||||
}
|
||||
|
||||
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
||||
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
||||
insert(e, 0, 0, 0, 0, 0)
|
||||
|
@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
|
|||
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
||||
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
||||
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
||||
|
||||
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
}
|
||||
|
|
144
device/awg/awg.go
Normal file
144
device/awg/awg.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
type aSecCfgType struct {
|
||||
IsSet bool
|
||||
JunkPacketCount int
|
||||
JunkPacketMinSize int
|
||||
JunkPacketMaxSize int
|
||||
InitHeaderJunkSize int
|
||||
ResponseHeaderJunkSize int
|
||||
CookieReplyHeaderJunkSize int
|
||||
TransportHeaderJunkSize int
|
||||
InitPacketMagicHeader uint32
|
||||
ResponsePacketMagicHeader uint32
|
||||
UnderloadPacketMagicHeader uint32
|
||||
TransportPacketMagicHeader uint32
|
||||
// InitPacketMagicHeader Limit
|
||||
// ResponsePacketMagicHeader Limit
|
||||
// UnderloadPacketMagicHeader Limit
|
||||
// TransportPacketMagicHeader Limit
|
||||
}
|
||||
|
||||
type Limit struct {
|
||||
Min uint32
|
||||
Max uint32
|
||||
HeaderType uint32
|
||||
}
|
||||
|
||||
func NewLimit(min, max, headerType uint32) (Limit, error) {
|
||||
if min > max {
|
||||
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
|
||||
}
|
||||
|
||||
return Limit{
|
||||
Min: min,
|
||||
Max: max,
|
||||
HeaderType: headerType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
|
||||
// tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType)
|
||||
// var min, max, headerType uint32
|
||||
// _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType)
|
||||
// if err != nil {
|
||||
// return Limit{}, fmt.Errorf("invalid magic header format: %s", value)
|
||||
// }
|
||||
|
||||
limits := strings.Split(value, "-")
|
||||
if len(limits) != 2 {
|
||||
return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value)
|
||||
}
|
||||
|
||||
min, err := strconv.ParseUint(limits[0], 10, 32)
|
||||
if err != nil {
|
||||
return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err)
|
||||
}
|
||||
|
||||
max, err := strconv.ParseUint(limits[1], 10, 32)
|
||||
if err != nil {
|
||||
return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err)
|
||||
}
|
||||
|
||||
limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
|
||||
if err != nil {
|
||||
return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err)
|
||||
}
|
||||
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
type Limits []Limit
|
||||
|
||||
func NewLimits(limits []Limit) Limits {
|
||||
slices.SortFunc(limits, func(a, b Limit) int {
|
||||
if a.Min < b.Min {
|
||||
return -1
|
||||
} else if a.Min > b.Min {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
return Limits(limits)
|
||||
}
|
||||
|
||||
type Protocol struct {
|
||||
IsASecOn abool.AtomicBool
|
||||
// TODO: revision the need of the mutex
|
||||
ASecMux sync.RWMutex
|
||||
ASecCfg aSecCfgType
|
||||
JunkCreator junkCreator
|
||||
|
||||
HandshakeHandler SpecialHandshakeHandler
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
|
||||
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) {
|
||||
extraSize := 0
|
||||
if len(optExtraSize) == 1 {
|
||||
extraSize = optExtraSize[0]
|
||||
}
|
||||
|
||||
var junk []byte
|
||||
protocol.ASecMux.RLock()
|
||||
if junkSize != 0 {
|
||||
buf := make([]byte, 0, junkSize+extraSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
|
||||
if err != nil {
|
||||
protocol.ASecMux.RUnlock()
|
||||
return nil, err
|
||||
}
|
||||
junk = writer.Bytes()
|
||||
}
|
||||
protocol.ASecMux.RUnlock()
|
||||
|
||||
return junk, nil
|
||||
}
|
37
device/awg/internal/mock.go
Normal file
37
device/awg/internal/mock.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package internal
|
||||
|
||||
type mockGenerator struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func NewMockGenerator(size int) mockGenerator {
|
||||
return mockGenerator{size: size}
|
||||
}
|
||||
|
||||
func (m mockGenerator) Generate() []byte {
|
||||
return make([]byte, m.size)
|
||||
}
|
||||
|
||||
func (m mockGenerator) Size() int {
|
||||
return m.size
|
||||
}
|
||||
|
||||
func (m mockGenerator) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
type mockByteGenerator struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewMockByteGenerator(data []byte) mockByteGenerator {
|
||||
return mockByteGenerator{data: data}
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Generate() []byte {
|
||||
return bg.data
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Size() int {
|
||||
return len(bg.data)
|
||||
}
|
70
device/awg/junk_creator.go
Normal file
70
device/awg/junk_creator.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
v2 "math/rand/v2"
|
||||
)
|
||||
|
||||
type junkCreator struct {
|
||||
aSecCfg aSecCfgType
|
||||
cha8Rand *v2.ChaCha8
|
||||
}
|
||||
|
||||
// TODO: refactor param to only pass the junk related params
|
||||
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
|
||||
buf := make([]byte, 32)
|
||||
_, err := crand.Read(buf)
|
||||
if err != nil {
|
||||
return junkCreator{}, err
|
||||
}
|
||||
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
|
||||
if jc.aSecCfg.JunkPacketCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for range jc.aSecCfg.JunkPacketCount {
|
||||
packetSize := jc.randomPacketSize()
|
||||
junk, err := jc.randomJunkWithSize(packetSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create junk packet: %v", err)
|
||||
}
|
||||
*junks = append(*junks, junk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) randomPacketSize() int {
|
||||
return int(
|
||||
jc.cha8Rand.Uint64()%uint64(
|
||||
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
|
||||
),
|
||||
) + jc.aSecCfg.JunkPacketMinSize
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk, err := jc.randomJunkWithSize(size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create header junk: %v", err)
|
||||
}
|
||||
_, err = writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write header junk: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
||||
// TODO: use a memory pool to allocate
|
||||
junk := make([]byte, size)
|
||||
_, err := jc.cha8Rand.Read(junk)
|
||||
return junk, err
|
||||
}
|
115
device/awg/junk_creator_test.go
Normal file
115
device/awg/junk_creator_test.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
||||
jc, err := NewJunkCreator(aSecCfgType{
|
||||
IsSet: true,
|
||||
JunkPacketCount: 5,
|
||||
JunkPacketMinSize: 500,
|
||||
JunkPacketMaxSize: 1000,
|
||||
InitHeaderJunkSize: 30,
|
||||
ResponseHeaderJunkSize: 40,
|
||||
InitPacketMagicHeader: 123456,
|
||||
ResponsePacketMagicHeader: 67543,
|
||||
UnderloadPacketMagicHeader: 32345,
|
||||
TransportPacketMagicHeader: 123123,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("failed to create junk creator %v", err)
|
||||
return junkCreator{}, err
|
||||
}
|
||||
|
||||
return jc, nil
|
||||
}
|
||||
|
||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
|
||||
err := jc.CreateJunkPackets(&got)
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v; failed",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, junk := range got {
|
||||
key := string(junk)
|
||||
if seen[key] {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
||||
got,
|
||||
junk,
|
||||
)
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r1, _ := jc.randomJunkWithSize(10)
|
||||
r2, _ := jc.randomJunkWithSize(10)
|
||||
fmt.Printf("%v\n%v\n", r1, r2)
|
||||
if bytes.Equal(r1, r2) {
|
||||
t.Errorf("same junks %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for range [30]struct{}{} {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
|
||||
got > jc.aSecCfg.JunkPacketMaxSize {
|
||||
t.Errorf(
|
||||
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||
got,
|
||||
jc.aSecCfg.JunkPacketMinSize,
|
||||
jc.aSecCfg.JunkPacketMaxSize,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
err := jc.AppendJunk(buffer, 30)
|
||||
if err != nil &&
|
||||
buffer.Len() != len(s)+30 {
|
||||
t.Error("appendWithJunk() size don't match")
|
||||
}
|
||||
read := make([]byte, 50)
|
||||
buffer.Read(read)
|
||||
fmt.Println(string(read))
|
||||
})
|
||||
}
|
73
device/awg/special_handshake_handler.go
Normal file
73
device/awg/special_handshake_handler.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// TODO: atomic?/ and better way to use this
|
||||
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||
|
||||
// TODO
|
||||
var WaitResponse = struct {
|
||||
Channel chan struct{}
|
||||
ShouldWait *abool.AtomicBool
|
||||
}{
|
||||
make(chan struct{}, 1),
|
||||
abool.New(),
|
||||
}
|
||||
|
||||
type SpecialHandshakeHandler struct {
|
||||
isFirstDone bool
|
||||
SpecialJunk TagJunkPacketGenerators
|
||||
ControlledJunk TagJunkPacketGenerators
|
||||
|
||||
nextItime time.Time
|
||||
ITimeout time.Duration // seconds
|
||||
|
||||
IsSet bool
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) Validate() error {
|
||||
var errs []error
|
||||
if err := handler.SpecialJunk.Validate(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if err := handler.ControlledJunk.Validate(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||
if !handler.SpecialJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: create tests
|
||||
if !handler.isFirstDone {
|
||||
handler.isFirstDone = true
|
||||
} else if !handler.isTimeToSendSpecial() {
|
||||
return nil
|
||||
}
|
||||
|
||||
rv := handler.SpecialJunk.GeneratePackets()
|
||||
handler.nextItime = time.Now().Add(handler.ITimeout)
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool {
|
||||
return time.Now().After(handler.nextItime)
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte {
|
||||
if !handler.ControlledJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return handler.ControlledJunk.GeneratePackets()
|
||||
}
|
190
device/awg/tag_generator.go
Normal file
190
device/awg/tag_generator.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v2 "math/rand/v2"
|
||||
// "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type Generator interface {
|
||||
Generate() []byte
|
||||
Size() int
|
||||
}
|
||||
|
||||
type newGenerator func(string) (Generator, error)
|
||||
|
||||
type BytesGenerator struct {
|
||||
value []byte
|
||||
size int
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Generate() []byte {
|
||||
return bg.value
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Size() int {
|
||||
return bg.size
|
||||
}
|
||||
|
||||
func newBytesGenerator(param string) (Generator, error) {
|
||||
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
|
||||
if !hasPrefix {
|
||||
return nil, fmt.Errorf("not correct hex: %s", param)
|
||||
}
|
||||
|
||||
hex, err := hexToBytes(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hexToBytes: %w", err)
|
||||
}
|
||||
|
||||
return &BytesGenerator{value: hex, size: len(hex)}, nil
|
||||
}
|
||||
|
||||
func hexToBytes(hexStr string) ([]byte, error) {
|
||||
hexStr = strings.TrimPrefix(hexStr, "0x")
|
||||
hexStr = strings.TrimPrefix(hexStr, "0X")
|
||||
|
||||
// Ensure even length (pad with leading zero if needed)
|
||||
if len(hexStr)%2 != 0 {
|
||||
hexStr = "0" + hexStr
|
||||
}
|
||||
|
||||
return hex.DecodeString(hexStr)
|
||||
}
|
||||
|
||||
type RandomPacketGenerator struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
size int
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Generate() []byte {
|
||||
junk := make([]byte, rpg.size)
|
||||
rpg.cha8Rand.Read(junk)
|
||||
return junk
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Size() int {
|
||||
return rpg.size
|
||||
}
|
||||
|
||||
func newRandomPacketGenerator(param string) (Generator, error) {
|
||||
size, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("random packet parse int: %w", err)
|
||||
}
|
||||
|
||||
if size > 1000 {
|
||||
return nil, fmt.Errorf("random packet size must be less than 1000")
|
||||
}
|
||||
|
||||
buf := make([]byte, 32)
|
||||
_, err = crand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("random packet crand read: %w", err)
|
||||
}
|
||||
|
||||
return &RandomPacketGenerator{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newTimestampGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &TimestampGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitTimeoutGenerator struct {
|
||||
waitTimeout time.Duration
|
||||
}
|
||||
|
||||
func (wtg *WaitTimeoutGenerator) Generate() []byte {
|
||||
time.Sleep(wtg.waitTimeout)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (wtg *WaitTimeoutGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitTimeoutGenerator(param string) (Generator, error) {
|
||||
timeout, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("timeout parse int: %w", err)
|
||||
}
|
||||
|
||||
if timeout > 5000 {
|
||||
return nil, fmt.Errorf("timeout must be less than 5000ms")
|
||||
}
|
||||
|
||||
return &WaitTimeoutGenerator{
|
||||
waitTimeout: time.Duration(timeout) * time.Millisecond,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type PacketCounterGenerator struct {
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
// TODO: better way to handle counter tag
|
||||
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newPacketCounterGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &PacketCounterGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitResponseGenerator struct {
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Generate() []byte {
|
||||
WaitResponse.ShouldWait.Set()
|
||||
<-WaitResponse.Channel
|
||||
WaitResponse.ShouldWait.UnSet()
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitResponseGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &WaitResponseGenerator{}, nil
|
||||
}
|
189
device/awg/tag_generator_test.go
Normal file
189
device/awg/tag_generator_test.go
Normal file
|
@ -0,0 +1,189 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_newBytesGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "wrong start",
|
||||
args: args{
|
||||
param: "123456",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with X",
|
||||
args: args{
|
||||
param: "0X12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with x",
|
||||
args: args{
|
||||
param: "0x12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "valid hex",
|
||||
args: args{
|
||||
param: "0xf6ab3267fa",
|
||||
},
|
||||
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
{
|
||||
name: "valid hex with odd length",
|
||||
args: args{
|
||||
param: "0xfab3267fa",
|
||||
},
|
||||
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newBytesGenerator(tt.args.param)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
gotValues := got.Generate()
|
||||
require.Equal(t, tt.want, gotValues)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newRandomPacketGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newRandomPacketGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketCounterGenerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
param string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid empty param",
|
||||
param: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid non-empty param",
|
||||
param: "anything",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gen, err := newPacketCounterGenerator(tc.param)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8, gen.Size())
|
||||
|
||||
// Reset counter to known value for test
|
||||
initialCount := uint64(42)
|
||||
PacketCounter.Store(initialCount)
|
||||
|
||||
output := gen.Generate()
|
||||
require.Equal(t, 8, len(output))
|
||||
|
||||
// Verify counter value in output
|
||||
counterValue := binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount, counterValue)
|
||||
|
||||
// Increment counter and verify change
|
||||
PacketCounter.Add(1)
|
||||
output = gen.Generate()
|
||||
counterValue = binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount+1, counterValue)
|
||||
})
|
||||
}
|
||||
}
|
59
device/awg/tag_junk_packet_generator.go
Normal file
59
device/awg/tag_junk_packet_generator.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type TagJunkPacketGenerator struct {
|
||||
name string
|
||||
tagValue string
|
||||
|
||||
packetSize int
|
||||
generators []Generator
|
||||
}
|
||||
|
||||
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
|
||||
return TagJunkPacketGenerator{
|
||||
name: name,
|
||||
tagValue: tagValue,
|
||||
generators: make([]Generator, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
||||
tg.generators = append(tg.generators, generator)
|
||||
tg.packetSize += generator.Size()
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
|
||||
packet := make([]byte, 0, tg.packetSize)
|
||||
for _, generator := range tg.generators {
|
||||
packet = append(packet, generator.Generate()...)
|
||||
}
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) Name() string {
|
||||
return tg.name
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
|
||||
if len(tg.name) != 2 {
|
||||
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(tg.name[1:2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("name 2 char should be an int %w", err)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||
return IpcFields{
|
||||
Key: tg.name,
|
||||
Value: tg.tagValue,
|
||||
}
|
||||
}
|
210
device/awg/tag_junk_packet_generator_test.go
Normal file
210
device/awg/tag_junk_packet_generator_test.go
Normal file
|
@ -0,0 +1,210 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTagJunkGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
genName string
|
||||
size int
|
||||
expected TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "Create new generator with empty name",
|
||||
genName: "",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with valid name",
|
||||
genName: "T1",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T1",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with non-zero size",
|
||||
genName: "T2",
|
||||
size: 5,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 5),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
|
||||
require.Equal(t, tc.expected.name, result.name)
|
||||
require.Equal(t, tc.expected.packetSize, result.packetSize)
|
||||
require.Equal(t, cap(result.generators), len(tc.expected.generators))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorAppend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialState TagJunkPacketGenerator
|
||||
mockSize int
|
||||
expectedLength int
|
||||
expectedSize int
|
||||
}{
|
||||
{
|
||||
name: "Append to empty generator",
|
||||
initialState: newTagJunkPacketGenerator("T1", "", 0),
|
||||
mockSize: 5,
|
||||
expectedLength: 1,
|
||||
expectedSize: 5,
|
||||
},
|
||||
{
|
||||
name: "Append to non-empty generator",
|
||||
initialState: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 10,
|
||||
generators: make([]Generator, 2),
|
||||
},
|
||||
mockSize: 7,
|
||||
expectedLength: 3, // 2 existing + 1 new
|
||||
expectedSize: 17, // 10 + 7
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.initialState
|
||||
mockGen := internal.NewMockGenerator(tc.mockSize)
|
||||
|
||||
tg.append(mockGen)
|
||||
|
||||
require.Equal(t, tc.expectedLength, len(tg.generators))
|
||||
require.Equal(t, tc.expectedSize, tg.packetSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorGenerate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock generators for testing
|
||||
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
|
||||
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupGenerator func() TagJunkPacketGenerator
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "Generate with empty generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
return newTagJunkPacketGenerator("T1", "", 0)
|
||||
},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Generate with single generator",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T2", "", 0)
|
||||
tg.append(mockGen1)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02},
|
||||
},
|
||||
{
|
||||
name: "Generate with multiple generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T3", "", 0)
|
||||
tg.append(mockGen1)
|
||||
tg.append(mockGen2)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.setupGenerator()
|
||||
result := tg.generatePacket()
|
||||
|
||||
require.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorNameIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
generatorName string
|
||||
expectedIndex int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid name with digit",
|
||||
generatorName: "T5",
|
||||
expectedIndex: 5,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too short",
|
||||
generatorName: "T",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too long",
|
||||
generatorName: "T55",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - non-digit second character",
|
||||
generatorName: "TX",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := TagJunkPacketGenerator{name: tc.generatorName}
|
||||
index, err := tg.nameIndex()
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectedIndex, index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
66
device/awg/tag_junk_packet_generators.go
Normal file
66
device/awg/tag_junk_packet_generators.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package awg
|
||||
|
||||
import "fmt"
|
||||
|
||||
type TagJunkPacketGenerators struct {
|
||||
tagGenerators []TagJunkPacketGenerator
|
||||
length int
|
||||
DefaultJunkCount int // Jc
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) AppendGenerator(
|
||||
generator TagJunkPacketGenerator,
|
||||
) {
|
||||
generators.tagGenerators = append(generators.tagGenerators, generator)
|
||||
generators.length++
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) IsDefined() bool {
|
||||
return len(generators.tagGenerators) > 0
|
||||
}
|
||||
|
||||
// validate that packets were defined consecutively
|
||||
func (generators *TagJunkPacketGenerators) Validate() error {
|
||||
seen := make([]bool, len(generators.tagGenerators))
|
||||
for _, generator := range generators.tagGenerators {
|
||||
index, err := generator.nameIndex()
|
||||
if index > len(generators.tagGenerators) {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("name index: %w", err)
|
||||
} else {
|
||||
seen[index-1] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, found := range seen {
|
||||
if !found {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
|
||||
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
||||
|
||||
for i, tagGenerator := range generators.tagGenerators {
|
||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||
copy(rv[i], tagGenerator.generatePacket())
|
||||
PacketCounter.Inc()
|
||||
}
|
||||
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
|
||||
rv := make([]IpcFields, 0, len(tg.tagGenerators))
|
||||
for _, generator := range tg.tagGenerators {
|
||||
rv = append(rv, generator.IpcGetFields())
|
||||
}
|
||||
|
||||
return rv
|
||||
}
|
149
device/awg/tag_junk_packet_generators_test.go
Normal file
149
device/awg/tag_junk_packet_generators_test.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generator TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "append single generator",
|
||||
generator: newTagJunkPacketGenerator("t1", "", 10),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
|
||||
// Initial length should be 0
|
||||
require.Equal(t, 0, generators.length)
|
||||
require.Empty(t, generators.tagGenerators)
|
||||
|
||||
// After append, length should be 1 and generator should be added
|
||||
generators.AppendGenerator(tt.generator)
|
||||
require.Equal(t, 1, generators.length)
|
||||
require.Len(t, generators.tagGenerators, 1)
|
||||
require.Equal(t, tt.generator, generators.tagGenerators[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generators []TagJunkPacketGenerator
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "bad start",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "non-consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t2", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
newTagJunkPacketGenerator("t5", "", 10),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nameIndex error",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("error", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "name must be 2 character long",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
for _, gen := range tt.generators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
err := generators.Validate()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errMsg)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
|
||||
mockByte1 := []byte{0x01, 0x02}
|
||||
mockByte2 := []byte{0x03, 0x04, 0x05}
|
||||
mockGen1 := internal.NewMockByteGenerator(mockByte1)
|
||||
mockGen2 := internal.NewMockByteGenerator(mockByte2)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupGenerator func() []TagJunkPacketGenerator
|
||||
expected [][]byte
|
||||
}{
|
||||
{
|
||||
name: "generate with no default junk",
|
||||
setupGenerator: func() []TagJunkPacketGenerator {
|
||||
tg1 := newTagJunkPacketGenerator("t1", "", 0)
|
||||
tg1.append(mockGen1)
|
||||
tg1.append(mockGen2)
|
||||
tg2 := newTagJunkPacketGenerator("t2", "", 0)
|
||||
tg2.append(mockGen2)
|
||||
tg2.append(mockGen1)
|
||||
|
||||
return []TagJunkPacketGenerator{tg1, tg2}
|
||||
},
|
||||
expected: [][]byte{
|
||||
append(mockByte1, mockByte2...),
|
||||
append(mockByte2, mockByte1...),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
tagGenerators := tt.setupGenerator()
|
||||
for _, gen := range tagGenerators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
result := generators.GeneratePackets()
|
||||
require.Equal(t, result, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
112
device/awg/tag_parser.go
Normal file
112
device/awg/tag_parser.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IpcFields struct{ Key, Value string }
|
||||
|
||||
type EnumTag string
|
||||
|
||||
const (
|
||||
BytesEnumTag EnumTag = "b"
|
||||
CounterEnumTag EnumTag = "c"
|
||||
TimestampEnumTag EnumTag = "t"
|
||||
RandomBytesEnumTag EnumTag = "r"
|
||||
WaitTimeoutEnumTag EnumTag = "wt"
|
||||
WaitResponseEnumTag EnumTag = "wr"
|
||||
)
|
||||
|
||||
var generatorCreator = map[EnumTag]newGenerator{
|
||||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||
// WaitResponseEnumTag: newWaitResponseGenerator,
|
||||
}
|
||||
|
||||
// helper map to determine enumTags are unique
|
||||
var uniqueTags = map[EnumTag]bool{
|
||||
CounterEnumTag: false,
|
||||
TimestampEnumTag: false,
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name EnumTag
|
||||
Param string
|
||||
}
|
||||
|
||||
func parseTag(input string) (Tag, error) {
|
||||
// Regular expression to match <tagname optional_param>
|
||||
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
|
||||
|
||||
match := re.FindStringSubmatch(input)
|
||||
tag := Tag{
|
||||
Name: EnumTag(match[1]),
|
||||
}
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
tag.Param = strings.TrimSpace(match[2])
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
func Parse(name, input string) (TagJunkPacketGenerator, error) {
|
||||
inputSlice := strings.Split(input, "<")
|
||||
if len(inputSlice) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
|
||||
}
|
||||
|
||||
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
|
||||
maps.Copy(uniqueTagCheck, uniqueTags)
|
||||
|
||||
// skip byproduct of split
|
||||
inputSlice = inputSlice[1:]
|
||||
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
|
||||
for _, inputParam := range inputSlice {
|
||||
if len(inputParam) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"empty tag in input: %s",
|
||||
inputSlice,
|
||||
)
|
||||
} else if strings.Count(inputParam, ">") != 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
||||
}
|
||||
|
||||
tag, _ := parseTag(inputParam)
|
||||
creator, ok := generatorCreator[tag.Name]
|
||||
if !ok {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
|
||||
}
|
||||
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
||||
if present {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"tag %s needs to be unique",
|
||||
tag.Name,
|
||||
)
|
||||
}
|
||||
uniqueTagCheck[tag.Name] = true
|
||||
}
|
||||
generator, err := creator(tag.Param)
|
||||
if err != nil {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
|
||||
}
|
||||
|
||||
// TODO: handle counter tag
|
||||
// if tag.Name == CounterEnumTag {
|
||||
// packetCounter, ok := generator.(*PacketCounterGenerator)
|
||||
// if !ok {
|
||||
// log.Fatalf("packet counter generator expected, got %T", generator)
|
||||
// }
|
||||
// PacketCounter = packetCounter.counter
|
||||
// }
|
||||
|
||||
rv.append(generator)
|
||||
}
|
||||
|
||||
return rv, nil
|
||||
}
|
77
device/awg/tag_parser_test.go
Normal file
77
device/awg/tag_parser_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
input string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid name",
|
||||
args: args{name: "apple", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{name: "i1", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra >",
|
||||
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra <",
|
||||
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "empty <>",
|
||||
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "invalid tag",
|
||||
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
|
||||
wantErr: fmt.Errorf("invalid tag"),
|
||||
},
|
||||
{
|
||||
name: "counter uniqueness violation",
|
||||
args: args{name: "i1", input: "<c><c>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "timestamp uniqueness violation",
|
||||
args: args{name: "i1", input: "<t><t>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Parse(tt.args.name, tt.args.input)
|
||||
|
||||
// TODO: ErrorAs doesn't work as you think
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -8,7 +8,7 @@ package device
|
|||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
)
|
||||
|
||||
type DummyDatagram struct {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
393
device/device.go
393
device/device.go
|
@ -1,24 +1,60 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
||||
"github.com/amnezia-vpn/amnezia-wg/ratelimiter"
|
||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
||||
"github.com/tevino/abool/v2"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
)
|
||||
|
||||
type Version uint8
|
||||
|
||||
const (
|
||||
VersionDefault Version = iota
|
||||
VersionAwg
|
||||
VersionAwgSpecialHandshake
|
||||
)
|
||||
|
||||
// TODO:
|
||||
type AtomicVersion struct {
|
||||
value atomic.Uint32
|
||||
}
|
||||
|
||||
func NewAtomicVersion(v Version) *AtomicVersion {
|
||||
av := &AtomicVersion{}
|
||||
av.Store(v)
|
||||
return av
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Load() Version {
|
||||
return Version(av.value.Load())
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Store(v Version) {
|
||||
av.value.Store(uint32(v))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
|
||||
return av.value.CompareAndSwap(uint32(old), uint32(new))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Swap(new Version) Version {
|
||||
return Version(av.value.Swap(uint32(new)))
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
state struct {
|
||||
// state holds the device's state. It is accessed atomically.
|
||||
|
@ -92,22 +128,8 @@ type Device struct {
|
|||
closed chan struct{}
|
||||
log *Logger
|
||||
|
||||
isASecOn abool.AtomicBool
|
||||
aSecMux sync.RWMutex
|
||||
aSecCfg aSecCfgType
|
||||
}
|
||||
|
||||
type aSecCfgType struct {
|
||||
isSet bool
|
||||
junkPacketCount int
|
||||
junkPacketMinSize int
|
||||
junkPacketMaxSize int
|
||||
initPacketJunkSize int
|
||||
responsePacketJunkSize int
|
||||
initPacketMagicHeader uint32
|
||||
responsePacketMagicHeader uint32
|
||||
underloadPacketMagicHeader uint32
|
||||
transportPacketMagicHeader uint32
|
||||
version Version
|
||||
awg awg.Protocol
|
||||
}
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
|
@ -388,10 +410,10 @@ func (device *Device) RemoveAllPeers() {
|
|||
}
|
||||
|
||||
func (device *Device) Close() {
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
if device.isClosed() {
|
||||
return
|
||||
}
|
||||
|
@ -415,6 +437,8 @@ func (device *Device) Close() {
|
|||
|
||||
device.rate.limiter.Close()
|
||||
|
||||
device.resetProtocol()
|
||||
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
}
|
||||
|
@ -481,11 +505,7 @@ func (device *Device) BindSetMark(mark uint32) error {
|
|||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
|
@ -535,11 +555,7 @@ func (device *Device) BindUpdate() error {
|
|||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
|
@ -562,242 +578,261 @@ func (device *Device) BindClose() error {
|
|||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
func (device *Device) isAdvancedSecurityOn() bool {
|
||||
return device.isASecOn.IsSet()
|
||||
func (device *Device) isAWG() bool {
|
||||
return device.version >= VersionAwg
|
||||
}
|
||||
|
||||
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||
func (device *Device) resetProtocol() {
|
||||
// restore default message type values
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
if !tempASecCfg.isSet {
|
||||
return err
|
||||
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
||||
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
isASecOn := false
|
||||
device.aSecMux.Lock()
|
||||
if tempASecCfg.junkPacketCount < 0 {
|
||||
err = ipcErrorf(
|
||||
device.awg.ASecMux.Lock()
|
||||
if tempAwg.ASecCfg.JunkPacketCount < 0 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketCount should be non negative",
|
||||
),
|
||||
)
|
||||
}
|
||||
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
||||
if tempASecCfg.junkPacketCount != 0 {
|
||||
device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
|
||||
if tempAwg.ASecCfg.JunkPacketCount != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
||||
if tempASecCfg.junkPacketMinSize != 0 {
|
||||
device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
|
||||
if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if device.aSecCfg.junkPacketCount > 0 &&
|
||||
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
||||
if device.awg.ASecCfg.JunkPacketCount > 0 &&
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
|
||||
|
||||
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
|
||||
}
|
||||
|
||||
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
||||
device.aSecCfg.junkPacketMinSize = 0
|
||||
device.aSecCfg.junkPacketMaxSize = 1
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
|
||||
device.awg.ASecCfg.JunkPacketMinSize = 0
|
||||
device.awg.ASecCfg.JunkPacketMaxSize = 1
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
)
|
||||
}
|
||||
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d; %w",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMinSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
))
|
||||
} else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d",
|
||||
tempASecCfg.junkPacketMaxSize,
|
||||
tempASecCfg.junkPacketMinSize,
|
||||
)
|
||||
}
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize,
|
||||
tempAwg.ASecCfg.JunkPacketMinSize,
|
||||
))
|
||||
} else {
|
||||
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
||||
device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
|
||||
}
|
||||
|
||||
if tempASecCfg.junkPacketMaxSize != 0 {
|
||||
if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||
tempASecCfg.initPacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
|
||||
|
||||
if newInitSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempASecCfg.initPacketJunkSize,
|
||||
tempAwg.ASecCfg.InitHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
||||
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempASecCfg.initPacketJunkSize != 0 {
|
||||
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||
tempASecCfg.responsePacketJunkSize,
|
||||
MaxSegmentSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
|
||||
|
||||
if newResponseSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempASecCfg.responsePacketJunkSize,
|
||||
tempAwg.ASecCfg.ResponseHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
||||
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempASecCfg.responsePacketJunkSize != 0 {
|
||||
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if tempASecCfg.initPacketMagicHeader > 4 {
|
||||
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
|
||||
|
||||
if newCookieSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
|
||||
|
||||
if newTransportSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.ASecCfg.TransportHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
|
||||
isASecOn = true
|
||||
}
|
||||
|
||||
if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
|
||||
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
||||
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
|
||||
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = 1
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
}
|
||||
|
||||
if tempASecCfg.responsePacketMagicHeader > 4 {
|
||||
if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
|
||||
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
||||
device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
|
||||
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = 2
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
}
|
||||
|
||||
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
||||
if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
||||
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
|
||||
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = 3
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
}
|
||||
|
||||
if tempASecCfg.transportPacketMagicHeader > 4 {
|
||||
if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 {
|
||||
isASecOn = true
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
|
||||
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
|
||||
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = 4
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
isSameMap := map[uint32]bool{}
|
||||
isSameMap[MessageInitiationType] = true
|
||||
isSameMap[MessageResponseType] = true
|
||||
isSameMap[MessageCookieReplyType] = true
|
||||
isSameMap[MessageTransportType] = true
|
||||
isSameHeaderMap := map[uint32]struct{}{
|
||||
MessageInitiationType: {},
|
||||
MessageResponseType: {},
|
||||
MessageCookieReplyType: {},
|
||||
MessageTransportType: {},
|
||||
}
|
||||
|
||||
// size will be different if same values
|
||||
if len(isSameMap) != 4 {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
|
||||
MessageInitiationType,
|
||||
MessageResponseType,
|
||||
MessageCookieReplyType,
|
||||
MessageTransportType,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
if len(isSameHeaderMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
||||
MessageInitiationType,
|
||||
MessageResponseType,
|
||||
MessageCookieReplyType,
|
||||
MessageTransportType,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
||||
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
||||
|
||||
if newInitSize == newResponseSize {
|
||||
if err != nil {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new init size:%d; and new response size:%d; should differ; %w`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new init size:%d; and new response size:%d; should differ`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
)
|
||||
isSameSizeMap := map[int]struct{}{
|
||||
newInitSize: {},
|
||||
newResponseSize: {},
|
||||
newCookieSize: {},
|
||||
newTransportSize: {},
|
||||
}
|
||||
|
||||
if len(isSameSizeMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
newCookieSize,
|
||||
newTransportSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
|
||||
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
|
||||
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
|
||||
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
|
||||
}
|
||||
|
||||
packetSizeToMsgType = map[int]uint32{
|
||||
newInitSize: MessageInitiationType,
|
||||
newResponseSize: MessageResponseType,
|
||||
MessageCookieReplySize: MessageCookieReplyType,
|
||||
MessageTransportSize: MessageTransportType,
|
||||
}
|
||||
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
|
||||
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
|
||||
MessageCookieReplyType: 0,
|
||||
MessageTransportType: 0,
|
||||
newCookieSize: MessageCookieReplyType,
|
||||
newTransportSize: MessageTransportType,
|
||||
}
|
||||
}
|
||||
|
||||
device.isASecOn.SetTo(isASecOn)
|
||||
device.aSecMux.Unlock()
|
||||
device.awg.IsASecOn.SetTo(isASecOn)
|
||||
var err error
|
||||
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return err
|
||||
if tempAwg.HandshakeHandler.IsSet {
|
||||
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||
} else {
|
||||
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
|
||||
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
|
||||
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
|
||||
device.version = VersionAwgSpecialHandshake
|
||||
}
|
||||
} else {
|
||||
device.version = VersionAwg
|
||||
}
|
||||
|
||||
device.awg.ASecMux.Unlock()
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
|
|
@ -1,29 +1,32 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn/bindtest"
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||
)
|
||||
|
||||
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
||||
|
@ -50,7 +53,7 @@ func uapiCfg(cfg ...string) string {
|
|||
|
||||
// genConfigs generates a pair of configs that connect to each other.
|
||||
// The configs use distinct, probably-usable ports.
|
||||
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||
func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) {
|
||||
var key1, key2 NoisePrivateKey
|
||||
_, err := rand.Read(key1[:])
|
||||
if err != nil {
|
||||
|
@ -62,7 +65,8 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
cfgs[0] = uapiCfg(
|
||||
args0 := append([]string(nil), cfg...)
|
||||
args0 = append(args0, []string{
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
|
@ -70,12 +74,16 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.2/32",
|
||||
)
|
||||
}...)
|
||||
cfgs[0] = uapiCfg(args0...)
|
||||
|
||||
endpointCfgs[0] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
cfgs[1] = uapiCfg(
|
||||
|
||||
args1 := append([]string(nil), cfg...)
|
||||
args1 = append(args1, []string{
|
||||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
|
@ -83,66 +91,9 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.1/32",
|
||||
)
|
||||
endpointCfgs[1] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
return
|
||||
}
|
||||
}...)
|
||||
|
||||
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||
var key1, key2 NoisePrivateKey
|
||||
_, err := rand.Read(key1[:])
|
||||
if err != nil {
|
||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||
}
|
||||
_, err = rand.Read(key2[:])
|
||||
if err != nil {
|
||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||
}
|
||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "501",
|
||||
"s1", "30",
|
||||
"s2", "40",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h4", "32345",
|
||||
"h3", "123123",
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.2/32",
|
||||
)
|
||||
endpointCfgs[0] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
)
|
||||
cfgs[1] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", "0",
|
||||
"replace_peers", "true",
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "501",
|
||||
"s1", "30",
|
||||
"s2", "40",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h4", "32345",
|
||||
"h3", "123123",
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "1.0.0.1/32",
|
||||
)
|
||||
cfgs[1] = uapiCfg(args1...)
|
||||
endpointCfgs[1] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"endpoint", "127.0.0.1:%d",
|
||||
|
@ -185,9 +136,10 @@ func (pair *testPair) Send(
|
|||
// pong is the new ping
|
||||
p0, p1 = p1, p0
|
||||
}
|
||||
|
||||
msg := tuntest.Ping(p0.ip, p1.ip)
|
||||
p1.tun.Outbound <- msg
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
timer := time.NewTimer(6 * time.Second)
|
||||
defer timer.Stop()
|
||||
var err error
|
||||
select {
|
||||
|
@ -214,14 +166,12 @@ func (pair *testPair) Send(
|
|||
// genTestPair creates a testPair.
|
||||
func genTestPair(
|
||||
tb testing.TB,
|
||||
realSocket, withASecurity bool,
|
||||
realSocket bool,
|
||||
extraCfg ...string,
|
||||
) (pair testPair) {
|
||||
var cfg, endpointCfg [2]string
|
||||
if withASecurity {
|
||||
cfg, endpointCfg = genASecurityConfigs(tb)
|
||||
} else {
|
||||
cfg, endpointCfg = genConfigs(tb)
|
||||
}
|
||||
cfg, endpointCfg = genConfigs(tb, extraCfg...)
|
||||
|
||||
var binds [2]conn.Bind
|
||||
if realSocket {
|
||||
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||
|
@ -237,7 +187,7 @@ func genTestPair(
|
|||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||
level = LogLevelError
|
||||
}
|
||||
p.dev = NewDevice(p.tun.TUN(),binds[i],NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||
p.dev.Close()
|
||||
|
@ -265,7 +215,7 @@ func genTestPair(
|
|||
|
||||
func TestTwoDevicePing(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true, false)
|
||||
pair := genTestPair(t, true)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
pair.Send(t, Ping, nil)
|
||||
})
|
||||
|
@ -274,9 +224,23 @@ func TestTwoDevicePing(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestTwoDevicePingASecurity(t *testing.T) {
|
||||
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
|
||||
func TestAWGDevicePing(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true, true)
|
||||
|
||||
pair := genTestPair(t, true,
|
||||
"jc", "5",
|
||||
"jmin", "500",
|
||||
"jmax", "1000",
|
||||
"s1", "30",
|
||||
"s2", "40",
|
||||
"s3", "50",
|
||||
"s4", "5",
|
||||
"h1", "123456",
|
||||
"h2", "67543",
|
||||
"h3", "123123",
|
||||
"h4", "32345",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
pair.Send(t, Ping, nil)
|
||||
})
|
||||
|
@ -285,16 +249,61 @@ func TestTwoDevicePingASecurity(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Needs to be stopped with Ctrl-C
|
||||
func TestAWGHandshakeDevicePing(t *testing.T) {
|
||||
t.Skip("This test is intended to be run manually, not as part of the test suite.")
|
||||
|
||||
signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
isRunning := atomic.NewBool(true)
|
||||
go func() {
|
||||
<-signalContext.Done()
|
||||
fmt.Println("Waiting to finish")
|
||||
isRunning.Store(false)
|
||||
}()
|
||||
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true,
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||
"i2", "<b 0xf6ab3267fa><r 100>",
|
||||
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||
"j3", "<t><b 0xf6ab><c><r 10>",
|
||||
"itime", "60",
|
||||
// "jc", "1",
|
||||
// "jmin", "500",
|
||||
// "jmax", "1000",
|
||||
// "s1", "30",
|
||||
// "s2", "40",
|
||||
// "h1", "123456",
|
||||
// "h2", "67543",
|
||||
// "h4", "32345",
|
||||
// "h3", "123123",
|
||||
)
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Ping, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
for isRunning.Load() {
|
||||
pair.Send(t, Pong, nil)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpDown(t *testing.T) {
|
||||
goroutineLeakCheck(t)
|
||||
const itrials = 50
|
||||
const otrials = 10
|
||||
|
||||
for n := 0; n < otrials; n++ {
|
||||
pair := genTestPair(t, false, false)
|
||||
pair := genTestPair(t, false)
|
||||
for i := range pair {
|
||||
for k := range pair[i].dev.peers.keyMap {
|
||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n",hex.EncodeToString(k[:])))
|
||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||
}
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
|
@ -325,7 +334,7 @@ func TestUpDown(t *testing.T) {
|
|||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||
// It is intended to be used with the race detector to catch data races.
|
||||
func TestConcurrencySafety(t *testing.T) {
|
||||
pair := genTestPair(t, true, false)
|
||||
pair := genTestPair(t, true)
|
||||
done := make(chan struct{})
|
||||
|
||||
const warmupIters = 10
|
||||
|
@ -406,7 +415,7 @@ func TestConcurrencySafety(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkLatency(b *testing.B) {
|
||||
pair := genTestPair(b, true, false)
|
||||
pair := genTestPair(b, true)
|
||||
|
||||
// Establish a connection.
|
||||
pair.Send(b, Ping, nil)
|
||||
|
@ -420,7 +429,7 @@ func BenchmarkLatency(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkThroughput(b *testing.B) {
|
||||
pair := genTestPair(b, true, false)
|
||||
pair := genTestPair(b, true)
|
||||
|
||||
// Establish a connection.
|
||||
pair.Send(b, Ping, nil)
|
||||
|
@ -464,7 +473,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkUAPIGet(b *testing.B) {
|
||||
pair := genTestPair(b, true, false)
|
||||
pair := genTestPair(b, true)
|
||||
pair.Send(b, Ping, nil)
|
||||
pair.Send(b, Pong, nil)
|
||||
b.ReportAllocs()
|
||||
|
@ -513,7 +522,7 @@ func (b *fakeBindSized) Open(
|
|||
|
||||
func (b *fakeBindSized) Close() error { return nil }
|
||||
|
||||
func (b *fakeBindSized) SetMark(mark uint32) error {return nil }
|
||||
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
||||
|
||||
|
@ -527,7 +536,9 @@ type fakeTUNDeviceSized struct {
|
|||
|
||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||
|
||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { return 0, nil }
|
||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -11,7 +11,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/replay"
|
||||
"github.com/amnezia-vpn/amneziawg-go/replay"
|
||||
)
|
||||
|
||||
/* Due to limitations in Go and /x/crypto there is currently
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
|||
device.net.brokenRoaming = true
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
peer.disableRoaming = peer.endpoint != nil
|
||||
peer.Unlock()
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||
peer.endpoint.Unlock()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -15,7 +15,7 @@ import (
|
|||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/tai64n"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tai64n"
|
||||
)
|
||||
|
||||
type handshakeState int
|
||||
|
@ -52,11 +52,18 @@ const (
|
|||
WGLabelCookie = "cookie--"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMessageInitiationType uint32 = 1
|
||||
DefaultMessageResponseType uint32 = 2
|
||||
DefaultMessageCookieReplyType uint32 = 3
|
||||
DefaultMessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
var (
|
||||
MessageInitiationType uint32 = 1
|
||||
MessageResponseType uint32 = 2
|
||||
MessageCookieReplyType uint32 = 3
|
||||
MessageTransportType uint32 = 4
|
||||
MessageInitiationType uint32 = DefaultMessageInitiationType
|
||||
MessageResponseType uint32 = DefaultMessageResponseType
|
||||
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
|
||||
MessageTransportType uint32 = DefaultMessageTransportType
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -75,9 +82,10 @@ const (
|
|||
MessageTransportOffsetContent = 16
|
||||
)
|
||||
|
||||
var packetSizeToMsgType map[int]uint32
|
||||
|
||||
var msgTypeToJunkSize map[uint32]int
|
||||
var (
|
||||
packetSizeToMsgType map[int]uint32
|
||||
msgTypeToJunkSize map[uint32]int
|
||||
)
|
||||
|
||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||
* by marshalling the messages in little-endian byteorder
|
||||
|
@ -197,12 +205,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.ASecMux.RLock()
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
@ -256,12 +264,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.ASecMux.RLock()
|
||||
if msg.Type != MessageInitiationType {
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
|
@ -376,9 +384,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
var msg MessageResponse
|
||||
device.aSecMux.RLock()
|
||||
device.awg.ASecMux.RLock()
|
||||
msg.Type = MessageResponseType
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
|
@ -428,12 +436,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
device.aSecMux.RLock()
|
||||
device.awg.ASecMux.RLock()
|
||||
if msg.Type != MessageResponseType {
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -10,8 +10,8 @@ import (
|
|||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||
)
|
||||
|
||||
func TestCurveWrappers(t *testing.T) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -12,22 +12,26 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
isRunning atomic.Bool
|
||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
||||
keypairs Keypairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
endpoint conn.Endpoint
|
||||
stopping sync.WaitGroup // routines pending stop
|
||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||
rxBytes atomic.Uint64 // bytes received from peer
|
||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||
|
||||
endpoint struct {
|
||||
sync.Mutex
|
||||
val conn.Endpoint
|
||||
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||
disableRoaming bool
|
||||
}
|
||||
|
||||
timers struct {
|
||||
retransmitHandshake *Timer
|
||||
|
@ -74,8 +78,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
|
||||
// create peer
|
||||
peer := new(Peer)
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
|
||||
peer.cookieGenerator.Init(pk)
|
||||
peer.device = device
|
||||
|
@ -97,7 +99,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
peer.endpoint = nil
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.val = nil
|
||||
peer.endpoint.disableRoaming = false
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
// init timers
|
||||
peer.timersInit()
|
||||
|
@ -108,6 +114,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
|
||||
err := peer.SendBuffers(buffers)
|
||||
if err == nil {
|
||||
awg.PacketCounter.Add(uint64(len(buffers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
@ -116,14 +132,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
peer.RLock()
|
||||
defer peer.RUnlock()
|
||||
|
||||
if peer.endpoint == nil {
|
||||
peer.endpoint.Lock()
|
||||
endpoint := peer.endpoint.val
|
||||
if endpoint == nil {
|
||||
peer.endpoint.Unlock()
|
||||
return errors.New("no known endpoint for peer")
|
||||
}
|
||||
if peer.endpoint.clearSrcOnTx {
|
||||
endpoint.ClearSrc()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
err := peer.device.net.bind.Send(buffers, peer.endpoint)
|
||||
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||
if err == nil {
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
|
@ -267,10 +288,20 @@ func (peer *Peer) Stop() {
|
|||
}
|
||||
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||
if peer.disableRoaming {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.disableRoaming {
|
||||
return
|
||||
}
|
||||
peer.Lock()
|
||||
peer.endpoint = endpoint
|
||||
peer.Unlock()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.val = endpoint
|
||||
}
|
||||
|
||||
func (peer *Peer) markEndpointSrcForClearing() {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.val == nil {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = true
|
||||
}
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type WaitPool struct {
|
||||
pool sync.Pool
|
||||
cond sync.Cond
|
||||
lock sync.Mutex
|
||||
count atomic.Uint32
|
||||
count uint32 // Get calls not yet Put back
|
||||
max uint32
|
||||
}
|
||||
|
||||
|
@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
|
|||
func (p *WaitPool) Get() any {
|
||||
if p.max != 0 {
|
||||
p.lock.Lock()
|
||||
for p.count.Load() >= p.max {
|
||||
for p.count >= p.max {
|
||||
p.cond.Wait()
|
||||
}
|
||||
p.count.Add(1)
|
||||
p.count++
|
||||
p.lock.Unlock()
|
||||
}
|
||||
return p.pool.Get()
|
||||
|
@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) {
|
|||
if p.max == 0 {
|
||||
return
|
||||
}
|
||||
p.count.Add(^uint32(0))
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.count--
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) {
|
|||
wg.Add(workers)
|
||||
var max atomic.Uint32
|
||||
updateMax := func() {
|
||||
count := p.count.Load()
|
||||
p.lock.Lock()
|
||||
count := p.count
|
||||
p.lock.Unlock()
|
||||
if count > p.max {
|
||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
|
||||
/* Reduce memory consumption for Android */
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
|
||||
const (
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -13,7 +13,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
deathSpiral = 0
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.ASecMux.RLock()
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
|
@ -137,31 +137,45 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := bufsArrs[i][:size]
|
||||
var msgType uint32
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if device.isAWG() {
|
||||
// TODO:
|
||||
// if awg.WaitResponse.ShouldWait.IsSet() {
|
||||
// awg.WaitResponse.Channel <- struct{}{}
|
||||
// }
|
||||
|
||||
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||
// transport size can align with other header types;
|
||||
// making sure we have the right msgType
|
||||
msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
|
||||
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
|
||||
if msgType == assumedMsgType {
|
||||
packet = packet[junkSize:]
|
||||
} else {
|
||||
device.log.Verbosef("Transport packet lined up with another msg type")
|
||||
device.log.Verbosef("transport packet lined up with another msg type")
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
}
|
||||
} else {
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
|
||||
msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
|
||||
if msgType != MessageTransportType {
|
||||
device.log.Verbosef("ASec: Received message with unknown type")
|
||||
// probably a junk packet
|
||||
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
|
||||
continue
|
||||
}
|
||||
|
||||
// remove junk from bufsArrs by shifting the packet
|
||||
// this buffer is also used for decryption, so it needs to be corrected
|
||||
copy(bufsArrs[i][:size], packet[transportJunkSize:])
|
||||
size -= transportJunkSize
|
||||
// need to reinitialize packet as well
|
||||
packet = packet[:size]
|
||||
}
|
||||
} else {
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
@ -245,7 +259,7 @@ func (device *Device) RoutineReceiveIncoming(
|
|||
default:
|
||||
}
|
||||
}
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
|
@ -304,7 +318,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
device.aSecMux.RLock()
|
||||
device.awg.ASecMux.RLock()
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
|
@ -456,7 +470,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.aSecMux.RUnlock()
|
||||
device.awg.ASecMux.RUnlock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
@ -476,7 +490,10 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
|||
return
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
validTailPacket := -1
|
||||
dataPacketReceived := false
|
||||
rxBytesLen := uint64(0)
|
||||
for i, elem := range elemsContainer.elems {
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
continue
|
||||
|
@ -486,21 +503,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
|||
continue
|
||||
}
|
||||
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
validTailPacket = i
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
continue
|
||||
}
|
||||
peer.timersDataReceived()
|
||||
dataPacketReceived = true
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
|
@ -549,6 +564,17 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
|||
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
|
||||
)
|
||||
}
|
||||
|
||||
peer.rxBytes.Add(rxBytesLen)
|
||||
if validTailPacket >= 0 {
|
||||
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
}
|
||||
if dataPacketReceived {
|
||||
peer.timersDataReceived()
|
||||
}
|
||||
if len(bufs) > 0 {
|
||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
|
|
112
device/send.go
112
device/send.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -9,14 +9,13 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
@ -125,37 +124,50 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
return err
|
||||
}
|
||||
var sendBuffer [][]byte
|
||||
|
||||
// so only packet processed for cookie generation
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.aSecMux.RLock()
|
||||
junks, err := peer.createJunkPackets()
|
||||
peer.device.aSecMux.RUnlock()
|
||||
if peer.device.version >= VersionAwg {
|
||||
var junks [][]byte
|
||||
if peer.device.version == VersionAwgSpecialHandshake {
|
||||
peer.device.awg.ASecMux.RLock()
|
||||
// set junks depending on packet type
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||
if junks == nil {
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk()
|
||||
if junks != nil {
|
||||
peer.device.log.Verbosef("%v - Controlled junks sent", peer)
|
||||
}
|
||||
} else {
|
||||
peer.device.log.Verbosef("%v - Special junks sent", peer)
|
||||
}
|
||||
peer.device.awg.ASecMux.RUnlock()
|
||||
} else {
|
||||
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
|
||||
}
|
||||
peer.device.awg.ASecMux.RLock()
|
||||
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
|
||||
peer.device.awg.ASecMux.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(junks) > 0 {
|
||||
err = peer.SendBuffers(junks)
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
peer.device.aSecMux.RLock()
|
||||
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
||||
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
peer.device.aSecMux.RUnlock()
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.aSecMux.RUnlock()
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
|
@ -170,7 +182,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
|
||||
sendBuffer = append(sendBuffer, junkedHeader)
|
||||
|
||||
err = peer.SendBuffers(sendBuffer)
|
||||
err = peer.SendAndCountBuffers(sendBuffer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
|
@ -191,22 +203,13 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
var junkedHeader []byte
|
||||
if peer.device.isAdvancedSecurityOn() {
|
||||
peer.device.aSecMux.RLock()
|
||||
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
||||
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
||||
|
||||
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
|
||||
if err != nil {
|
||||
peer.device.aSecMux.RUnlock()
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
junkedHeader = writer.Bytes()
|
||||
}
|
||||
peer.device.aSecMux.RUnlock()
|
||||
}
|
||||
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
||||
|
@ -226,7 +229,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendBuffers([][]byte{junkedHeader})
|
||||
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
|
@ -249,11 +252,19 @@ func (device *Device) SendHandshakeCookie(
|
|||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
|
||||
junkedHeader = append(junkedHeader, writer.Bytes()...)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -466,31 +477,6 @@ top:
|
|||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) createJunkPackets() ([][]byte, error) {
|
||||
if peer.device.aSecCfg.junkPacketCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
|
||||
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
|
||||
packetSize := rand.Intn(
|
||||
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
|
||||
) + peer.device.aSecCfg.junkPacketMinSize
|
||||
|
||||
junk, err := randomJunkWithSize(packetSize)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf(
|
||||
"%v - Failed to create junk packet: %v",
|
||||
peer,
|
||||
err,
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
junks = append(junks, junk)
|
||||
}
|
||||
return junks, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
|
@ -591,6 +577,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
|
@ -598,6 +585,14 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
for _, elem := range elemsContainer.elems {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
|
||||
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = append(junkedHeader, elem.packet...)
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
|
@ -605,10 +600,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendBuffers(bufs)
|
||||
err := peer.SendAndCountBuffers(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
|
@ -9,7 +9,7 @@
|
|||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
* So this code remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -20,8 +20,8 @@ import (
|
|||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
|
@ -47,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
|
|||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
|
@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
|||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.Lock()
|
||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||
pePtr.peer.Unlock()
|
||||
pePtr.peer.endpoint.Lock()
|
||||
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||
pePtr.peer.Unlock()
|
||||
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
|
||||
pePtr.peer.Unlock()
|
||||
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
|
@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
|||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.RLock()
|
||||
if peer.endpoint == nil {
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val == nil {
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
|
||||
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
|
@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
|||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint,
|
||||
endpoint: &peer.endpoint.val,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Unlock()
|
||||
i++
|
||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
|
@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.Unlock()
|
||||
peer.markEndpointSrcForClearing()
|
||||
|
||||
peer.SendHandshakeInitiation(true)
|
||||
}
|
||||
|
@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
|
|||
func expiredNewHandshake(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.Unlock()
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -8,7 +8,7 @@ package device
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
)
|
||||
|
||||
const DefaultMTU = 1420
|
||||
|
|
256
device/uapi.go
256
device/uapi.go
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -18,7 +18,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
)
|
||||
|
||||
type IPCError struct {
|
||||
|
@ -97,49 +98,66 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
if device.isAdvancedSecurityOn() {
|
||||
if device.aSecCfg.junkPacketCount != 0 {
|
||||
sendf("jc=%d", device.aSecCfg.junkPacketCount)
|
||||
if device.isAWG() {
|
||||
if device.awg.ASecCfg.JunkPacketCount != 0 {
|
||||
sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
|
||||
}
|
||||
if device.aSecCfg.junkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
|
||||
if device.awg.ASecCfg.JunkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
|
||||
}
|
||||
if device.aSecCfg.junkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
|
||||
if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
|
||||
}
|
||||
if device.aSecCfg.initPacketJunkSize != 0 {
|
||||
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
|
||||
if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
|
||||
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
|
||||
}
|
||||
if device.aSecCfg.responsePacketJunkSize != 0 {
|
||||
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
|
||||
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
|
||||
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
|
||||
}
|
||||
if device.aSecCfg.initPacketMagicHeader != 0 {
|
||||
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
|
||||
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
|
||||
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
|
||||
}
|
||||
if device.aSecCfg.responsePacketMagicHeader != 0 {
|
||||
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
|
||||
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
|
||||
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
|
||||
}
|
||||
if device.aSecCfg.underloadPacketMagicHeader != 0 {
|
||||
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
|
||||
if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
|
||||
sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
|
||||
}
|
||||
if device.aSecCfg.transportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
|
||||
if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
|
||||
sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
|
||||
sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
|
||||
}
|
||||
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
|
||||
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
|
||||
}
|
||||
|
||||
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||
for _, field := range specialJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields()
|
||||
for _, field := range controlledJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
}
|
||||
if device.awg.HandshakeHandler.ITimeout != 0 {
|
||||
sendf("itime=%d", device.awg.HandshakeHandler.ITimeout/time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range device.peers.keyMap {
|
||||
// Serialize peer state.
|
||||
// Do the work in an anonymous function so that we can use defer.
|
||||
func() {
|
||||
peer.RLock()
|
||||
defer peer.RUnlock()
|
||||
|
||||
peer.handshake.mutex.RLock()
|
||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||
peer.handshake.mutex.RUnlock()
|
||||
sendf("protocol_version=1")
|
||||
if peer.endpoint != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.DstToString())
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
nano := peer.lastHandshakeNano.Load()
|
||||
secs := nano / time.Second.Nanoseconds()
|
||||
|
@ -151,14 +169,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||
|
||||
device.allowedips.EntriesForPeer(
|
||||
peer,
|
||||
func(prefix netip.Prefix) bool {
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
},
|
||||
)
|
||||
}()
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -185,13 +199,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
tempASecCfg := aSecCfgType{}
|
||||
tempAwg := awg.Protocol{}
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
err := device.handlePostConfig(&tempASecCfg)
|
||||
err := device.handlePostConfig(&tempAwg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -222,7 +236,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value, &tempASecCfg)
|
||||
err = device.handleDeviceLine(key, value, &tempAwg)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
|
@ -230,7 +244,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
err = device.handlePostConfig(&tempASecCfg)
|
||||
err = device.handlePostConfig(&tempAwg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -242,7 +256,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
|
||||
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
|
@ -283,7 +297,11 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set replace_peers, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
@ -291,80 +309,138 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
|||
case "jc":
|
||||
junkPacketCount, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||
tempASecCfg.junkPacketCount = junkPacketCount
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "jmin":
|
||||
junkPacketMinSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "jmax":
|
||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "s1":
|
||||
initPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "s2":
|
||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "s3":
|
||||
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
|
||||
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "s4":
|
||||
transportPacketJunkSize, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
|
||||
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "h1":
|
||||
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "h2":
|
||||
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "h3":
|
||||
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
|
||||
}
|
||||
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
|
||||
case "h4":
|
||||
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
|
||||
}
|
||||
tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||
tempAwg.ASecCfg.IsSet = true
|
||||
case "i1", "i2", "i3", "i4", "i5":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
return nil
|
||||
}
|
||||
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||
tempASecCfg.isSet = true
|
||||
|
||||
generators, err := awg.Parse(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating %s", key)
|
||||
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
case "j1", "j2", "j3":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
generators, err := awg.Parse(key, value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating %s", key)
|
||||
|
||||
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
case "itime":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty itime")
|
||||
return nil
|
||||
}
|
||||
|
||||
itime, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating itime")
|
||||
|
||||
tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
|
@ -385,8 +461,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
|||
return
|
||||
}
|
||||
if peer.created {
|
||||
peer.disableRoaming = peer.device.net.brokenRoaming &&
|
||||
peer.endpoint != nil
|
||||
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||
}
|
||||
if peer.device.isUp() {
|
||||
peer.Start()
|
||||
|
@ -438,7 +513,11 @@ func (device *Device) handlePeerLine(
|
|||
case "update_only":
|
||||
// allow disabling of creation
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set update only, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
}
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
|
@ -475,16 +554,20 @@ func (device *Device) handlePeerLine(
|
|||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||
}
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
peer.endpoint = endpoint
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
peer.endpoint.val = endpoint
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to set persistent keepalive interval: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||
|
@ -495,7 +578,11 @@ func (device *Device) handlePeerLine(
|
|||
case "replace_allowed_ips":
|
||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"failed to replace allowedips, invalid value: %v",
|
||||
value,
|
||||
)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
|
@ -503,7 +590,14 @@ func (device *Device) handlePeerLine(
|
|||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
add := true
|
||||
verb := "Adding"
|
||||
if len(value) > 0 && value[0] == '-' {
|
||||
add = false
|
||||
verb = "Removing"
|
||||
value = value[1:]
|
||||
}
|
||||
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
|
@ -511,7 +605,11 @@ func (device *Device) handlePeerLine(
|
|||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
if add {
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
} else {
|
||||
device.allowedips.Remove(prefix, peer.Peer)
|
||||
}
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
|
@ -563,7 +661,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
|||
return
|
||||
}
|
||||
if nextByte != '\n' {
|
||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||
err = ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"trailing character in UAPI get: %q",
|
||||
nextByte,
|
||||
)
|
||||
break
|
||||
}
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func appendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk, err := randomJunkWithSize(size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create header junk: %v", err)
|
||||
}
|
||||
_, err = writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write header junk: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomJunkWithSize(size int) ([]byte, error) {
|
||||
junk := make([]byte, size)
|
||||
_, err := crand.Read(junk)
|
||||
return junk, err
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_randomJunktWithSize(t *testing.T) {
|
||||
junk, err := randomJunkWithSize(30)
|
||||
fmt.Println(string(junk), len(junk), err)
|
||||
}
|
||||
|
||||
func Test_appendJunk(t *testing.T) {
|
||||
t.Run("", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
err := appendJunk(buffer, 30)
|
||||
if err != nil &&
|
||||
buffer.Len() != len(s)+30 {
|
||||
t.Errorf("appendWithJunk() size don't match")
|
||||
}
|
||||
read := make([]byte, 50)
|
||||
buffer.Read(read)
|
||||
fmt.Println(string(read))
|
||||
})
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
package main
|
||||
|
||||
|
|
22
go.mod
22
go.mod
|
@ -1,17 +1,23 @@
|
|||
module github.com/amnezia-vpn/amnezia-wg
|
||||
module github.com/amnezia-vpn/amneziawg-go
|
||||
|
||||
go 1.20
|
||||
go 1.24.4
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tevino/abool v1.2.0
|
||||
github.com/tevino/abool/v2 v2.1.0
|
||||
golang.org/x/crypto v0.13.0
|
||||
golang.org/x/net v0.15.0
|
||||
golang.org/x/sys v0.12.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/net v0.41.0
|
||||
golang.org/x/sys v0.33.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/time v0.9.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
48
go.sum
48
go.sum
|
@ -1,16 +1,40 @@
|
|||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
|
||||
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
|
||||
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
||||
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
||||
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
|
||||
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
|
||||
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
|
|
|
@ -20,7 +20,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -26,7 +26,7 @@ const (
|
|||
|
||||
// socketDirectory is variable because it is modified by a linker
|
||||
// flag in wireguard-android.
|
||||
var socketDirectory = "/var/run/wireguard"
|
||||
var socketDirectory = "/var/run/amneziawg"
|
||||
|
||||
func sockPath(iface string) string {
|
||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -8,7 +8,7 @@ package ipc
|
|||
import (
|
||||
"net"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
|
@ -62,7 +62,7 @@ func init() {
|
|||
func UAPIListen(name string) (net.Listener, error) {
|
||||
listener, err := (&namedpipe.ListenConfig{
|
||||
SecurityDescriptor: UAPISecurityDescriptor,
|
||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
|
||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
24
main.go
24
main.go
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -14,10 +14,10 @@ import (
|
|||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
@ -46,20 +46,20 @@ func warning() {
|
|||
return
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
|
||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
|
||||
fmt.Fprintln(os.Stderr, "│ │")
|
||||
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
|
||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
|
||||
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
|
||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
|
||||
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
||||
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
|
||||
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
|
||||
fmt.Fprintln(os.Stderr, "│ │")
|
||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
|
||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
||||
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ func main() {
|
|||
fmt.Sprintf("(%s) ", interfaceName),
|
||||
)
|
||||
|
||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
||||
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create TUN device: %v", err)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -12,11 +12,11 @@ import (
|
|||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
|
||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,13 +30,13 @@ func main() {
|
|||
}
|
||||
interfaceName := os.Args[1]
|
||||
|
||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
|
||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
|
||||
|
||||
logger := device.NewLogger(
|
||||
device.LogLevelVerbose,
|
||||
fmt.Sprintf("(%s) ", interfaceName),
|
||||
)
|
||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
||||
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||
|
||||
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||
if err == nil {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package replay
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package rwcancel implements cancelable read/write operations on
|
||||
|
@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool {
|
|||
|
||||
func (rw *RWCancel) ReadyWrite() bool {
|
||||
closeFd := int32(rw.closingReader.Fd())
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}}
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(pollFds, -1)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
|
122
tun/checksum.go
122
tun/checksum.go
|
@ -1,102 +1,86 @@
|
|||
package tun
|
||||
|
||||
import "encoding/binary"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
|
||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||
ac := initial
|
||||
tmp := make([]byte, 8)
|
||||
binary.NativeEndian.PutUint64(tmp, initial)
|
||||
ac := binary.BigEndian.Uint64(tmp)
|
||||
var carry uint64
|
||||
|
||||
for len(b) >= 128 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
|
||||
ac += carry
|
||||
b = b[128:]
|
||||
}
|
||||
if len(b) >= 64 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
|
||||
ac += carry
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac += carry
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac += carry
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac += carry
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
ac += uint64(binary.BigEndian.Uint32(b))
|
||||
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
|
||||
ac += carry
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
ac += uint64(binary.BigEndian.Uint16(b))
|
||||
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
|
||||
ac += carry
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) == 1 {
|
||||
ac += uint64(b[0]) << 8
|
||||
tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
|
||||
ac, carry = bits.Add64(ac, uint64(tmp), 0)
|
||||
ac += carry
|
||||
}
|
||||
|
||||
return ac
|
||||
binary.NativeEndian.PutUint64(tmp, ac)
|
||||
return binary.BigEndian.Uint64(tmp)
|
||||
}
|
||||
|
||||
func checksum(b []byte, initial uint64) uint16 {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue