Compare commits

...

43 commits
show ... master

Author SHA1 Message Date
pokamest
27e661d68e
Merge pull request #70 from marko1777/junk-improvements
Junk improvements
2025-04-07 15:31:41 +01:00
Mark Puha
71be0eb3a6 faster and more secure junk creation 2025-03-18 08:34:23 +01:00
pokamest
e3f1273f8a
Merge pull request #64 from drkivi/master
Patch for golang crypto and net submodules
2025-02-18 11:50:35 +00:00
drkivi
c97b5b7615
Update go.sum
Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:58 +03:30
drkivi
668ddfd455
Update go.mod
Submodules Version Up

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:44:17 +03:30
drkivi
b8da08c106
Update Dockerfile
golang -> 1.23.6
AWGTOOLS_RELEASE -> 1.0.20241018

Signed-off-by: drkivi <115035277+drkivi@users.noreply.github.com>
2025-02-10 21:43:02 +03:30
Iurii Egorov
2e3f7d122c Update Go version in Dockerfile 2024-07-01 13:47:44 +03:00
Iurii Egorov
2e7780471a
Remove GetOffloadInfo() (#32)
* Remove GetOffloadInfo()
* Remove GetOffloadInfo() from bind_windows as well
* Allow lightweight tags to be used in the version
2024-05-24 16:18:23 +01:00
albexk
87d8c00f86 Up go to 1.22.3, up crypto to 0.21.0 2024-05-21 08:09:58 -07:00
albexk
c00bda9200 Fix output of the version command 2024-05-14 03:51:01 -07:00
albexk
d2b0fc9789 Add resetting of message types when closing the device 2024-05-14 03:51:01 -07:00
albexk
77d39ff3b9 Minor naming changes 2024-05-14 03:51:01 -07:00
albexk
e433d13df6 Add disabling UDP GSO when an error occurs due to inconsistent peer mtu 2024-05-14 03:51:01 -07:00
RomikB
3ddf952973 unsafe rebranding: change pipe name 2024-05-13 11:10:42 -07:00
albexk
3f0a3bcfa0 Fix wg reconnection problem after awg connection 2024-03-16 14:16:13 +00:00
AlexanderGalkov
4dddf62e57 Update Dockerfile
add wg and wg-quick symlinks

Signed-off-by: AlexanderGalkov <143902290+AlexanderGalkov@users.noreply.github.com>
2024-02-20 20:32:38 +07:00
tiaga
827ec6e14b
Merge pull request #21 from amnezia-vpn/fix-dockerfile
Fix Dockerfile
2024-02-13 21:47:55 +07:00
tiaga
92e28a0d14 Fix Dockerfile
Fix AmneziaWG tools installation.
2024-02-13 21:44:41 +07:00
tiaga
52fed4d362
Merge pull request #20 from amnezia-vpn/update_dockerfile
Update Dockerfile
2024-02-13 21:28:17 +07:00
tiaga
9c6b3ff332 Update Dockerfile
- rename `wg` and `wg-quick` to `awg` and `awg-quick` accordingly
- add iptables
- update AmneziaWG tools version
2024-02-13 21:27:34 +07:00
pokamest
7de7a9a754
Merge pull request #19 from amnezia-vpn/fix/go_sum
Fix go.sum
2024-02-12 05:31:57 -08:00
albexk
0c347529b8 Fix go.sum 2024-02-12 16:27:56 +03:00
albexk
6705978fc8 Add debug udp offload info 2024-02-12 04:40:23 -08:00
albexk
032e33f577 Fix Android UDP GRO check 2024-02-12 04:40:23 -08:00
albexk
59101fd202 Bump crypto, net, sys modules to the latest versions 2024-02-12 04:40:23 -08:00
tiaga
8bcfbac230
Merge pull request #17 from amnezia-vpn/fix-pipeline
Fix pipeline
2024-02-07 19:19:11 +07:00
tiaga
f0dfb5eacc Fix pipeline
Fix path to GitHub Actions workflow.
2024-02-07 18:53:55 +07:00
tiaga
9195025d8f
Merge pull request #16 from amnezia-vpn/pipeline
Add pipeline
2024-02-07 18:48:12 +07:00
tiaga
cbd414dfec Add pipeline
Build and push Docker image on a tag push.
2024-02-07 18:44:59 +07:00
tiaga
7155d20913
Merge pull request #14 from amnezia-vpn/docker
Update Dockerfile
2024-02-02 23:40:39 +07:00
tiaga
bfeb3954f6 Update Dockerfile
- update Alpine version
- improve `Dockerfile` to use pre-built AmneziaWG tools
2024-02-02 22:56:00 +07:00
Iurii Egorov
e3c9ec8012 Naming unify 2024-01-20 23:18:10 +03:00
pokamest
ce9d3866a3
Merge pull request #10 from amnezia-vpn/upstream-merge
Merge upstream changes
2024-01-14 13:37:00 -05:00
Iurii Egorov
e5f355e843 Fix incorrect configuration handling for zero-valued Jc 2024-01-14 18:22:02 +03:00
Iurii Egorov
c05b2ee2a3 Merge remote-tracking branch 'upstream/master' into upstream-merge 2024-01-09 21:34:30 +03:00
tiaga
180c9284f3
Merge pull request #7 from amnezia-vpn/docker
Add Dockerfile
2023-12-19 18:50:48 +07:00
tiaga
015e11875d Add Dockerfile
Build Docker image with the corresponding wg-tools version.
2023-12-19 18:46:04 +07:00
Martin Basovnik
12269c2761 device: fix possible deadlock in close method
There is a possible deadlock in `device.Close()` when you try to close
the device very soon after its start. The problem is that two different
methods acquire the same locks in different order:

1. device.Close()
 - device.ipcMutex.Lock()
 - device.state.Lock()

2. device.changeState(deviceState)
 - device.state.Lock()
 - device.ipcMutex.Lock()

Reproducer:

    func TestDevice_deadlock(t *testing.T) {
    	d := randDevice(t)
    	d.Close()
    }

Problem:

    $ go clean -testcache && go test -race -timeout 3s -run TestDevice_deadlock ./device | grep -A 10 sync.runtime_SemacquireMutex
    sync.runtime_SemacquireMutex(0xc000117d20?, 0x94?, 0x0?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130518)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    golang.zx2c4.com/wireguard/device.(*Device).Close(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:373 +0xb6
    golang.zx2c4.com/wireguard/device.TestDevice_deadlock(0x0?)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device_test.go:480 +0x2c
    testing.tRunner(0xc00014c000, 0x131d7b0)
    --
    sync.runtime_SemacquireMutex(0xc000130564?, 0x60?, 0xc000130548?)
            /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25
    sync.(*Mutex).lockSlow(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213
    sync.(*Mutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55
    sync.(*RWMutex).Lock(0xc000130750)
            /usr/local/opt/go/libexec/src/sync/rwmutex.go:147 +0x45
    golang.zx2c4.com/wireguard/device.(*Device).upLocked(0xc000130500)
            /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:179 +0x72
    golang.zx2c4.com/wireguard/device.(*Device).changeState(0xc000130500, 0x1)

Signed-off-by: Martin Basovnik <martin.basovnik@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:38:47 +01:00
Jason A. Donenfeld
542e565baa device: do atomic 64-bit add outside of vector loop
Only bother updating the rxBytes counter once we've processed a whole
vector, since additions are atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:35:57 +01:00
Jordan Whited
7c20311b3d device: reduce redundant per-packet overhead in RX path
Peer.RoutineSequentialReceiver() deals with packet vectors and does not
need to perform timer and endpoint operations for every packet in a
given vector. Changing these per-packet operations to per-vector
improves throughput by as much as 10% in some environments.

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
4ffa9c2032 device: change Peer.endpoint locking to reduce contention
Access to Peer.endpoint was previously synchronized by Peer.RWMutex.
This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the
sole caller of Endpoint.ClearSrc(), which is signaled via a new bool,
Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now
set this bool, primarily via peer.markEndpointSrcForClearing().
Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an
updated conn.Endpoint is stored. This maintains the same event order as
before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx
is set, but before the next Peer.SendBuffers() call results in the
latest conn.Endpoint source being used for the next packet transmission.

These changes result in throughput improvements for single flow,
parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP
tests as measured on both Linux and Windows. Latency under load improves
especially for high throughput Linux scenarios. These improvements are
likely realized on all platforms to some degree, as the changes are not
platform-specific.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:34:09 +01:00
Jordan Whited
d0bc03c707 tun: implement UDP GSO/GRO for Linux
Implement UDP GSO and GRO for the Linux tun.Device, which is made
possible by virtio extensions in the kernel's TUN driver starting in
v6.2.

secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is
used to demonstrate the effect of this commit between two Linux
computers with i5-12400 CPUs. There is roughly ~13us of round trip
latency between them. secnetperf was invoked with the following command
line options:
-stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0

The first result is from commit 2e0774f without UDP GSO/GRO on the TUN.

[conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us
SendTotalPackets=55859 SendSuspectedLostPackets=61
SendSpuriousLostPackets=59 SendCongestionCount=27
SendEcnCongestionCount=0 RecvTotalPackets=2779122
RecvReorderedPackets=0 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms).

The second result is with UDP GSO/GRO on the TUN.

[conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us
SendTotalPackets=165033 SendSuspectedLostPackets=64
SendSpuriousLostPackets=61 SendCongestionCount=53
SendEcnCongestionCount=0 RecvTotalPackets=11845268
RecvReorderedPackets=25267 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms).

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:27:22 +01:00
Jordan Whited
1cf89f5339 tun: fix Device.Read() buf length assumption on Windows
The length of a packet read from the underlying TUN device may exceed
the length of a supplied buffer when MTU exceeds device.MaxMessageSize.

Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-12-11 16:20:49 +01:00
52 changed files with 1764 additions and 897 deletions

41
.github/workflows/build-if-tag.yml vendored Normal file
View file

@ -0,0 +1,41 @@
name: build-if-tag
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
env:
APP: amneziawg-go
jobs:
build:
runs-on: ubuntu-latest
name: build
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup metadata
uses: docker/metadata-action@v5
id: metadata
with:
images: amneziavpn/${{ env.APP }}
tags: type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.metadata.outputs.tags }}

2
.gitignore vendored
View file

@ -1 +1 @@
wireguard-go
amneziawg-go

17
Dockerfile Normal file
View file

@ -0,0 +1,17 @@
FROM golang:1.24 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
go mod verify && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
unzip -j alpine-3.19-amneziawg-tools.zip && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go

View file

@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true
@$(MAKE) wireguard-go
@$(MAKE) amneziawg-go
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@"
install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
test:
go test ./...
clean:
rm -f wireguard-go
rm -f amneziawg-go
.PHONY: all clean test install generate-version-and-build

View file

@ -11,17 +11,17 @@ As a result, AmneziaWG maintains high performance while adding an extra layer of
Simply run:
```
$ amnezia-wg wg0
$ amneziawg-go wg0
```
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
To run amnezia-wg without forking to the background, pass `-f` or `--foreground`:
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
```
$ amnezia-wg -f wg0
$ amneziawg-go -f wg0
```
When an interface is running, you may use [`amnezia-wg-tools `](https://github.com/amnezia-vpn/amnezia-wg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@ -34,11 +34,11 @@ This will run on Linux; you should run amnezia-wg instead of using default linux
### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
This runs on MacOS, you should use it from [awg-apple](https://github.com/amnezia-vpn/awg-apple)
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
### Windows
This runs on Windows, you should use it from [awg-windows](https://github.com/amnezia-vpn/awg-windows), which uses this as a module.
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
## Building
@ -46,7 +46,7 @@ This runs on Windows, you should use it from [awg-windows](https://github.com/am
This requires an installation of the latest version of [Go](https://go.dev/).
```
$ git clone https://github.com/amnezia-vpn/amnezia-wg
$ cd amnezia-wg
$ git clone https://github.com/amnezia-vpn/amneziawg-go
$ cd amneziawg-go
$ make
```

View file

@ -331,7 +331,7 @@ type ErrUDPGSODisabled struct {
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {

View file

@ -17,7 +17,7 @@ import (
"golang.org/x/sys/windows"
"github.com/amnezia-vpn/amnezia-wg/conn/winrio"
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
)
const (

View file

@ -12,7 +12,7 @@ import (
"net/netip"
"os"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type ChannelBind struct {

View file

@ -57,13 +57,5 @@ func init() {
}
return err
},
// Attempt to enable UDP_GRO
func(network, address string, c syscall.RawConn) error {
c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
})
return nil
},
)
}

View file

@ -20,7 +20,9 @@ func errShouldDisableUDPGSO(err error) bool {
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
return serr.Err == unix.EIO
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
// It occurs when the peer mtu + wg headers is greater than path mtu.
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
}
return false
}

View file

@ -19,8 +19,10 @@ func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
rxOffload = errSyscall == nil && opt == 1
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
// use setsockopt workaround
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
rxOffload = errSyscall == nil
})
if err != nil {
return false, false

View file

@ -8,7 +8,7 @@ package device
import (
"errors"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type DummyDatagram struct {

View file

@ -11,11 +11,11 @@ import (
"sync/atomic"
"time"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/ipc"
"github.com/amnezia-vpn/amnezia-wg/ratelimiter"
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/tevino/abool/v2"
)
@ -95,6 +95,7 @@ type Device struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
}
type aSecCfgType struct {
@ -388,10 +389,10 @@ func (device *Device) RemoveAllPeers() {
}
func (device *Device) Close() {
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
device.state.Lock()
defer device.state.Unlock()
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
@ -415,6 +416,8 @@ func (device *Device) Close() {
device.rate.limiter.Close()
device.resetProtocol()
device.log.Verbosef("Device closed")
close(device.closed)
}
@ -481,11 +484,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@ -535,11 +534,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@ -566,6 +561,14 @@ func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
@ -797,6 +800,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err

View file

@ -20,10 +20,10 @@ import (
"testing"
"time"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/conn/bindtest"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@ -237,7 +237,7 @@ func genTestPair(
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
p.dev = NewDevice(p.tun.TUN(),binds[i],NewLogger(level, fmt.Sprintf("dev%d: ", i)))
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
@ -274,7 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
})
}
func TestTwoDevicePingASecurity(t *testing.T) {
func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
@ -294,7 +294,7 @@ func TestUpDown(t *testing.T) {
pair := genTestPair(t, false, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n",hex.EncodeToString(k[:])))
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
}
}
var wg sync.WaitGroup
@ -513,7 +513,7 @@ func (b *fakeBindSized) Open(
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error {return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
@ -527,7 +527,9 @@ type fakeTUNDeviceSized struct {
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { return 0, nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }

69
device/junk_creator.go Normal file
View file

@ -0,0 +1,69 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

124
device/junk_creator_test.go Normal file
View file

@ -0,0 +1,124 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets()
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -11,7 +11,7 @@ import (
"sync/atomic"
"time"
"github.com/amnezia-vpn/amnezia-wg/replay"
"github.com/amnezia-vpn/amneziawg-go/replay"
)
/* Due to limitations in Go and /x/crypto there is currently

View file

@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
peer.disableRoaming = peer.endpoint != nil
peer.Unlock()
peer.endpoint.Lock()
peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.endpoint.Unlock()
}
device.peers.RUnlock()
}

View file

@ -15,7 +15,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"github.com/amnezia-vpn/amnezia-wg/tai64n"
"github.com/amnezia-vpn/amneziawg-go/tai64n"
)
type handshakeState int

View file

@ -10,8 +10,8 @@ import (
"encoding/binary"
"testing"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func TestCurveWrappers(t *testing.T) {

View file

@ -12,22 +12,25 @@ import (
"sync/atomic"
"time"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type Peer struct {
isRunning atomic.Bool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
endpoint struct {
sync.Mutex
val conn.Endpoint
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
disableRoaming bool
}
timers struct {
retransmitHandshake *Timer
@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
// init timers
peer.timersInit()
@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return nil
}
peer.RLock()
defer peer.RUnlock()
if peer.endpoint == nil {
peer.endpoint.Lock()
endpoint := peer.endpoint.val
if endpoint == nil {
peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffers, peer.endpoint)
err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
var totalLen uint64
for _, b := range buffers {
@ -267,10 +277,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if peer.disableRoaming {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
return
}
peer.Lock()
peer.endpoint = endpoint
peer.Unlock()
peer.endpoint.clearSrcOnTx = false
peer.endpoint.val = endpoint
}
func (peer *Peer) markEndpointSrcForClearing() {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return
}
peer.endpoint.clearSrcOnTx = true
}

View file

@ -5,7 +5,7 @@
package device
import "github.com/amnezia-vpn/amnezia-wg/conn"
import "github.com/amnezia-vpn/amneziawg-go/conn"
/* Reduce memory consumption for Android */

View file

@ -7,7 +7,7 @@
package device
import "github.com/amnezia-vpn/amnezia-wg/conn"
import "github.com/amnezia-vpn/amneziawg-go/conn"
const (
QueueStagedSize = conn.IdealBatchSize

View file

@ -13,7 +13,7 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -145,7 +145,7 @@ func (device *Device) RoutineReceiveIncoming(
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
@ -476,7 +476,10 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
return
}
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
validTailPacket := -1
dataPacketReceived := false
rxBytesLen := uint64(0)
for i, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
@ -486,21 +489,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
continue
}
peer.SetEndpointFromPacket(elem.endpoint)
validTailPacket = i
if peer.ReceivedWithKeypair(elem.keypair) {
peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
peer.timersDataReceived()
dataPacketReceived = true
switch elem.packet[0] >> 4 {
case 4:
@ -549,6 +550,17 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
}
peer.rxBytes.Add(rxBytesLen)
if validTailPacket >= 0 {
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
}
if dataPacketReceived {
peer.timersDataReceived()
}
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {

View file

@ -9,14 +9,13 @@ import (
"bytes"
"encoding/binary"
"errors"
"math/rand"
"net"
"os"
"sync"
"time"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -129,7 +128,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.createJunkPackets()
junks, err := peer.device.junkCreator.createJunkPackets()
peer.device.aSecMux.RUnlock()
if err != nil {
@ -137,17 +136,20 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
@ -197,7 +199,7 @@ func (peer *Peer) SendHandshakeResponse() error {
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
@ -466,31 +468,6 @@ top:
}
}
func (peer *Peer) createJunkPackets() ([][]byte, error) {
if peer.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
packetSize := rand.Intn(
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
) + peer.device.aSecCfg.junkPacketMinSize
junk, err := randomJunkWithSize(packetSize)
if err != nil {
peer.device.log.Errorf(
"%v - Failed to create junk packet: %v",
peer,
err,
)
return nil, err
}
junks = append(junks, junk)
}
return junks, nil
}
func (peer *Peer) FlushStagedPackets() {
for {
select {

View file

@ -3,8 +3,8 @@
package device
import (
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {

View file

@ -20,8 +20,8 @@ import (
"golang.org/x/sys/unix"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.endpoint.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock()
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.endpoint.Unlock()
break
}
pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
pePtr.peer.Unlock()
pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil {
peer.RUnlock()
peer.endpoint.Lock()
if peer.endpoint.val == nil {
peer.endpoint.Unlock()
continue
}
nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
peer.RUnlock()
peer.endpoint.Unlock()
continue
}
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock()
peer.endpoint.Unlock()
break
}
nlmsg := struct {
@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
peer.RUnlock()
peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {

View file

@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
}

View file

@ -8,7 +8,7 @@ package device
import (
"fmt"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
const DefaultMTU = 1420

View file

@ -18,7 +18,7 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amnezia-wg/ipc"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
type IPCError struct {
@ -129,17 +129,16 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
for _, peer := range device.peers.keyMap {
// Serialize peer state.
// Do the work in an anonymous function so that we can use defer.
func() {
peer.RLock()
defer peer.RUnlock()
peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
if peer.endpoint != nil {
sendf("endpoint=%s", peer.endpoint.DstToString())
peer.endpoint.Lock()
if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
peer.endpoint.Unlock()
nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
@ -151,14 +150,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(
peer,
func(prefix netip.Prefix) bool {
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
},
)
}()
})
}
}()
@ -385,8 +380,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
return
}
if peer.created {
peer.disableRoaming = peer.device.net.brokenRoaming &&
peer.endpoint != nil
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
@ -475,9 +469,9 @@ func (device *Device) handlePeerLine(
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
peer.Lock()
defer peer.Unlock()
peer.endpoint = endpoint
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)

View file

@ -1,25 +0,0 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
)
func appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
func randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := crand.Read(junk)
return junk, err
}

View file

@ -1,27 +0,0 @@
package device
import (
"bytes"
"fmt"
"testing"
)
func Test_randomJunktWithSize(t *testing.T) {
junk, err := randomJunkWithSize(30)
fmt.Println(string(junk), len(junk), err)
}
func Test_appendJunk(t *testing.T) {
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

16
go.mod
View file

@ -1,17 +1,17 @@
module github.com/amnezia-vpn/amnezia-wg
module github.com/amnezia-vpn/amneziawg-go
go 1.20
go 1.24
require (
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.13.0
golang.org/x/net v0.15.0
golang.org/x/sys v0.12.0
golang.org/x/crypto v0.36.0
golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
)
require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
github.com/google/btree v1.1.3 // indirect
golang.org/x/time v0.9.0 // indirect
)

28
go.sum
View file

@ -1,16 +1,20 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=

View file

@ -20,7 +20,7 @@ import (
"testing"
"time"
"github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)

View file

@ -9,7 +9,7 @@ import (
"net"
"os"
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix"
)

View file

@ -26,7 +26,7 @@ const (
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
var socketDirectory = "/var/run/wireguard"
var socketDirectory = "/var/run/amneziawg"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)

View file

@ -8,7 +8,7 @@ package ipc
import (
"net"
"github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)
@ -62,7 +62,7 @@ func init() {
func UAPIListen(name string) (net.Listener, error) {
listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
if err != nil {
return nil, err
}

22
main.go
View file

@ -14,10 +14,10 @@ import (
"runtime"
"strconv"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/device"
"github.com/amnezia-vpn/amnezia-wg/ipc"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/sys/unix"
)
@ -46,20 +46,20 @@ func warning() {
return
}
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
}
func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
return
}
@ -145,7 +145,7 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting wireguard-go version %s", Version)
logger.Verbosef("Starting amneziawg-go version %s", Version)
if err != nil {
logger.Errorf("Failed to create TUN device: %v", err)

View file

@ -12,11 +12,11 @@ import (
"golang.org/x/sys/windows"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/device"
"github.com/amnezia-vpn/amnezia-wg/ipc"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
const (
@ -30,13 +30,13 @@ func main() {
}
interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
logger := device.NewLogger(
device.LogLevelVerbose,
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting wireguard-go version %s", Version)
logger.Verbosef("Starting amneziawg-go version %s", Version)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {

View file

@ -13,9 +13,9 @@ import (
"net/http"
"net/netip"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/device"
"github.com/amnezia-vpn/amnezia-wg/tun/netstack"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
)
func main() {

View file

@ -14,9 +14,9 @@ import (
"net/http"
"net/netip"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/device"
"github.com/amnezia-vpn/amnezia-wg/tun/netstack"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
)
func main() {

View file

@ -17,9 +17,9 @@ import (
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/device"
"github.com/amnezia-vpn/amnezia-wg/tun/netstack"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
)
func main() {

View file

@ -22,7 +22,7 @@ import (
"syscall"
"time"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/buffer"

View file

@ -12,7 +12,7 @@ import (
"io"
"unsafe"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/sys/unix"
)
@ -57,22 +57,23 @@ const (
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
// tcpFlowKey represents the key for a TCP flow.
type tcpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
isV6 bool
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsByFlow map[tcpFlowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize),
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
}
for i := range t.itemsPool {
@ -81,14 +82,15 @@ func newTCPGROTable() *tcpGROTable {
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
key := tcpFlowKey{}
addrSize := dstAddrOffset - srcAddrOffset
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
key.isV6 = addrSize == 16
return key
}
@ -96,7 +98,7 @@ func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
@ -108,7 +110,7 @@ func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, t
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
@ -131,7 +133,7 @@ func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
@ -140,7 +142,7 @@ func (t *tcpGROTable) deleteAt(key flowKey, i int) {
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
key tcpFlowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
@ -164,6 +166,103 @@ func (t *tcpGROTable) reset() {
}
}
// udpFlowKey represents the key for a UDP flow.
type udpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
isV6 bool
}
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
type udpGROTable struct {
itemsByFlow map[udpFlowKey][]udpGROItem
itemsPool [][]udpGROItem
}
func newUDPGROTable() *udpGROTable {
u := &udpGROTable{
itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
}
for i := range u.itemsPool {
u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
}
return u
}
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
key := udpFlowKey{}
addrSize := dstAddrOffset - srcAddrOffset
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
key.isV6 = addrSize == 16
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
items, ok := u.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
item := udpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
iphLen: uint8(udphOffset),
cSumKnownInvalid: cSumKnownInvalid,
}
items, ok := u.itemsByFlow[key]
if !ok {
items = u.newItems()
}
items = append(items, item)
u.itemsByFlow[key] = items
}
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
items, _ := u.itemsByFlow[item.key]
items[i] = item
}
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type udpGROItem struct {
key udpFlowKey
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
}
func (u *udpGROTable) newItems() []udpGROItem {
var items []udpGROItem
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
return items
}
func (u *udpGROTable) reset() {
for k, items := range u.itemsByFlow {
items = items[:0]
u.itemsPool = append(u.itemsPool, items)
delete(u.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
@ -174,6 +273,61 @@ const (
coalesceAppend canCoalesce = 1
)
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
// meet all requirements to be merged as part of a GRO operation, otherwise it
// returns false.
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
if len(pktA) < 9 || len(pktB) < 9 {
return false
}
if pktA[0]>>4 == 6 {
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
// cannot coalesce with unequal Traffic class values
return false
}
if pktA[7] != pktB[7] {
// cannot coalesce with unequal Hop limit values
return false
}
} else {
if pktA[1] != pktB[1] {
// cannot coalesce with unequal ToS values
return false
}
if pktA[6]>>5 != pktB[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return false
}
if pktA[8] != pktB[8] {
// cannot coalesce with unequal TTL values
return false
}
}
return true
}
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
// packets involved in the current GRO evaluation. bufsOffset is the offset at
// which packet data begins within bufs.
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if !ipHeadersCanCoalesce(pkt, pktTarget) {
return coalesceUnavailable
}
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
}
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
@ -189,30 +343,9 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
return coalesceUnavailable
}
}
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
if !ipHeadersCanCoalesce(pkt, pktTarget) {
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
@ -252,16 +385,16 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
lenForPseudo := uint16(len(pkt) - int(iphLen))
cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
return ^checksum(pkt[iphLen:], cSum) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
@ -276,8 +409,36 @@ const (
coalesceSuccess
)
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
// item, and returns the outcome.
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
headersLen := item.iphLen + udphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
return coalesceItemInvalidCSum
}
}
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
return coalescePktInvalidCSum
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
item.numMerged++
return coalesceSuccess
}
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap bufs elements in the
// item, and returns the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
@ -297,11 +458,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
@ -319,11 +480,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
@ -354,52 +515,52 @@ const (
maxUint16 = 1<<16 - 1
)
type tcpGROResult int
type groResult int
const (
tcpGROResultNoop tcpGROResult = iota
tcpGROResultTableInsert
tcpGROResultCoalesced
groResultNoop groResult = iota
groResultTableInsert
groResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a tcpGROResultNoop when no
// action was taken, tcpGROResultTableInsert when the evaluated packet was
// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
// existing packets tracked in table. It returns a groResultNoop when no
// action was taken, groResultTableInsert when the evaluated packet was
// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return tcpGROResultNoop
return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return tcpGROResultNoop
return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return tcpGROResultNoop
return groResultNoop
}
}
if len(pkt) < iphLen {
return tcpGROResultNoop
return groResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return tcpGROResultNoop
return groResultNoop
}
if len(pkt) < iphLen+tcphLen {
return tcpGROResultNoop
return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return tcpGROResultNoop
return groResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
@ -407,14 +568,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return tcpGROResultNoop
return groResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return tcpGROResultNoop
return groResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
@ -425,7 +586,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return tcpGROResultNoop
return groResultTableInsert
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
@ -443,54 +604,25 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return tcpGROResultCoalesced
return groResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return tcpGROResultNoop
return groResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return tcpGROResultTableInsert
return groResultTableInsert
}
func isTCP4NoIPOptions(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[0]&0x0F != 5 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// applyCoalesceAccounting updates bufs to account for coalescing based on the
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
@ -505,7 +637,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if isV6 {
if item.key.isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
@ -525,7 +657,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
@ -546,54 +678,245 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
return nil
}
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + udphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 6,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
if item.key.isV6 {
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Recalculate the UDP len field value
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
// Calculate the pseudo header checksum and place it at the UDP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the udp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
type groCandidateType uint8
const (
notGROCandidate groCandidateType = iota
tcp4GROCandidate
tcp6GROCandidate
udp4GROCandidate
udp6GROCandidate
)
func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
if len(b) < 28 {
return notGROCandidate
}
if b[0]>>4 == 4 {
if b[0]&0x0F != 5 {
// IPv4 packets w/IP options do not coalesce
return notGROCandidate
}
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
return tcp4GROCandidate
}
if b[9] == unix.IPPROTO_UDP && canUDPGRO {
return udp4GROCandidate
}
} else if b[0]>>4 == 6 {
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
return tcp6GROCandidate
}
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
return udp6GROCandidate
}
}
return notGROCandidate
}
const (
udphLen = 8
)
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a groResultNoop when no
// action was taken, groResultTableInsert when the evaluated packet was
// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return groResultNoop
}
}
if len(pkt) < iphLen {
return groResultNoop
}
if len(pkt) < iphLen+udphLen {
return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return groResultNoop
}
}
gsoSize := uint16(len(pkt) - udphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return groResultNoop
}
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
if !existing {
return groResultTableInsert
}
// With UDP we only check the last item, otherwise we could reorder packets
// for a given flow. We must also always insert a new item, or successfully
// coalesce with an existing item, for the same reason.
item := items[len(items)-1]
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
var pktCSumKnownInvalid bool
if can == coalesceAppend {
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, len(items)-1)
return groResultCoalesced
case coalesceItemInvalidCSum:
// If the existing item has an invalid csum we take no action. A new
// item will be stored after it, and the existing item will never be
// revisited as part of future coalescing candidacy checks.
case coalescePktInvalidCSum:
// We must insert a new item, but we also mark it as invalid csum
// to prevent a repeat checksum validation.
pktCSumKnownInvalid = true
default:
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
return groResultTableInsert
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
// supported.
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var result tcpGROResult
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
result = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
result = tcpGRO(bufs, offset, i, tcp6Table, true)
var result groResult
switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
case tcp4GROCandidate:
result = tcpGRO(bufs, offset, i, tcpTable, false)
case tcp6GROCandidate:
result = tcpGRO(bufs, offset, i, tcpTable, true)
case udp4GROCandidate:
result = udpGRO(bufs, offset, i, udpTable, false)
case udp6GROCandidate:
result = udpGRO(bufs, offset, i, udpTable, true)
}
switch result {
case tcpGROResultNoop:
case groResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case tcpGROResultTableInsert:
case groResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
return errors.Join(err4, err6)
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
return errors.Join(errTCP, errUDP)
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// gsoSplit splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
if !isV6 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
var firstTCPSeqNum uint32
var protocol uint8
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
protocol = unix.IPPROTO_TCP
firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
} else {
protocol = unix.IPPROTO_UDP
}
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
@ -610,7 +933,7 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
if !isV6 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
@ -627,8 +950,11 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
// copy transport header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
if protocol == unix.IPPROTO_TCP {
// set TCP seq and adjust TCP flags
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
@ -636,16 +962,20 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
} else {
// set UDP header len
binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
// transport checksum
transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}

752
tun/offload_linux_test.go Normal file
View file

@ -0,0 +1,752 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"net/netip"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
offset = virtioNetHdrLen
)
var (
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
)
func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
totalLen := 28 + payloadLen
b := make([]byte, offset+int(totalLen), 65535)
ipv4H := header.IPv4(b[offset:])
srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4()
ipFields := &header.IPv4Fields{
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
Protocol: unix.IPPROTO_UDP,
TTL: 64,
TotalLength: uint16(totalLen),
}
if ipFn != nil {
ipFn(ipFields)
}
ipv4H.Encode(ipFields)
udpH := header.UDP(b[offset+20:])
udpH.Encode(&header.UDPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
Length: uint16(payloadLen + udphLen),
})
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
return b
}
func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
}
func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
totalLen := 48 + payloadLen
b := make([]byte, offset+int(totalLen), 65535)
ipv6H := header.IPv6(b[offset:])
srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16()
ipFields := &header.IPv6Fields{
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
TransportProtocol: unix.IPPROTO_UDP,
HopLimit: 64,
PayloadLength: uint16(payloadLen + udphLen),
}
if ipFn != nil {
ipFn(ipFields)
}
ipv6H.Encode(ipFields)
udpH := header.UDP(b[offset+40:])
udpH.Encode(&header.UDPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
Length: uint16(payloadLen + udphLen),
})
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
return b
}
func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
}
func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
totalLen := 40 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv4H := header.IPv4(b[offset:])
srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4()
ipFields := &header.IPv4Fields{
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
Protocol: unix.IPPROTO_TCP,
TTL: 64,
TotalLength: uint16(totalLen),
}
if ipFn != nil {
ipFn(ipFields)
}
ipv4H.Encode(ipFields)
tcpH := header.TCP(b[offset+20:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
}
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
totalLen := 60 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv6H := header.IPv6(b[offset:])
srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16()
ipFields := &header.IPv6Fields{
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
TransportProtocol: unix.IPPROTO_TCP,
HopLimit: 64,
PayloadLength: uint16(segmentSize + 20),
}
if ipFn != nil {
ipFn(ipFields)
}
ipv6H.Encode(ipFields)
tcpH := header.TCP(b[offset+40:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
}
func Test_handleVirtioRead(t *testing.T) {
tests := []struct {
name string
hdr virtioNetHdr
pktIn []byte
wantLens []int
wantErr bool
}{
{
"tcp4",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
gsoSize: 100,
hdrLen: 40,
csumStart: 20,
csumOffset: 16,
},
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{140, 140},
false,
},
{
"tcp6",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
gsoSize: 100,
hdrLen: 60,
csumStart: 40,
csumOffset: 16,
},
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{160, 160},
false,
},
{
"udp4",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
gsoSize: 100,
hdrLen: 28,
csumStart: 20,
csumOffset: 6,
},
udp4Packet(ip4PortA, ip4PortB, 200),
[]int{128, 128},
false,
},
{
"udp6",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
gsoSize: 100,
hdrLen: 48,
csumStart: 40,
csumOffset: 6,
},
udp6Packet(ip6PortA, ip6PortB, 200),
[]int{148, 148},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out := make([][]byte, conn.IdealBatchSize)
sizes := make([]int, conn.IdealBatchSize)
for i := range out {
out[i] = make([]byte, 65535)
}
tt.hdr.encode(tt.pktIn)
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if n != len(tt.wantLens) {
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
}
for i := range tt.wantLens {
if tt.wantLens[i] != sizes[i] {
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
}
}
})
}
}
func flipTCP4Checksum(b []byte) []byte {
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
b[at] ^= 0xFF
b[at+1] ^= 0xFF
return b
}
func flipUDP4Checksum(b []byte) []byte {
at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
b[at] ^= 0xFF
b[at+1] ^= 0xFF
return b
}
func Fuzz_handleGRO(f *testing.F) {
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
toWrite := make([]int, 0, len(pkts))
handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
if len(toWrite) > len(pkts) {
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
}
seenWriteI := make(map[int]bool)
for _, writeI := range toWrite {
if writeI < 0 || writeI > len(pkts)-1 {
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
}
if seenWriteI[writeI] {
t.Errorf("duplicate toWrite value: %d", writeI)
}
seenWriteI[writeI] = true
}
})
}
func Test_handleGRO(t *testing.T) {
tests := []struct {
name string
pktsIn [][]byte
canUDPGRO bool
wantToWrite []int
wantLens []int
wantErr bool
}{
{
"multiple protocols and flows",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
},
true,
[]int{0, 1, 2, 4, 5, 7, 9},
[]int{240, 228, 128, 140, 260, 160, 248},
false,
},
{
"multiple protocols and flows no UDP GRO",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
},
false,
[]int{0, 1, 2, 4, 5, 7, 8, 9, 10},
[]int{240, 128, 128, 140, 260, 160, 128, 148, 148},
false,
},
{
"PSH interleaved",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
},
true,
[]int{0, 2, 4, 6},
[]int{240, 240, 260, 260},
false,
},
{
"coalesceItemInvalidCSum",
[][]byte{
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
udp4Packet(ip4PortA, ip4PortB, 100),
udp4Packet(ip4PortA, ip4PortB, 100),
},
true,
[]int{0, 1, 3, 4},
[]int{140, 240, 128, 228},
false,
},
{
"out of order",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
true,
[]int{0},
[]int{340},
false,
},
{
"unequal TTL",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.TTL++
}),
udp4Packet(ip4PortA, ip4PortB, 100),
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
fields.TTL++
}),
},
true,
[]int{0, 1, 2, 3},
[]int{140, 140, 128, 128},
false,
},
{
"unequal ToS",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.TOS++
}),
udp4Packet(ip4PortA, ip4PortB, 100),
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
fields.TOS++
}),
},
true,
[]int{0, 1, 2, 3},
[]int{140, 140, 128, 128},
false,
},
{
"unequal flags more fragments set",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.Flags = 1
}),
udp4Packet(ip4PortA, ip4PortB, 100),
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
fields.Flags = 1
}),
},
true,
[]int{0, 1, 2, 3},
[]int{140, 140, 128, 128},
false,
},
{
"unequal flags DF set",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.Flags = 2
}),
udp4Packet(ip4PortA, ip4PortB, 100),
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
fields.Flags = 2
}),
},
true,
[]int{0, 1, 2, 3},
[]int{140, 140, 128, 128},
false,
},
{
"ipv6 unequal hop limit",
[][]byte{
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
fields.HopLimit++
}),
udp6Packet(ip6PortA, ip6PortB, 100),
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
fields.HopLimit++
}),
},
true,
[]int{0, 1, 2, 3},
[]int{160, 160, 148, 148},
false,
},
{
"ipv6 unequal traffic class",
[][]byte{
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
fields.TrafficClass++
}),
udp6Packet(ip6PortA, ip6PortB, 100),
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
fields.TrafficClass++
}),
},
true,
[]int{0, 1, 2, 3},
[]int{160, 160, 148, 148},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toWrite := make([]int, 0, len(tt.pktsIn))
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if len(toWrite) != len(tt.wantToWrite) {
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
}
for i, pktI := range tt.wantToWrite {
if tt.wantToWrite[i] != toWrite[i] {
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
}
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
}
}
})
}
}
func Test_packetIsGROCandidate(t *testing.T) {
tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
tcp4TooShort := tcp4[:39]
ip4InvalidHeaderLen := make([]byte, len(tcp4))
copy(ip4InvalidHeaderLen, tcp4)
ip4InvalidHeaderLen[0] = 0x46
ip4InvalidProtocol := make([]byte, len(tcp4))
copy(ip4InvalidProtocol, tcp4)
ip4InvalidProtocol[9] = unix.IPPROTO_GRE
tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
tcp6TooShort := tcp6[:59]
ip6InvalidProtocol := make([]byte, len(tcp6))
copy(ip6InvalidProtocol, tcp6)
ip6InvalidProtocol[6] = unix.IPPROTO_GRE
udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
udp4TooShort := udp4[:27]
udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
udp6TooShort := udp6[:47]
tests := []struct {
name string
b []byte
canUDPGRO bool
want groCandidateType
}{
{
"tcp4",
tcp4,
true,
tcp4GROCandidate,
},
{
"tcp6",
tcp6,
true,
tcp6GROCandidate,
},
{
"udp4",
udp4,
true,
udp4GROCandidate,
},
{
"udp4 no support",
udp4,
false,
notGROCandidate,
},
{
"udp6",
udp6,
true,
udp6GROCandidate,
},
{
"udp6 no support",
udp6,
false,
notGROCandidate,
},
{
"udp4 too short",
udp4TooShort,
true,
notGROCandidate,
},
{
"udp6 too short",
udp6TooShort,
true,
notGROCandidate,
},
{
"tcp4 too short",
tcp4TooShort,
true,
notGROCandidate,
},
{
"tcp6 too short",
tcp6TooShort,
true,
notGROCandidate,
},
{
"invalid IP version",
[]byte{0x00},
true,
notGROCandidate,
},
{
"invalid IP header len",
ip4InvalidHeaderLen,
true,
notGROCandidate,
},
{
"ip4 invalid protocol",
ip4InvalidProtocol,
true,
notGROCandidate,
},
{
"ip6 invalid protocol",
ip6InvalidProtocol,
true,
notGROCandidate,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
}
})
}
}
func Test_udpPacketsCanCoalesce(t *testing.T) {
udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
type args struct {
pkt []byte
iphLen uint8
gsoSize uint16
item udpGROItem
bufs [][]byte
bufsOffset int
}
tests := []struct {
name string
args args
want canCoalesce
}{
{
"coalesceAppend equal gso",
args{
pkt: udp4a[offset:],
iphLen: 20,
gsoSize: 100,
item: udpGROItem{
gsoSize: 100,
iphLen: 20,
},
bufs: [][]byte{
udp4a,
udp4b,
},
bufsOffset: offset,
},
coalesceAppend,
},
{
"coalesceAppend smaller gso",
args{
pkt: udp4a[offset : len(udp4a)-90],
iphLen: 20,
gsoSize: 10,
item: udpGROItem{
gsoSize: 100,
iphLen: 20,
},
bufs: [][]byte{
udp4a,
udp4b,
},
bufsOffset: offset,
},
coalesceAppend,
},
{
"coalesceUnavailable smaller gso previously appended",
args{
pkt: udp4a[offset:],
iphLen: 20,
gsoSize: 100,
item: udpGROItem{
gsoSize: 100,
iphLen: 20,
},
bufs: [][]byte{
udp4c,
udp4b,
},
bufsOffset: offset,
},
coalesceUnavailable,
},
{
"coalesceUnavailable larger following smaller",
args{
pkt: udp4c[offset:],
iphLen: 20,
gsoSize: 110,
item: udpGROItem{
gsoSize: 100,
iphLen: 20,
},
bufs: [][]byte{
udp4a,
udp4c,
},
bufsOffset: offset,
},
coalesceUnavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,411 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"net/netip"
"testing"
"github.com/amnezia-vpn/amnezia-wg/conn"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
offset = virtioNetHdrLen
)
var (
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
)
func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
totalLen := 40 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv4H := header.IPv4(b[offset:])
srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4()
ipFields := &header.IPv4Fields{
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
Protocol: unix.IPPROTO_TCP,
TTL: 64,
TotalLength: uint16(totalLen),
}
if ipFn != nil {
ipFn(ipFields)
}
ipv4H.Encode(ipFields)
tcpH := header.TCP(b[offset+20:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
}
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
totalLen := 60 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv6H := header.IPv6(b[offset:])
srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16()
ipFields := &header.IPv6Fields{
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
TransportProtocol: unix.IPPROTO_TCP,
HopLimit: 64,
PayloadLength: uint16(segmentSize + 20),
}
if ipFn != nil {
ipFn(ipFields)
}
ipv6H.Encode(ipFields)
tcpH := header.TCP(b[offset+40:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
}
func Test_handleVirtioRead(t *testing.T) {
tests := []struct {
name string
hdr virtioNetHdr
pktIn []byte
wantLens []int
wantErr bool
}{
{
"tcp4",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
gsoSize: 100,
hdrLen: 40,
csumStart: 20,
csumOffset: 16,
},
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{140, 140},
false,
},
{
"tcp6",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
gsoSize: 100,
hdrLen: 60,
csumStart: 40,
csumOffset: 16,
},
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{160, 160},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out := make([][]byte, conn.IdealBatchSize)
sizes := make([]int, conn.IdealBatchSize)
for i := range out {
out[i] = make([]byte, 65535)
}
tt.hdr.encode(tt.pktIn)
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if n != len(tt.wantLens) {
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
}
for i := range tt.wantLens {
if tt.wantLens[i] != sizes[i] {
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
}
}
})
}
}
func flipTCP4Checksum(b []byte) []byte {
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
b[at] ^= 0xFF
b[at+1] ^= 0xFF
return b
}
func Fuzz_handleGRO(f *testing.F) {
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
toWrite := make([]int, 0, len(pkts))
handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
if len(toWrite) > len(pkts) {
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
}
seenWriteI := make(map[int]bool)
for _, writeI := range toWrite {
if writeI < 0 || writeI > len(pkts)-1 {
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
}
if seenWriteI[writeI] {
t.Errorf("duplicate toWrite value: %d", writeI)
}
seenWriteI[writeI] = true
}
})
}
func Test_handleGRO(t *testing.T) {
tests := []struct {
name string
pktsIn [][]byte
wantToWrite []int
wantLens []int
wantErr bool
}{
{
"multiple flows",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
},
[]int{0, 2, 3, 5},
[]int{240, 140, 260, 160},
false,
},
{
"PSH interleaved",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
},
[]int{0, 2, 4, 6},
[]int{240, 240, 260, 260},
false,
},
{
"coalesceItemInvalidCSum",
[][]byte{
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
[]int{0, 1},
[]int{140, 240},
false,
},
{
"out of order",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
[]int{0},
[]int{340},
false,
},
{
"tcp4 unequal TTL",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.TTL++
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp4 unequal ToS",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.TOS++
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp4 unequal flags more fragments set",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.Flags = 1
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp4 unequal flags DF set",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.Flags = 2
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp6 unequal hop limit",
[][]byte{
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
fields.HopLimit++
}),
},
[]int{0, 1},
[]int{160, 160},
false,
},
{
"tcp6 unequal traffic class",
[][]byte{
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
fields.TrafficClass++
}),
},
[]int{0, 1},
[]int{160, 160},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toWrite := make([]int, 0, len(tt.pktsIn))
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if len(toWrite) != len(tt.wantToWrite) {
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
}
for i, pktI := range tt.wantToWrite {
if tt.wantToWrite[i] != toWrite[i] {
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
}
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
}
}
})
}
}
func Test_isTCP4NoIPOptions(t *testing.T) {
valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
invalidLen := valid[:39]
invalidHeaderLen := make([]byte, len(valid))
copy(invalidHeaderLen, valid)
invalidHeaderLen[0] = 0x46
invalidProtocol := make([]byte, len(valid))
copy(invalidProtocol, valid)
invalidProtocol[9] = unix.IPPROTO_TCP + 1
tests := []struct {
name string
b []byte
want bool
}{
{
"valid",
valid,
true,
},
{
"invalid length",
invalidLen,
false,
},
{
"invalid version",
[]byte{0x00},
false,
},
{
"invalid header len",
invalidHeaderLen,
false,
},
{
"invalid protocol",
invalidProtocol,
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isTCP4NoIPOptions(tt.b); got != tt.want {
t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,8 +0,0 @@
go test fuzz v1
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
int(34)

View file

@ -1,8 +0,0 @@
go test fuzz v1
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
int(-48)

View file

@ -17,8 +17,8 @@ import (
"time"
"unsafe"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix"
)
@ -38,6 +38,7 @@ type NativeTun struct {
statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
udpGSO bool
closeOnce sync.Once
@ -48,9 +49,10 @@ type NativeTun struct {
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
toWrite []int
tcp4GROTable, tcp6GROTable *tcpGROTable
tcpGROTable *tcpGROTable
udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) {
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.writeOpMu.Lock()
defer func() {
tun.tcp4GROTable.reset()
tun.tcp6GROTable.reset()
tun.tcpGROTable.reset()
tun.udpGROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
if err != nil {
return 0, err
}
@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// csumStart, which is synonymous for IP header length.
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
hdr.hdrLen = hdr.csumStart + 8
} else {
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the TCP header length and add it onto
// csumStart, which is synonymous for IP header length.
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
}
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, bufs, sizes, offset)
return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int {
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
func (tun *NativeTun) initFromFlags(name string) error {
@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error {
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
// tunTCPOffloads were added in Linux v2.6. We require their support
// if IFF_VNET_HDR is set.
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = conn.IdealBatchSize
// tunUDPOffloads were added in Linux v6.2. We do not return an
// error if they are unsupported at runtime.
tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
} else {
tun.batchSize = 1
}
@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
tcp4GROTable: newTCPGROTable(),
tcp6GROTable: newTCPGROTable(),
tcpGROTable: newTCPGROTable(),
udpGROTable: newUDPGROTable(),
toWrite: make([]int, 0, conn.IdealBatchSize),
}
@ -631,8 +644,8 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
tcp4GROTable: newTCPGROTable(),
tcp6GROTable: newTCPGROTable(),
tcpGROTable: newTCPGROTable(),
udpGROTable: newUDPGROTable(),
toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()

View file

@ -160,11 +160,10 @@ retry:
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
copy(bufs[0][offset:], packet)
sizes[0] = packetSize
n := copy(bufs[0][offset:], packet)
sizes[0] = n
tun.session.ReleaseReceivePacket(packet)
tun.rate.update(uint64(packetSize))
tun.rate.update(uint64(n))
return 1, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {

View file

@ -11,7 +11,7 @@ import (
"net/netip"
"os"
"github.com/amnezia-vpn/amnezia-wg/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
func Ping(dst, src netip.Addr) []byte {